# BioMED Image-Text-to-Text microsoft/llava-med-v1.5-mistral-7b

[microsoft/llava-med-v1.5-mistral-7b](https://huggingface.co/microsoft/llava-med-v1.5-mistral-7b)

## Approach-1 using HuggingFace

In [1]:
MODEL_NAME = "microsoft/llava-med-v1.5-mistral-7b"

In [None]:
# from transformers import AutoModelForCausalLM, AutoTokenizer
# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True, device_map="auto")
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

In [2]:
from transformers import AutoConfig, LlamaConfig 
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM

[2025-01-18 23:49:49,623] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
class LlavaConfig(LlamaConfig):
    model_type = "llava"

In [4]:
AutoConfig.register("llava", LlavaConfig)
model = LlavaLlamaForCausalLM.from_pretrained(MODEL_NAME)

You are using a model of type llava_mistral to instantiate a model of type llava. This is not supported for all configurations of models and can yield errors.


: 

## Approach-2 using llava-torch

In [None]:
import torch
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (
    KeywordsStoppingCriteria,
    get_model_name_from_path,
    process_images,
    tokenizer_image_token,
)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from PIL import Image

In [None]:
disable_torch_init()
MODEL = "microsoft/llava-med-v1.5-mistral-7b"
model_name = get_model_name_from_path(MODEL)
model_name

In [None]:
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path=MODEL, model_base=None, model_name=model_name, load_4bit=False)

In [None]:
image_file = "../data/master_data/brain_xray.jpeg"
image = Image.open(image_file).convert("RGB")
image

In [5]:
def process_image(image):
    args = {"image_aspect_ratio": "pad"}
    image_tensor = process_images([image], image_processor, args)
    return image_tensor.to(model.device, dtype=torch.float16)

In [None]:
processed_image = process_image(image)
type(processed_image), processed_image.shape

In [None]:
CONV_MODE = "llava_v0"

def create_prompt(prompt: str):
    conv = conv_templates[CONV_MODE].copy()
    roles = conv.roles
    prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
    conv.append_message(roles[0], prompt)
    conv.append_message(roles[1], None)
    return conv.get_prompt(), conv

prompt, conv = create_prompt("What type of imaging does this represent?")
print("prompt:", prompt, "\nconv:", conv)

In [12]:
def ask_image(image: Image, prompt: str):
    print("I got his prompt:", prompt)
    
    image_tensor = process_image(image)
    prompt, conv = create_prompt(prompt)
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(model.device)
    )

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria(
        keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.01,
            max_new_tokens=512,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
        )
        print("decoded output_ids", tokenizer.decode(output_ids[0, input_ids.shape[1]:])) #input_ids.shape[1] :
    return tokenizer.decode(
        output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
    ).strip()

In [13]:
## 1
# query = "a photography of"

## 2
query = """Analyze the provided photograph and perform the following tasks:
            1. Identification: Describe the key features or abnormalities visible in the image.
            2. Prognosis: Suggest potential implications or diagnoses based on the identified features (if applicable).
            3. Description: Provide a detailed summary of the observed structures, focusing on medical relevance."""

## 3
# query = "do the prognosis"

## 4
# query = "What type of imaging does this represent?"

In [None]:
result = ask_image(image, query)
result