In [None]:
import torch

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import TextStreamer

from PIL import Image

def load_image(image_file):
    image = Image.open(image_file).convert('RGB')
    return image


disable_torch_init()

model_path= "microsoft/llava-med-v1.5-mistral-7b"

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device="cuda")

In [27]:

def run_llava(image_file="fracture.jpg"):
    disable_torch_init()


    conv_mode = "mistral_instruct"
    roles = ('USER', 'ASSISTANT')

    image = load_image(image_file)
    image_tensor = process_images([image], image_processor, model.config)
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)

    # User input (replace with your actual user input)
    inp = "what do you see in the image"

    conv = conv_templates[conv_mode].copy()
    if "mpt" in model_name.lower():
        roles = ('user', 'assistant')
    else:
        roles = conv.roles

    # Prepare conversation
    if image is not None:
        # first message
        if model.config.mm_use_im_start_end:
            inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
        else:
            inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
        conv.append_message(conv.roles[0], inp)
        image = None
    else:
        # later messages
        conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    # Generate response
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    #streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            #do_sample=True,
            temperature=0.0,
            max_new_tokens=512,
            #streamer=streamer,
            #use_cache=True,
            #stopping_criteria=[stopping_criteria]
            )

    output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    return output[0],image

output,image = run_llava("fracture.jpg")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [28]:
output

'The image is a lateral radiograph of the left ankle. It shows a fracture of the lateral malleolus, which is the bony prominence on the outer side of the ankle. Additionally, there is a small amount of soft tissue swelling present. '

In [26]:
output 

'The image is a lateral radiograph of the left ankle. It shows a fracture of the lateral malleolus, which is the bony prominence on the outer side of the ankle. Additionally, there is a small amount of soft tissue swelling present. '

In [24]:
image