# Install Dependencies

In [None]:
!pip install -q transformers accelerate bitsandbytes pillow pandas

# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
IMAGE_FOLDER = "/content/drive/MyDrive/images"
from google.colab import files
uploaded = files.upload()

# Imports + Config

In [None]:
import os
import gc
import pandas as pd
from PIL import Image
import torch

from transformers import (
    AutoProcessor,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)

# =============================
# CONFIG
# =============================

MODEL_ID = "google/medgemma-1.5-4b-it"

IMAGE_FOLDER = "/content/images"        # Change if needed
REPORT_FOLDER = "/content/reports"
CSV_OUTPUT = "/content/vlm_results.csv"

MAX_IMAGES = 10
MAX_NEW_TOKENS = 80

os.makedirs(REPORT_FOLDER, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


# 4-bit Quantization Setup

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)


# Load MedGemma

In [None]:
print("Loading MedGemma (4bit)...")

processor = AutoProcessor.from_pretrained(
    MODEL_ID,
    use_fast=False
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto"
)

model.eval()

print("Model loaded successfully.")


# Load Images

In [None]:
image_paths = [
    os.path.join(IMAGE_FOLDER, f)
    for f in os.listdir(IMAGE_FOLDER)
    if f.lower().endswith((".png", ".jpg", ".jpeg"))
][:MAX_IMAGES]

print(f"Found {len(image_paths)} images.")


# Prompt Builder

In [None]:
def build_messages():
    return [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {
                    "type": "text",
                    "text": (
                        "You are an expert radiologist.\n"
                        "Analyze this chest X-ray.\n\n"
                        "Return:\n\n"
                        "Findings:\n"
                        "- ...\n\n"
                        "Abnormalities:\n"
                        "- ...\n\n"
                        "Impression:\n"
                        "- Pneumonia likely or unlikely"
                    ),
                },
            ],
        }
    ]


# Generate Reports

In [None]:
results = []

for i, img_path in enumerate(image_paths):

    print(f"\n[{i+1}/{len(image_paths)}] Processing {os.path.basename(img_path)}")

    try:
        image = Image.open(img_path).convert("RGB")
        image = image.resize((512, 512))

        messages = build_messages()

        prompt = processor.apply_chat_template(
            messages,
            add_generation_prompt=True
        )

        inputs = processor(
            text=prompt,
            images=[image],
            return_tensors="pt"
        )

        # Move inputs to same device as model
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                do_sample=False,
                temperature=0.0
            )

        generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:]

        report = processor.decode(
            generated_tokens,
            skip_special_tokens=True
        ).strip()

        if len(report) < 10:
            report = "No clear findings generated."

        report_file = os.path.join(
            REPORT_FOLDER,
            os.path.basename(img_path) + ".txt"
        )

        with open(report_file, "w") as f:
            f.write(report)

        results.append({
            "image": img_path,
            "report": report
        })

        print("✔ Done")

        del inputs, outputs
        torch.cuda.empty_cache()
        gc.collect()

    except Exception as e:
        print("❌ Error:", e)


# Save CSV

In [None]:
if results:
    df = pd.DataFrame(results)
    df.to_csv(CSV_OUTPUT, index=False)
    print("CSV saved at:", CSV_OUTPUT)

print("\nAll reports generated.")
