## Training on MD5 dataset

In [None]:
# @title Install libraries if necessary
%%capture
import os
if "COLAB_" in "".join(os.environ.keys()):
    !pip install rouge_score
    !pip install accelerate
    !pip install evaluate
    !pip install trl
    !pip install bitsandbytes
    import ntlk
    # Download ntlk tokenizers. if needed
    nltk_data_path = os.path.join(os.path.expanduser("~"), "nltk_data")
    nltk.download('punkt_tab', download_dir=nltk_data_path)

    # Check download location:
    print(nltk.data.path)  # Make sure the home directory is in here!

    # Confirm file existence:
    try:
        nltk.data.find("tokenizers/punkt_tab")
        print("punkt_tab found!")
    except LookupError:
        print("punkt_tab still not found. Please check NLTK's data path and files.")

In [None]:
# @title Imports
from huggingface_hub import login
import torch
import json
import os
from PIL import Image
from tqdm import tqdm

import random
from typing import Tuple, List, Dict, Any


# from data_loader import load_md5_dataset
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

# Download NLTK resources
nltk.download('punkt')
nltk.download('punkt_tab')

# Authenticate with Hugging Face using token
login("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")
print("Successfully authenticated with Hugging Face")

2025-05-01 08:48:25.475296: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746107306.537921   14861 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746107306.837369   14861 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746107309.317443   14861 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746107309.317504   14861 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746107309.317561   14861 computation_placer.cc:177] computation placer alr

Successfully authenticated with Hugging Face


In [None]:
import evaluate
!pip install bert_score

[0m

In [None]:
# @title Function to load md5 dataset

def load_md5_dataset(
    train_dataset_path: str = 'vqa_rad_gemini_train.json',
    test_dataset_path: str = 'vqa_rad_gemini_test.json',
    num_train_images: int | None = None,
    num_test_images: int | None = None,
):

    # function to load dataset
    def _load_dataset(
            dataset_path: str, num_images: int | None = None
    ) -> List[Tuple[str, Dict[str, Any]]]:
        # Load the MD5 dataset and truncate it to first num_images images if specified
        print(f"Loading MD5 VQA-RAD dataset from {dataset_path}...")
        try:
            with open(dataset_path, 'r') as f:
                md5_data = json.load(f)
                md5_data = list(md5_data.items())
                if num_images is not None:
                    md5_data = md5_data[:num_images]
            print(f"Loaded dataset info: {len(md5_data)} images with " \
                  f"{sum(len(img_data['qa_pairs']) for _, img_data in md5_data)} Q&A pairs")
        except Exception as e:
            print(f"Error loading MD5 dataset: {e}")
            raise
        return md5_data

    # Load train and test datasets
    train_data = _load_dataset(train_dataset_path, num_train_images)
    test_data = _load_dataset(test_dataset_path, num_test_images)

    return train_data, test_data


In [None]:
# @title Template and function to format data

# System message for the assistant
system_message = "You are a medical image analysis expert. Provide detailed and accurate answers to questions about medical images."

# User prompt that combines the user query and image
user_prompt = """Analyze the following medical image and answer this question:

<QUESTION>
{question}
</QUESTION>
"""

# Process the data to format for OAI messages
def format_data(image_path, qa_pair):
    return {
        "messages": [
            {"role": "system", "content": [{"type": "text", "text": system_message}],},
            {"role": "user",
                "content": [
                    {"type": "text","text": user_prompt.format(question=qa_pair["question"],),},
                    {"type": "image", "image": Image.open(image_path).convert("RGB"),},
                ],
            },
            {"role": "assistant","content": [{"type": "text", "text": qa_pair["descriptive_answer"]}],},
        ],
    }

def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image)
    return image_inputs


In [None]:
# @title Load and format md5 dataset
# Load MD5 dataset
train_data, test_data = load_md5_dataset(
    # num_train_images=3,
    # num_test_images=3
)

# Convert dataset to OAI messages format
formatted_train_data = []
for img_hash, img_data in train_data:
    for qa_pair in img_data["qa_pairs"]:
        try:
            formatted_sample = format_data(img_data["image_path"], qa_pair)
            formatted_train_data.append(formatted_sample)
        except Exception as e:
            print(f"Error processing image {img_hash}: {e}")

print(f"Formatted {len(formatted_train_data)} training samples")
if formatted_train_data:
    print("Sample data:", formatted_train_data[0]["messages"])


Loading MD5 VQA-RAD dataset from vqa_rad_gemini_train.json...
Loaded dataset info: 313 images with 1793 Q&A pairs
Loading MD5 VQA-RAD dataset from vqa_rad_gemini_test.json...
Loaded dataset info: 203 images with 451 Q&A pairs
Formatted 1793 training samples
Sample data: [{'role': 'system', 'content': [{'type': 'text', 'text': 'You are a medical image analysis expert. Provide detailed and accurate answers to questions about medical images.'}]}, {'role': 'user', 'content': [{'type': 'text', 'text': 'Analyze the following medical image and answer this question:\n\n<QUESTION>\nare regions of the brain infarcted?\n</QUESTION>\n'}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=566x555 at 0x15001AEEACE0>}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'Yes. Explanation: The image shows areas of hyperintensity in the left hemisphere, which are indicative of restricted diffusion, a hallmark of acute infarction.'}]}]


In [None]:
# @title Load Gemma3 4B Instruction tuned model and it's processor

# Hugging Face model id
model_id = "google/gemma-3-4b-it"  # or `google/gemma-3-12b-it`

# Check if GPU benefits from bfloat16
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    print("GPU supports bfloat16")
else:
    print("Warning: GPU may not support bfloat16, proceeding anyway")

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager",  # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch.bfloat16,  # What torch dtype to use, defaults to auto
    device_map="auto",  # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", use_fast=True)

GPU supports bfloat16


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# @title Evaluate model with BERT score

def evaluate_model(model, processor, test_data, name=""):
    """
    Evaluate the model on test data using BLEU, ROUGE, BERT scores, and accuracy metrics.

    Args:
        model: The model to evaluate
        processor: The processor to use for tokenization
        test_data: The test dataset
        name: Name identifier for logging purposes

    Returns:
        dict: Dictionary containing all evaluation metrics
    """
    print(f"\n===== Model Evaluation: {name} =====")

    bleu_scores = []
    rouge_metrics = []
    bert_scores = []
    exact_match_count = 0
    total_samples = 0

    # Initialize ROUGE scorer
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    # Initialize BERTScore
    bertscore = evaluate.load("bertscore")

    # Process each test sample
    for sample in tqdm(test_data, desc="Evaluating"):
        img_hash, img_data = sample
        for qa_pair in img_data["qa_pairs"]:
            try:
                # Get ground truth
                reference = qa_pair["descriptive_answer"]

                # Generate prediction
                prediction = generate_medical_answer(
                    img_data["image_path"],
                    qa_pair["question"],
                    model,
                    processor
                )

                # Tokenize for BLEU
                reference_tokens = nltk.word_tokenize(reference.lower())
                prediction_tokens = nltk.word_tokenize(prediction.lower())

                # Calculate BLEU
                smoothie = SmoothingFunction().method1
                bleu = sentence_bleu([reference_tokens], prediction_tokens, smoothing_function=smoothie)
                bleu_scores.append(bleu)

                # Calculate ROUGE
                rouge_results = scorer.score(reference, prediction)
                rouge_metrics.append(rouge_results)

                # Calculate BERT Score
                bert_result = bertscore.compute(
                    predictions=[prediction],
                    references=[reference],
                    lang="en",
                    model_type="microsoft/deberta-xlarge-mnli"
                )
                bert_scores.append(bert_result["f1"][0])  # Using F1 score from BERT Score

                # Extract just the reference answer part (before "Explanation:")
                reference_answer = reference.split("Explanation:")[0].strip() if "Explanation:" in reference else reference.strip()
                prediction_answer = prediction.split("Explanation:")[0].strip() if "Explanation:" in prediction else prediction.strip()

                # Simple exact match accuracy for just the answer part (case-insensitive)
                if reference_answer.lower() == prediction_answer.lower():
                    exact_match_count += 1

                total_samples += 1

            except Exception as e:
                print(f"Error evaluating sample {img_hash}: {e}")

    # Calculate aggregated metrics
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0

    # Average ROUGE scores
    rouge1 = sum(r['rouge1'].fmeasure for r in rouge_metrics) / len(rouge_metrics) if rouge_metrics else 0
    rouge2 = sum(r['rouge2'].fmeasure for r in rouge_metrics) / len(rouge_metrics) if rouge_metrics else 0
    rougeL = sum(r['rougeL'].fmeasure for r in rouge_metrics) / len(rouge_metrics) if rouge_metrics else 0

    # Average BERT score
    avg_bert_score = sum(bert_scores) / len(bert_scores) if bert_scores else 0

    # Calculate accuracy
    accuracy = exact_match_count / total_samples if total_samples > 0 else 0

    # Print results
    print(f"Evaluated on {total_samples} samples")
    print(f"BLEU Score: {avg_bleu:.4f}")
    print(f"ROUGE-1 F1: {rouge1:.4f}")
    print(f"ROUGE-2 F1: {rouge2:.4f}")
    print(f"ROUGE-L F1: {rougeL:.4f}")
    print(f"BERT Score F1: {avg_bert_score:.4f}")
    print(f"Exact Match Accuracy (reference answer only): {accuracy:.4f}")

    # Return all metrics
    return {
        "bleu": avg_bleu,
        "rouge1": rouge1,
        "rouge2": rouge2,
        "rougeL": rougeL,
        "bertscore": avg_bert_score,
        "accuracy": accuracy,
        "samples": total_samples
    }


def generate_medical_answer(image_path, question, model, processor):
    # Convert sample into messages and then apply the chat template
    messages = [
        {"role": "system", "content": [{"type": "text", "text": system_message}]},
        {"role": "user", "content": [
            {"type": "image", "image": Image.open(image_path).convert("RGB")},
            {"type": "text", "text": user_prompt.format(question=question)},
        ]},
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    # Process the image and text
    image_inputs = process_vision_info(messages)
    # Tokenize the text and process the images
    inputs = processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt",
    )
    # Move the inputs to the device
    inputs = inputs.to(model.device)

    # Generate the output
    stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
    with torch.no_grad():
        # generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids)
    # Trim the generation and decode the output to text
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text[0]




IndentationError: expected an indented block after 'with' statement on line 142 (2824962431.py, line 145)

In [None]:
# @title Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100  # Special token masking

    batch["labels"] = labels
    return batch


In [None]:
# @title LORA & Trainer config
# LoRA configuration
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

# Training configuration
args = SFTConfig(
    output_dir="gemma-md5-10epochs_8r_run",            # directory to save and repository id
    output_dir="gemma-md5-5epochs_run",
    num_train_epochs=10,                        # number of training epochs
    per_device_train_batch_size=1,             # batch size per device during training
    gradient_accumulation_steps=32,             # number of steps before performing a backward/update pass
    gradient_checkpointing=True,               # use gradient checkpointing to save memory
    optim="adamw_torch_fused",                 # use fused adamw optimizer
    logging_steps=1,                           # log every N steps
    save_strategy="epoch",                     # save checkpoint every epoch
    learning_rate=1e-4,                        # learning rate, based on QLoRA paper
    bf16=True,                                 # use bfloat16 precision
    max_grad_norm=0.3,                         # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                         # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",              # use constant learning rate scheduler
    # push_to_hub=True,                          # push model to hub
    report_to="tensorboard",                   # report metrics to tensorboard
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },
    dataset_text_field="",                     # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
    # resume_from_checkpoint='gemma',
)
args.remove_unused_columns = False  # important for collator

# Disable wandb integration
os.environ["WANDB_DISABLED"] = "true"

# Initialize the trainer
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=formatted_train_data,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
# @title Printing Noumber of trainable parameters

from peft import PeftModel

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,}")


Trainable parameters: 1361753856
Trainable parameters: 1,361,753,856


In [None]:
# @title Evaluate the model before training
print("\nEvaluating the base model before fine-tuning...")
before_metrics = evaluate_model(model, processor, test_data, "Before Fine-tuning")  # Limit to 5 samples for faster evaluation
print(before_metrics)



Evaluating the base model before fine-tuning...



KeyboardInterrupt



In [None]:
# @title Start training
print("Starting training...")
trainer.train()

# Save the final model
trainer.save_model()
print("Model saved!")


Starting training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,210.3217
2,204.3471
3,207.1181
4,191.3896
5,201.0484
6,195.8064
7,189.1103
8,165.4267
9,153.3032
10,136.5456


Model saved!


In [None]:
# @title Load the saved model
model = AutoModelForImageTextToText.from_pretrained(
    args.output_dir,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
print("\nEvaluating the fine-tuned model...")

after_metrics = evaluate_model(model, processor, test_data, "After Fine-tuning")
print(after_metrics)

before_metrics = evaluate_model(model, processor, test_data, "before Fine-tuning")


In [None]:
after_metrics = evaluate_model(model, processor, test_data, "After Fine-tuning")
print(after_metrics)

# # Print comparison
# print("\nPerformance Comparison:")
# print(f"BLEU Score: {before_metrics['bleu']:.4f} → {after_metrics['bleu']:.4f}")
# print(f"ROUGE-1 F1: {before_metrics['rouge1']:.4f} → {after_metrics['rouge1']:.4f}")
# print(f"ROUGE-2 F1: {before_metrics['rouge2']:.4f} → {after_metrics['rouge2']:.4f}")
# print(f"ROUGE-L F1: {before_metrics['rougeL']:.4f} → {after_metrics['rougeL']:.4f}")
# print(f"Accuracy: {before_metrics['accuracy']:.4f} → {after_metrics['accuracy']:.4f}")


===== Model Evaluation: After Fine-tuning =====


Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████| 203/203 [31:51<00:00,  9.42s/it]

Evaluated on 451 samples
BLEU Score: 0.2312
ROUGE-1 F1: 0.4859
ROUGE-2 F1: 0.2832
ROUGE-L F1: 0.4040
BERT Score F1: 0.7342
Exact Match Accuracy (reference answer only): 0.4878
{'bleu': 0.2312323163893028, 'rouge1': 0.48592959167588895, 'rouge2': 0.28320957421439014, 'rougeL': 0.4039511628305343, 'bertscore': 0.7341564389528034, 'accuracy': 0.4878048780487805, 'samples': 451}





In [None]:
print('Lora=8, Temperature = 0.5')
after_metrics = evaluate_model(model, processor, test_data, "After Fine-tuning")
print(after_metrics)

Temperature = 0.5

===== Model Evaluation: After Fine-tuning =====


Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████| 203/203 [32:21<00:00,  9.56s/it]

Evaluated on 451 samples
BLEU Score: 0.2224
ROUGE-1 F1: 0.4800
ROUGE-2 F1: 0.2710
ROUGE-L F1: 0.3945
BERT Score F1: 0.7314
Exact Match Accuracy (reference answer only): 0.4479
{'bleu': 0.22241845962093695, 'rouge1': 0.48002207779727624, 'rouge2': 0.2710426961689062, 'rougeL': 0.39451398105657126, 'bertscore': 0.7313649768707757, 'accuracy': 0.44789356984478934, 'samples': 451}





In [None]:
print('LORA = 8, Temperature = 0.2')
after_metrics = evaluate_model(model, processor, test_data, "After Fine-tuning")
print(after_metrics)

Temperature = 0.2

===== Model Evaluation: After Fine-tuning =====


Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████| 203/203 [30:41<00:00,  9.07s/it]

Evaluated on 451 samples
BLEU Score: 0.2240
ROUGE-1 F1: 0.4813
ROUGE-2 F1: 0.2751
ROUGE-L F1: 0.4015
BERT Score F1: 0.7337
Exact Match Accuracy (reference answer only): 0.4612
{'bleu': 0.22398105843146573, 'rouge1': 0.4813235691791108, 'rouge2': 0.27508217253007033, 'rougeL': 0.4014893559524255, 'bertscore': 0.7336562686246673, 'accuracy': 0.4611973392461197, 'samples': 451}





In [None]:
print('Temperature = 0.2, LORA16')
after_metrics = evaluate_model(model, processor, test_data, "After Fine-tuning")
print(after_metrics)

Temperature = 0.2, LORA16

===== Model Evaluation: After Fine-tuning =====


Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████| 203/203 [29:19<00:00,  8.67s/it]

Evaluated on 451 samples
BLEU Score: 0.2407
ROUGE-1 F1: 0.4965
ROUGE-2 F1: 0.2944
ROUGE-L F1: 0.4179
BERT Score F1: 0.7464
Exact Match Accuracy (reference answer only): 0.4745
{'bleu': 0.24072255017835026, 'rouge1': 0.4965194948005037, 'rouge2': 0.29437741118218586, 'rougeL': 0.4179027531567964, 'bertscore': 0.7463569337142809, 'accuracy': 0.4745011086474501, 'samples': 451}





In [None]:
print('Temperature = 0.5, LORA16')
after_metrics = evaluate_model(model, processor, test_data, "After Fine-tuning")
print(after_metrics)

Temperature = 0.5, LORA16

===== Model Evaluation: After Fine-tuning =====


Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████| 203/203 [29:43<00:00,  8.78s/it]

Evaluated on 451 samples
BLEU Score: 0.2389
ROUGE-1 F1: 0.4864
ROUGE-2 F1: 0.2859
ROUGE-L F1: 0.4110
BERT Score F1: 0.7411
Exact Match Accuracy (reference answer only): 0.4789
{'bleu': 0.23887398671055451, 'rouge1': 0.48641951989442517, 'rouge2': 0.28586536271926455, 'rougeL': 0.41100367274475347, 'bertscore': 0.741133305027321, 'accuracy': 0.4789356984478936, 'samples': 451}





In [None]:
print('Temperature = 0.8, LORA16')
after_metrics = evaluate_model(model, processor, test_data, "After Fine-tuning")
print(after_metrics)

Temperature = 0.8, LORA16

===== Model Evaluation: After Fine-tuning =====


Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████| 203/203 [29:14<00:00,  8.64s/it]

Evaluated on 451 samples
BLEU Score: 0.2275
ROUGE-1 F1: 0.4880
ROUGE-2 F1: 0.2776
ROUGE-L F1: 0.4022
BERT Score F1: 0.7402
Exact Match Accuracy (reference answer only): 0.4656
{'bleu': 0.22746627384181467, 'rouge1': 0.4880306311725099, 'rouge2': 0.27761990035463485, 'rougeL': 0.40222257171815573, 'bertscore': 0.7402149718378706, 'accuracy': 0.4656319290465632, 'samples': 451}





In [None]:
# @title Printing a few predictions from test dataset

for sample in tqdm(test_data[:5], desc="Sample Printing"):
        img_hash, img_data = sample
        for qa_pair in img_data["qa_pairs"]:
                # Get ground truth
                reference = qa_pair["descriptive_answer"]

                # Generate prediction
                prediction = generate_medical_answer(
                    img_data["image_path"],
                    qa_pair["question"],
                    model,
                    processor
                )
                print(f'Question: {qa_pair["question"]}')
                print(f'Reference Answer: {reference} ')
                print(f'Predicted Answer: {prediction}')
                print('\n')

Sample Printing:   0%|          | 0/5 [00:00<?, ?it/s]

Question: is there evidence of an aortic aneurysm?
Reference Answer: No. Explanation: The image quality is poor, and there is no obvious aortic aneurysm. 
Predicted Answer: no. Explanation: The aortic knob and the ascending aorta appear within normal limits in size and contour. There is no obvious widening or outpouching to suggest an aneurysm.




Sample Printing:  20%|██        | 1/5 [00:08<00:34,  8.67s/it]

Question: is there cardiomyopathy?
Reference Answer: yes. Explanation: The presence of a pacemaker/ICD device suggests underlying cardiac disease, which can often be cardiomyopathy. Additionally, the cardiac silhouette appears enlarged, which is consistent with cardiomegaly, a common finding in cardiomyopathy. 
Predicted Answer: no. Explanation: The heart size appears normal, and the heart muscle's structure and function do not show signs of damage or enlargement, which are typical features of cardiomyopathy. The heart borders are within normal limits, and there is no evidence of wall thickening or unusual contour.


Question: is there airspace consolidation on the left side?
Reference Answer: yes. Explanation: There is increased opacity in the left lung, suggesting airspace consolidation. The lung markings are less distinct compared to the right lung, which is consistent with fluid or inflammatory material filling the airspaces. 
Predicted Answer: Yes. Explanation: There is a large ar

Sample Printing:  40%|████      | 2/5 [00:19<00:29,  9.93s/it]

Question: how is the patient oriented?
Reference Answer: posterior-anterior. Explanation: The image is a PA (posterior-anterior) chest radiograph. This is determined by the position of the heart and the scapulae. In a PA view, the heart appears less magnified, and the scapulae are rotated out of the lung fields. The "LEFT" marker also indicates the left side of the patient, which is standard for PA views. 
Predicted Answer: PA. Explanation: The image is a posteroanterior (PA) chest X-ray. This is determined by the positioning of the scapulae and the clarity of the heart shadow. In a PA view, the scapulae are typically seen superior to the heart, and the heart shadow is more apex-focused. The clavicles are also more horizontal in a PA view.


Question: is there any intraparenchymal abnormalities in the lung fields?
Reference Answer: no. Explanation: The lung fields appear clear and without any visible intraparenchymal abnormalities. The vasculature is normal in appearance. 
Predicted An

Sample Printing:  60%|██████    | 3/5 [00:38<00:28, 14.01s/it]

Question: is there evidence of any fractures of the ribs?
Reference Answer: no. Explanation: The ribs appear intact without any obvious discontinuities or breaks. 
Predicted Answer: no. Explanation: The ribs appear intact without any obvious discontinuities or breaks indicative of fractures. The cortical margins of the ribs are smooth and continuous throughout their visible course.


Question: which side of the heart border is obscured?
Reference Answer: right. Explanation: The right heart border is obscured by an infiltrate in the right middle lobe. 
Predicted Answer: right. Explanation: The right heart border is less clearly defined than the left heart border in this image, suggesting it is obscured. This could be due to a number of factors, including overlying structures or pathology.




Sample Printing:  80%|████████  | 4/5 [00:46<00:11, 11.76s/it]

Question: where is the lesion located?
Reference Answer: Anterior mediastinum. Explanation: The opacity seen in the right upper lung field, particularly its location relative to the heart and great vessels, suggests a mass or lesion in the anterior mediastinum. The anterior mediastinum is a common location for thymomas, teratomas, lymphomas, and thyroid masses, which are all possible differential diagnoses. 
Predicted Answer: Retrocardiac. Explanation: The lesion appears to be located behind the heart, in the retrocardiac space. This is suggested by its location outside the normal lung fields and its appearance on both sides of the chest.


Question: where are the kidney?
Reference Answer: not seen here. Explanation: The image is a transverse section of the upper abdomen. The liver, spleen, stomach, and bowel are visible. The kidneys are located more inferiorly in the abdomen and are not included in this slice. 
Predicted Answer: The kidneys are not visible in this single axial CT imag

Sample Printing: 100%|██████████| 5/5 [01:03<00:00, 12.65s/it]

Question: what is the dense mass visualized in the liver?
Reference Answer: Blood vessel. Explanation: The dense mass visualized in the liver is likely a blood vessel due to its location and appearance within the liver parenchyma. The contrast enhancement suggests vascularity. 
Predicted Answer: Hepatocellular carcinoma. Explanation: The CT scan shows a large, dense mass in the liver, which is consistent with hepatocellular carcinoma. The mass appears to be growing exophytically and distending the abdomen.







In [None]:
print(args.output_dir)

gemma-md5-10epochs_8r_run


In [None]:
del model

import gc
gc.collect()

24434