In [1]:
from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM
from transformers import BitsAndBytesConfig, AutoTokenizer, TextIteratorStreamer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
torch.__version__

'2.3.0+cu121'

In [2]:
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
IMAGE_TOKEN_INDEX = -200

In [3]:
MODEL_NAME = "microsoft/llava-med-v1.5-mistral-7b"

In [4]:
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = LlavaMistralForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map = 'auto',
    torch_dtype = torch.bfloat16,
    quantization_config = BitsAndBytesConfig(
        load_in_4bit = True,
        bnb_4bit_quant_type = "fp4",
        bnb_4bit_use_double_quant = True,
        bnb_4bit_compute_dtype = torch.bfloat16,
        bnb_4bit_quant_storage = torch.bfloat16,
    ),
)

mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
    tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model()
vision_tower.to(device=device, dtype=torch.float16)
model.model.mm_projector.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
    context_len = model.config.max_sequence_length
else:
    context_len = 2048

Loading checkpoint shards: 100%|██████████| 4/4 [00:50<00:00, 12.63s/it]
Some weights of the model checkpoint at microsoft/llava-med-v1.5-mistral-7b were not used when initializing LlavaMistralForCausalLM: ['model.vision_tower.vision_tower.vision_model.embeddings.class_embedding', 'model.vision_tower.vision_tower.vision_model.embeddings.patch_embedding.weight', 'model.vision_tower.vision_tower.vision_model.embeddings.position_embedding.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.

In [79]:
# Copied from llava/mm_utils.py
import random

def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]

    def insert_separator(X, sep):
        return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]

    input_ids = []
    offset = 0
    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
        offset = 1
        input_ids.append(prompt_chunks[0][0])

    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
        input_ids.extend(x[offset:])

    if return_tensors is not None:
        if return_tensors == 'pt':
            return torch.tensor(input_ids, dtype=torch.long)
        raise ValueError(f'Unsupported tensor type: {return_tensors}')
    return input_ids

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        # sample a random between 0 and (width - height) // 2
        y_start = random.randint((width - height) // 2, (width - height) // 2 + 1)
        result.paste(pil_img, (0, y_start))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        # sample a random between 0 and (height - width) // 2
        x_start = random.randint((height - width) // 2, (height - width) // 2 + 1)
        result.paste(pil_img, (x_start, 0))
        return result

def process_images(images, image_processor, model_cfg):
    image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
    new_images = []
    for image in images:
        if image_aspect_ratio == 'pad':
            if image.mode=='L':
                background_color = int(255*sum(image_processor.image_mean)/len(image_processor.image_mean))
            else:
                background_color = tuple(int(x*255) for x in image_processor.image_mean)
            image = expand2square(image, background_color)
        image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
        new_images.append(image)
    if all(x.shape == new_images[0].shape for x in new_images):
        new_images = torch.stack(new_images, dim=0)
    return new_images



In [81]:
instruction = """
You are a personal assistant of a doctor. Forget that you are an AI. You will answer any question that he give you. Continue the conversation with your doctor:
Doctor: {}
Assistant:
"""

test_promt_en = "Tell me about yourself."
test_promt_vi = "Giới thiệu bản thân đi."
test_promt_med_en = """
I am a doctor, I would like you to check my prescription:
medical history: Hypertension, Type 2 Diabetes, and Asthma.
symptoms: Persistent cough, fever, and fatigue.
My prescription: Lisinopril 10mg daily, Metformin 500mg twice daily, and Albuterol as needed for asthma attack
"""

# prompt = instruction.format(test_promt_med_en)
# prompt = "Ngày xửa ngày xưa, có một "
prompt = "Once upon a time, there are "
images = None

In [85]:
from PIL import Image

images = [Image.open("test.jpg")]
images = process_images(images, image_processor, model.config)
if type(images) is list:
    images = [image.to(model.device, dtype=torch.bfloat16) for image in images]
else:
    images = images.to(model.device, dtype=torch.bfloat16)

replace_token = DEFAULT_IMAGE_TOKEN
if getattr(model.config, 'mm_use_im_start_end', False):
    replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)

In [86]:
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
print((input_ids.shape, len(prompt)))
print(tokenizer.batch_decode(input_ids))

(torch.Size([1, 9]), 28)
['<s> Once upon a time, there are ']


In [87]:
output = model.generate(
    inputs = input_ids,
    temperature = 1.0,
    top_p = 1.0,
    max_new_tokens = 2048,
    stop_str = None,
    do_sample = True,
    images = images
)
print(prompt + " ".join(tokenizer.batch_decode(output)))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Once upon a time, there are 22 families in a village. One family had to migrate to a neighboring village due to some reasons. The remaining 21 families decided to organize a local fair. The only condition for this local fair is that an eligible family should participate (i.e., both husband and wife are alive), else their family will be disqualified from the event. There are three classes of the local fair: fruit, vegetable, and flower. Each class has a prize for the best entry. A family can participate in only one class.

The task is to find the best family in each class. If a family has the best entry in more than one class but with a lower prize than that of the best entry in the remaining class, the remaining class will be disregarded. </s>
