In [None]:
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from pathlib import Path
from datasets import load_from_disk
from tqdm import tqdm
import torch
import csv
import albumentations as A
from PIL import Image
import numpy as np
import openai

  check_for_updates()


In [None]:
# ---------------------------------------------------------------------------
# 1. Modell und Prozessor laden
# ---------------------------------------------------------------------------
model_id = "merged_model_batchsize_2/"
model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id, device_map={"":0}, torch_dtype=torch.bfloat16
).eval()
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)

Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

In [None]:
# ---------------------------------------------------------------------------
# 2. Gerätetyp ermitteln (GPU mit BF16‑Unterstützung bevorzugt)
# ---------------------------------------------------------------------------
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
    dtype = torch.bfloat16
else:
    dtype = torch.float
print("🖥️ Torch device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU only")

🖥️ Torch device: NVIDIA L4


In [None]:
# ---------------------------------------------------------------------------
# 3. Datensatz laden
# ---------------------------------------------------------------------------
project_root = Path.cwd().parent
data_path = project_root / "data" / "validation"
dataset = load_from_disk(str(data_path))

In [None]:
# ---------------------------------------------------------------------------
# 4. OpenAI‑Client (für Frage‑Umformulierungen)
# ---------------------------------------------------------------------------
client = openai.OpenAI(api_key="api_key")
MODEL_NAME = "gpt-4o-mini"    

In [None]:
# ---------------------------------------------------------------------------
# 5. Helferfunktion: Frage umformulieren (sprachliche Diversität)
# ---------------------------------------------------------------------------
def rephrase_question(question: str, n_variants: int = 1) -> list[str]:
    system_prompt = (
    # ---- ROLE ----
    "You are a senior medical language specialist.\n"
    "\n"
    "TASK\n"
    "Rephrase the following pathology-related question into ONE linguistically diverse variant.\n"
    "\n"
    "HARD CONSTRAINTS (must stay identical)\n"
    "• Medical meaning and implied answer\n"
    "• Scope and subject of the question\n"
    "• Answer type (yes/no ↔ yes/no, entity ↔ entity, etc.)\n"
    "• Facts, numbers, entities – none may be added or removed\n"
    "• Do not turn a request for a characteristic into a definition (or vice versa).\n"
    "• Expected answer length and granularity must match the original.\n"
    "• Preserve key morphological descriptors exactly (e.g. “swirl”, “wavy”, “keratin pearl”).\n"
    "\n"
    "VARIATION GUIDELINES (apply at least one)\n"
    "1. Switch voice or clause order (active ↔ passive, fronting, etc.)\n"
    "2. Use precise synonyms or hyper-/hyponyms that do **not** affect meaning\n"
    "3. Embed the core question in a different grammatical structure (e.g. statement → question)\n"
    "\n"
    "QUALITY FILTER\n"
    "• Paraphrase must change ≥ 20 % of tokens compared with the original.\n"
    "• Purely cosmetic edits (punctuation, capitalisation) are insufficient.\n"
    "• If **any** hard constraint would be violated **or** the 20 % threshold is not met, "
    "output exactly the ORIGINAL question without modification.\n"
    "\n"
    "OUTPUT\n"
    "Return exactly ONE line with the rephrased question – no quotation marks, numbering or extra text."
    )

    prompt  = f"Original question:\n\"{question}\"\n\nRephrased question:"
    
    response = client.chat.completions.create(
        model=MODEL_NAME,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user",   "content": prompt}
        ],
        temperature=0.6,     # leicht kreativ, trotzdem präzise
        top_p=0.95,
        n=n_variants
    )

    return [choice.message.content.strip() for choice in response.choices]

In [None]:
# ---------------------------------------------------------------------------
# 6. CSV‑Ausgabe vorbereiten – inkl. Augmentations‑Info
# ---------------------------------------------------------------------------

output_file = "../data/llm_answers/batchsize_2_satzumstellung_mit_Gpt_4o.csv"
fieldnames = [
    "ID",
    "augmentation",
    "question_orig",
    "question_augmented",
    "correct_answer",
    "model_output",
]

In [None]:
with open(output_file, "w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
        # -----------------------------------------------------------------------
    # 7. Hauptschleife: Datensatz durchlaufen
    # -----------------------------------------------------------------------

    for idx in tqdm(range(len(dataset))):
        sample         = dataset[idx]
        question_orig  = sample["question"]
        ground_truth   = sample["answer"]
        qid            = sample.get("id", idx + 1)

        # ---------- Schritt 1: Rephrasing ----------
        try:
            question_aug = rephrase_question(question_orig)[0]
            aug_name = "gpt4o_rephrase"
        except Exception as e:
            print(f"Rephrase-Fehler bei ID {qid}: {e}")
            question_aug = question_orig
            aug_name = "none"

        # ---------- Schritt 2: Vision-LLM ----------
        messages = [
            {
                "role": "system",
                "content": [{
                    "type": "text",
                    "text": (
                        "You are a medical pathology expert. Your task is to answer "
                        "medical questions based solely on the visual information in the "
                        "provided pathology image. Focus only on what is visible in the image — "
                        "do not rely on prior medical knowledge, assumptions, or external information. "
                        "Your responses should be short, factual, and medically precise, using "
                        "appropriate terminology. Do not include explanations. "
                        "Use a consistent format, no punctuation, avoid capitalisation unless needed. "
                        "Only return the exact answer."
                    )
                }]
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": sample["image"].convert("RGB")},
                    {"type": "text",  "text": question_aug}
                ]
            }
        ]

        inputs = processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to(model.device, dtype=dtype)

        input_len = inputs["input_ids"].shape[-1]

        with torch.inference_mode():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=False
            )[0][input_len:]

        llm_answer = processor.decode(output_ids, skip_special_tokens=True).strip()

        # ---------- Schritt 3: CSV-Log ----------
        writer.writerow({
            "ID": qid,
            "augmentation": aug_name,
            "question_orig": question_orig,
            "question_augmented": question_aug,
            "correct_answer": ground_truth,
            "model_output": llm_answer
        })

print("✅ Alle Ergebnisse gesichert:", output_file)

100%|██████████| 6259/6259 [1:54:17<00:00,  1.10s/it]  

✅ Alle Ergebnisse gesichert: ../data/llm_answers/batchsize_2_satzumstellung_mit_Gpt_4o.csv



