In [None]:
# ===== 0. Bibliotheken =====
from pathlib import Path
import csv
from typing import Dict

import numpy as np
import torch
from datasets import load_from_disk
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import albumentations as A

In [None]:

# ===== 1. Modell & Prozessor =====
MODEL_ID = "../models/Gemma_3_4B/merged_model_batchsize_2/"
model = Gemma3ForConditionalGeneration.from_pretrained(
    MODEL_ID,
    device_map={"": 0},
    torch_dtype=torch.bfloat16,
).eval()

processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)

DTYPE = (
    torch.bfloat16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else torch.float32
)
device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
print(f"🖥️ Torch device: {device_name} | Dtype: {DTYPE}")


In [None]:
# ===== 2. Daten laden =====
PROJECT_ROOT = Path.cwd().parent
DATA_PATH = PROJECT_ROOT / "data" / "validation"
dataset = load_from_disk(str(DATA_PATH))
print(f"Dataset geladen: {len(dataset)} Samples")

In [None]:
# ===== 3. Bild-Augmentierung =====
augment = A.Compose(
    [
        A.Rotate(limit=15, p=0.5),
        A.RandomBrightnessContrast(0.1, 0.1, p=0.5),
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
        A.MotionBlur(blur_limit=3, p=0.3),
        A.CoarseDropout(max_holes=5, max_h_size=32, max_w_size=32, p=0.3),
        A.ImageCompression(70, 100, p=0.3),
    ]
)

In [None]:
# ===== 4. CSV-Ausgabe vorbereiten =====
OUTPUT_CSV = PROJECT_ROOT / "data" / "llm_answers" / "batchsize_2_imageaug_results.csv"
OUTPUT_CSV.parent.mkdir(parents=True, exist_ok=True)

FIELDNAMES = ["ID", "question", "correct_answer", "model_output"]

In [None]:
with OUTPUT_CSV.open("w", newline="", encoding="utf-8") as f_out:
    writer = csv.DictWriter(f_out, fieldnames=FIELDNAMES)
    writer.writeheader()

    # ===== 5. Generierung =====
    for idx in tqdm(range(len(dataset)), desc="Samples"):
        sample = dataset[idx]

        # Bild laden und augmentieren
        image_rgb = sample["image"].convert("RGB")
        aug_np = augment(image=np.array(image_rgb))["image"]
        aug_image = Image.fromarray(aug_np)

        question = sample["question"]
        ground_truth = sample["answer"]
        qid = sample.get("id", idx + 1)

        # Gemma-Chat-Prompt
        messages = [
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": (
                            "You are a medical pathology expert. "
                            "Answer strictly based on the visual information in the image. "
                            "Use short precise terms without explanations."
                        ),
                    }
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": aug_image},
                    {"type": "text", "text": question},
                ],
            },
        ]

        try:
            # Tokenisierung
            inputs = processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt",
            ).to(model.device, dtype=DTYPE)

            prompt_len = inputs["input_ids"].shape[-1]

            # Inferenz
            with torch.inference_mode():
                gen = model.generate(
                    **inputs,
                    max_new_tokens=100,
                    do_sample=False,
                )

            answer = processor.decode(gen[0][prompt_len:], skip_special_tokens=True).strip()

            # CSV-Eintrag
            writer.writerow(
                {
                    "ID": qid,
                    "question": question,
                    "correct_answer": ground_truth,
                    "model_output": answer,
                }
            )

        except Exception as err:
            print(f"Fehler bei Sample {qid}: {err}")
            continue

print(f"\nAlle Ergebnisse gesichert: {OUTPUT_CSV.relative_to(PROJECT_ROOT)}")