In [1]:
from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM
from transformers import BitsAndBytesConfig, AutoTokenizer, TextIteratorStreamer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
torch.__version__

'2.3.0+cu121'

In [2]:
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
IMAGE_TOKEN_INDEX = -200

In [3]:
MODEL_NAME = "microsoft/llava-med-v1.5-mistral-7b"

In [4]:
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = LlavaMistralForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map = 'auto',
    torch_dtype = torch.bfloat16,
    quantization_config = BitsAndBytesConfig(
        load_in_4bit = True,
        bnb_4bit_quant_type = "fp4",
        bnb_4bit_use_double_quant = True,
        bnb_4bit_compute_dtype = torch.bfloat16,
        bnb_4bit_quant_storage = torch.bfloat16,
    ),
)

mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
    tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model()
vision_tower.to(device=device, dtype=torch.float16)
model.model.mm_projector.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
    context_len = model.config.max_sequence_length
else:
    context_len = 2048

Loading checkpoint shards: 100%|██████████| 4/4 [00:50<00:00, 12.63s/it]
Some weights of the model checkpoint at microsoft/llava-med-v1.5-mistral-7b were not used when initializing LlavaMistralForCausalLM: ['model.vision_tower.vision_tower.vision_model.embeddings.class_embedding', 'model.vision_tower.vision_tower.vision_model.embeddings.patch_embedding.weight', 'model.vision_tower.vision_tower.vision_model.embeddings.position_embedding.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.

In [5]:
# Copied from llava/mm_utils.py

def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]

    def insert_separator(X, sep):
        return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]

    input_ids = []
    offset = 0
    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
        offset = 1
        input_ids.append(prompt_chunks[0][0])

    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
        input_ids.extend(x[offset:])

    if return_tensors is not None:
        if return_tensors == 'pt':
            return torch.tensor(input_ids, dtype=torch.long)
        raise ValueError(f'Unsupported tensor type: {return_tensors}')
    return input_ids

In [29]:
test_promt_en = "Tell me about yourself"
test_promt_vi = "Giới thiệu bản thân đi"
"""
I am a doctor, I would like you to check my prescription:
medical history: Hypertension, Type 2 Diabetes, and Asthma.
symptoms: Persistent cough, fever, and fatigue.
My prescription: Lisinopril 10mg daily, Metformin 500mg twice daily, and Albuterol as needed for asthma attacks
"""

input_ids = tokenizer_image_token(test_promt_en, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
print((input_ids.shape, len(test_promt_en)))
print(tokenizer.batch_decode(input_ids))

(torch.Size([1, 5]), 22)
['<s> Tell me about yourself']


In [32]:
output = model.generate(
    inputs = input_ids,
    temperature = 1.0,
    top_p = 1.0,
    max_new_tokens = 2048,
    stop_str = None,
    do_sample = True,
)
print(tokenizer.batch_decode(output))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[". Who are you, and what do you do for a living?\n\n I'm a software engineer by trade. I've been working in the tech industry for about 8 years now, and I currently work as a team lead at a company called Mux. My expertise lies in full-stack JavaScript development, and I am active in the open-source community, particularly in the areas of client-side rendering and web performance. </s>"]
