In [1]:
import os, torch, pickle, gc
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForImageTextToText
from evaluate import load as load_metric

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("test_reasoning_data_cache.pkl", "rb") as f:
    test_data = pickle.load(f)

count = 200
rouge = load_metric("rouge")

def evaluate_model(model, processor):
    model.eval()
    predictions, references = [], []
    for sample in tqdm(test_data[:count]):
        image = sample["image"]
        gt = sample["messages"][1]["content"][0]["text"]
        messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Analyze this medical image and provide step-by-step findings."}]}]
        prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=prompt, images=[image], return_tensors="pt", padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=300)
        decoded = processor.tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
        predictions.append(decoded.strip())
        references.append(gt.strip())
    return rouge.compute(predictions=predictions, references=references)

print("Evaluating base model")
processor = AutoProcessor.from_pretrained("google/medgemma-4b-it")
model = AutoModelForImageTextToText.from_pretrained("google/medgemma-4b-it", torch_dtype=torch.bfloat16, device_map="auto")
base_scores = evaluate_model(model, processor)

del model, processor
torch.cuda.empty_cache()
gc.collect()


Evaluating base model


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 200/200 [31:32<00:00,  9.46s/it]


15163

In [2]:
print("🔍 Evaluating fine-tuned reasoning model")

processor = AutoProcessor.from_pretrained("Manusinhh/medgemma-finetuned-cxr-reasoning")
model = AutoModelForImageTextToText.from_pretrained("Manusinhh/medgemma-finetuned-cxr-reasoning", torch_dtype=torch.bfloat16, device_map="auto")
ft_scores = evaluate_model(model, processor)

del model, processor
torch.cuda.empty_cache()
gc.collect()


🔍 Evaluating fine-tuned reasoning model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 200/200 [17:05<00:00,  5.13s/it]


16085

In [3]:
print("\n📊 ROUGE Score Comparison (on 2 samples):")
print(f"{'Metric':<12} | {'Base Model':<10} | {'Fine-tuned':<10}")
print("-" * 38)
for metric in ["rouge1", "rouge2", "rougeL", "rougeLsum"]:
    b = base_scores[metric]
    f = ft_scores[metric]
    print(f"{metric:<12} | {b:<10.4f} | {f:<10.4f}")



📊 ROUGE Score Comparison (on 2 samples):
Metric       | Base Model | Fine-tuned
--------------------------------------
rouge1       | 0.2716     | 0.4299    
rouge2       | 0.0521     | 0.1862    
rougeL       | 0.1472     | 0.2806    
rougeLsum    | 0.2588     | 0.4147    


In [4]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

metrics = ["rouge1", "rouge2", "rougeL", "rougeLsum"]

base_vals = [base_scores[m] for m in metrics]
ft_vals = [ft_scores[m] for m in metrics]

# Create bar chart
x = range(len(metrics))
plt.figure(figsize=(8, 5))
plt.bar(x, base_vals, width=0.4, label="Base Model", align='center')
plt.bar([i + 0.4 for i in x], ft_vals, width=0.4, label="Fine-tuned", align='center')
plt.xticks([i + 0.2 for i in x], metrics)
plt.ylabel("ROUGE F1 Score")
plt.title("ROUGE Score Comparison (Base vs Fine-tuned)")
plt.legend()
plt.tight_layout()
plt.savefig("rouge_comparison.png")
plt.show()