In [1]:
import numpy as np
import pandas as pd
from PIL import Image
from datasets import load_dataset

In [2]:
test_dataset = load_dataset("adishourya/MEDPIX-ShortQA",split="Test")
test_dataset = test_dataset.to_pandas()

In [3]:
import torch
import pandas as pd
from datasets import load_dataset
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
from accelerate import infer_auto_device_map

# ┌──────────────────┐
# │ Load Quantized Model │
# └──────────────────┘

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

# adapter_model_id = "adishourya/results__fullrun__0310-134147"
adapter_model_id = "adishourya/results__fullrun__0710-111627"
peft_config = PeftConfig.from_pretrained(adapter_model_id)

model_id = peft_config.base_model_name_or_path
base_model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map="auto",
)

# device_map = infer_auto_device_map(
#     base_model,
#     max_memory={0: "2GB", "cpu": "10GB"},
# )

model = PeftModel.from_pretrained(
    base_model,
    adapter_model_id,
)

processor = AutoProcessor.from_pretrained(model_id)
model.eval()

# ┌───────┐
# │ Dataset│
# └───────┘

# Load the full test dataset
test_dataset = load_dataset("adishourya/MEDPIX-ShortQA", split="Test")

# ┌─────────────────────┐
# │ Inference Function  │
# └─────────────────────┘

def generate_answer(batch):
    if isinstance(batch, dict):
        batch = [{key: batch[key][i] for key in batch} for i in range(len(batch["image_id"]))]
    
    images = [item["image_id"].convert("RGB") for item in batch]
    questions = ["answer " + item["question"] for item in batch]
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    inputs = processor(
        text=questions,
        images=images,
        return_tensors="pt",
        padding="longest"
    ).to(device)

    with torch.no_grad():
        generated_ids = model.generate(
            inputs["input_ids"], 
            max_new_tokens=100,
        )
    
    generated_answers = processor.batch_decode(generated_ids, skip_special_tokens=True)
    return generated_answers

# ┌────────────────────────┐
# │ Perform Inference      │
# └────────────────────────┘

batch_size = 1
output_data = []

for i in range(0, len(test_dataset), batch_size):
    batch = test_dataset[i:i + batch_size]
    
    if isinstance(batch, dict):
        batch = [{key: batch[key][i] for key in batch} for i in range(len(batch["image_id"]))]
    
    generated_answers = generate_answer(batch)
    
    for image_id, question, generated_answer, label in zip(
        [item["image_id"] for item in batch],
        [item["question"] for item in batch], 
        generated_answers, 
        [item["answer"] for item in batch]
    ):
        output_data.append({
            "Image ID": image_id,  # You can modify this if you want to save image paths or metadata
            "Question": question,
            "Generated Answer": generated_answer,
            "Label": label
        })
        print(generated_answer)
        print("-"*25)
        print(label)
        print("="*50)

# ┌─────────────────────┐
# │ Save to Excel       │
# └─────────────────────┘

# Create a DataFrame from the output data
df_results = pd.DataFrame(output_data)



`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

answer What appears to be the issue?
The issue is that the image is blurry and not clear enough to make out what is happening in the image.
-------------------------
The findings suggest Plain radiographs demonstrate a subtle lucency in the acromion.  A CT ordered to rule out fracture clearly demonstrates a congentially unfused acromion.. CT demonstrates unfused meso-acromion and meta-acromion.. The patient might have a history of 26 year old male with left shoulder pain..
answer What could the diagnosis suggest?
The diagnosis suggests that the patient might have a history of significant motor vehicle accident and multiple concussions, which could lead to conus hemorrhage and/or other brain damage.
-------------------------
The possible diagnosis includes Acromion fracture
Normal unfused ossification center in patients under 25 years
Os Acromiale.
answer What was observed in the imaging?
The imaging shows that the image is not clear and there are some color differences.
---------------

In [6]:
test_dataset = test_dataset.to_pandas()

In [7]:
test_dataset.drop(columns=["question","answer","ans_len", "mode" , "split"],inplace=True)

In [8]:
test_dataset["pg_ans"] = df_results["Generated Answer"]
test_dataset["label"] = df_results["Label"]

In [9]:
test_dataset

Unnamed: 0,image_id,case_id,pg_ans,label
0,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,MPX1389,answer What appears to be the issue?\nThe issu...,The findings suggest Plain radiographs demonst...
1,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,MPX1389,answer What could the diagnosis suggest?\nThe ...,The possible diagnosis includes Acromion fract...
2,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,MPX1389,answer What was observed in the imaging?\nThe ...,The imaging shows Plain radiographs demonstrat...
3,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,MPX2257,"answer What might the diagnosis be?\nSorry, as...",The possible diagnosis includes meningioma\nne...
4,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,MPX1251,answer What is the potential diagnosis?\nThe p...,The possible diagnosis includes Chondrosarcoma...
...,...,...,...,...
1555,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,MPX1879,answer What is the location of the lesion in t...,The lesion is located in the dura of the falx.
1556,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,MPX1879,answer What type of symptoms did the 21-year-o...,He presented with a new onset seizure and seve...
1557,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,MPX1879,answer What is the most common form of meningi...,"The most common forms are ""globose"" (spherical..."
1558,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,MPX1879,answer What is the appearance of meningiomas o...,They are typically well-circumscribed and iso-...


In [10]:
from datasets import Dataset, Features, Value, Image, DatasetDict
import pandas as pd


# Convert the DataFrame to a Hugging Face Dataset with image features
# image_id,case_id,question,answer,ans_len,mode,split
features = Features(
    {
        "image_id": Image(),
        "case_id": Value("string"),
        "pg_ans": Value("string"),
        "label": Value("string"),

    }
)
inference_dataset = Dataset.from_pandas(test_dataset,features=features)

inference_dict = DatasetDict(
    {"Infer": inference_dataset}
)

inference_dict.push_to_hub(adapter_model_id + "__infer__")



Map:   0%|          | 0/1560 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/16 [00:00<?, ?ba/s]