In [None]:
# ===== 0. Bibliotheken =====
from pathlib import Path
import csv

import numpy as np
import torch
from datasets import load_from_disk
from huggingface_hub import login            # Token sollte über ENV gesetzt sein
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, Gemma3ForConditionalGeneration

# NLP-Augmentation
import nlpaug.augmenter.char as nac
import nlpaug.augmenter.word as naw
import nlpaug.model.word_dict.wordnet as nmw_wordnet

import nltk
from nltk.corpus import wordnet
for pkg in ("wordnet", "omw-1.4", "averaged_perceptron_tagger_eng", "punkt"):
    try:
        nltk.data.find(f"corpora/{pkg}")
    except LookupError:
        nltk.download(pkg, quiet=True)

# nlpaug benötigt ein globales WordNet-Objekt
nmw_wordnet.wordnet = wordnet

In [None]:
# ===== 2. Modell & Prozessor =====
model_id = "../models/Gemma_3_4B/merged_model_batchsize_2"
model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map={"": 0},
    torch_dtype=torch.bfloat16,
).eval()

processor = AutoProcessor.from_pretrained(model_id, use_fast=True)

DTYPE = (
    torch.bfloat16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else torch.float32
)
device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
print(f"🖥️ Torch device: {device_name} | Dtype: {DTYPE}")

In [None]:
# ===== 3. Daten laden =====
project_root = Path.cwd().parent
data_path = project_root / "data" / "validation"
dataset = load_from_disk(str(data_path))
print(f"Datensatz geladen: {len(dataset)} Samples")

In [None]:
# ===== 4. Text-Augmenter definieren =====
text_augmenters = {
    "char_swap": nac.RandomCharAug(action="swap", aug_char_p=0.015, aug_char_min=1, aug_char_max=1),
    "typo_insert": nac.RandomCharAug(action="insert", aug_char_p=0.007, aug_char_min=1, aug_char_max=1),
    "sentence_shuffle": naw.RandomWordAug(action="swap", aug_p=0.15, aug_min=1, aug_max=1),
}

In [None]:
# ===== 5. CSV-Ausgabe vorbereiten =====
output_csv = project_root / "data" / "llm_answers" / "batchsize_2_textaugmented_results.csv"
output_csv.parent.mkdir(parents=True, exist_ok=True)

fieldnames = [
    "ID",
    "augmentation",
    "question_orig",
    "question_augmented",
    "correct_answer",
    "model_output",
]

In [None]:

with output_csv.open("w", newline="", encoding="utf-8") as f_out:
    writer = csv.DictWriter(f_out, fieldnames=fieldnames)
    writer.writeheader()
    # ===== 6. Generierung =====
    for aug_name, text_aug in text_augmenters.items():
        print(f"\n==> Running text augmentation: {aug_name}")
        for idx in tqdm(range(len(dataset)), desc=aug_name):
            sample = dataset[idx]

            # Bild (PIL RGB)
            image = sample["image"].convert("RGB")

            # Frage augmentieren
            question_orig = sample["question"]
            question_aug = text_aug.augment(question_orig)

            ground_truth = sample["answer"]
            qid = sample.get("id", idx + 1)

            # Gemma-Chat-Prompt
            messages = [
                {
                    "role": "system",
                    "content": [
                        {
                            "type": "text",
                            "text": (
                                "You are a medical pathology expert. Answer strictly "
                                "based on the visual information in the image. Use short "
                                "precise terms without explanations."
                            ),
                        }
                    ],
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image},
                        {"type": "text", "text": question_aug},
                    ],
                },
            ]

            try:
                # Tokenisierung
                inputs = processor.apply_chat_template(
                    messages,
                    add_generation_prompt=True,
                    tokenize=True,
                    return_dict=True,
                    return_tensors="pt",
                ).to(model.device, dtype=DTYPE)

                prompt_len = inputs["input_ids"].shape[-1]

                # Inferenz
                with torch.inference_mode():
                    gen_output = model.generate(
                        **inputs,
                        max_new_tokens=100,
                        do_sample=False,
                    )

                answer = processor.decode(gen_output[0][prompt_len:], skip_special_tokens=True).strip()

                # CSV-Eintrag
                writer.writerow(
                    {
                        "ID": qid,
                        "augmentation": aug_name,
                        "question_orig": question_orig,
                        "question_augmented": question_aug,
                        "correct_answer": ground_truth,
                        "model_output": answer,
                    }
                )

            except Exception as err:
                print(f"Fehler bei Sample {qid}, Augmentation {aug_name}: {err}")
                continue

    print(f"\n Alle Ergebnisse gesichert: {output_csv.relative_to(project_root)}")