In [None]:
import os
import csv
import re
import glob
from PIL import Image
import torch
from tqdm import tqdm
from transformers import pipeline

# CONFIGURATION
MODEL_ID = "google/medgemma-4b-it"

OUTPUT_FILE = "./results_marisse.csv"
IMAGE_PATH = "./full-fundus/*.*"

# INITIALIZE PIPELINE
pipe = pipeline(
    "image-text-to-text",
    model=MODEL_ID,
    torch_dtype=torch.bfloat16,
    # device_map="auto",
    device=1,
)

# PROMPT
base_prompt_text = (
    "You are an expert ophthalmologist. Evaluate this fundus image for signs of glaucoma "
    "(optic disc cupping, RNFL loss, peripapillary atrophy). "
    "Write your Key Findings. Then provide your Conclusion. Do not include a Disclaimer."
)

# LOAD IMAGES
image_files = sorted(glob.glob(IMAGE_PATH))
print(f"Found {len(image_files)} images.")

# Ensure directory exists
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)

# PROCESSING LOOP
with open(OUTPUT_FILE, mode='w', newline='', encoding='utf-8') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["Image File", "Full Reasoning"])

    for image_file in tqdm(image_files):
        try:
            image = Image.open(image_file).convert("RGB")  # Ensure consistent color channels

            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": base_prompt_text},
                        {"type": "image", "image": image}
                    ]
                }
            ]

            output = pipe(
                messages,
                max_new_tokens=2048,
                do_sample=False
            )

            generated_text = output[0]["generated_text"]
            if isinstance(generated_text, list):
                raw_response = generated_text[-1]["content"].strip()
            else:
                raw_response = generated_text.strip()

            writer.writerow([os.path.basename(image_file), raw_response])

        except Exception as e:
            print(f"Error processing {image_file}: {e}")
            writer.writerow([os.path.basename(image_file), "ERROR", str(e)])

print("Processing complete.")