In [None]:
import torch

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import TextStreamer

from PIL import Image, ImageFile
import PIL
def load_image(image_file)->ImageFile:
    image = Image.open(image_file).convert('RGB')
    return image


disable_torch_init()

model_path= "microsoft/llava-med-v1.5-mistral-7b"

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device="cuda")

In [92]:

def run_llava(inp:str, image_file="fracture.jpg"):
    disable_torch_init()


    conv_mode = "mistral_instruct"

    if isinstance(image_file,(PIL.JpegImagePlugin.JpegImageFile)):
        image = image_file
    else:
        image = load_image(image_file)

    image_tensor = process_images([image], image_processor, model.config)
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)

    # User input (replace with your actual user input)
    #inp = "what is the condition of the patient"

    conv = conv_templates[conv_mode].copy()
    if "mpt" in model_name.lower():
        roles = ('user', 'assistant')
    else:
        roles = conv.roles

    # Prepare conversation
    if image is not None:
        # first message
        if model.config.mm_use_im_start_end:
            inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
        else:
            inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
        conv.append_message(conv.roles[0], inp)
        image = None
    else:
        # later messages
        conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    # Generate response
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            temperature=0.0,
            max_new_tokens=512,
            top_p = None,
            #streamer=streamer,
            use_cache=True,
            #stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.eos_token_id
            )

    output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    return output[0],image

output,image = run_llava("what is the condition of the patient","image (1).jpg")



In [76]:
import pandas as pd

splits = {'train': 'data/train-00000-of-00001-eb8844602202be60.parquet', 'test': 'data/test-00000-of-00001-e5bc3d208bb4deeb.parquet'}
df = pd.read_parquet("hf://datasets/flaviagiammarino/vqa-rad/" + splits["test"])

In [96]:
df.head(2)

Unnamed: 0,image,question,answer,index
0,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,is there evidence of an aortic aneurysm?,yes,0
1,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,is there airspace consolidation on the left side?,yes,1


In [104]:
import io

generative_responses = []
progress = 0
for data in df.to_dict(orient='records'):
    question = data['question']
    image = Image.open(io.BytesIO(data['image']['bytes']))
    output,_ = run_llava(question, image)

    generative_response = {"index":data['index'], 'generative_answer': output }
    generative_responses.append(generative_response)

    progress = progress+1

    if progress%50 == 0:
        print(f'Progress {progress} of {len(df)}')
    #print('output:',generative_response)



Progress 50 of 451
Progress 100 of 451
Progress 150 of 451
Progress 200 of 451
Progress 250 of 451
Progress 300 of 451
Progress 350 of 451
Progress 400 of 451
Progress 450 of 451


In [105]:
df_llava_med = pd.DataFrame(generative_responses)

In [106]:
df_llava_med

Unnamed: 0,index,generative_answer
0,0,"According to the chest X-ray, there is no evid..."
1,1,"Yes, the chest X-ray shows airspace consolidat..."
2,2,"According to the chest X-ray, there are no int..."
3,3,The right side of the heart border is obscured...
4,4,The kidney is located in the right upper quadr...
...,...,...
446,446,"Yes, the chest X-ray shows subcutaneous air in..."
447,447,"Yes, the image is a computed tomography (CT) s..."
448,448,"In the left apex, there is a nodule visible on..."
449,449,"Yes, the chest X-ray shows a pneumothorax in t..."


In [112]:
df_final = df_llava_med.join(df, on='index', how='inner',lsuffix='r')

In [116]:
df_final

Unnamed: 0,indexr,generative_answer,image,question,answer,index
0,0,"According to the chest X-ray, there is no evid...",{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,is there evidence of an aortic aneurysm?,yes,0
1,1,"Yes, the chest X-ray shows airspace consolidat...",{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,is there airspace consolidation on the left side?,yes,1
2,2,"According to the chest X-ray, there are no int...",{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,is there any intraparenchymal abnormalities in...,no,2
3,3,The right side of the heart border is obscured...,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,which side of the heart border is obscured?,right,3
4,4,The kidney is located in the right upper quadr...,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,where are the kidney?,not seen here,4
...,...,...,...,...,...,...
446,446,"Yes, the chest X-ray shows subcutaneous air in...",{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,is there subcutaneous air present in the right...,yes,446
447,447,"Yes, the image is a computed tomography (CT) s...",{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,is this image taken above the diaphragm?,no,447
448,448,"In the left apex, there is a nodule visible on...",{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,what is in the left apex?,a bullous lesion,448
449,449,"Yes, the chest X-ray shows a pneumothorax in t...",{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,is a pneumothorax present in the left apex?,no,449


In [115]:
df_final.to_csv('LLaVa_med.csv',index=False)