In [None]:
import csv
import os
import torch
import pandas as pd
from transformers import  AutoProcessor
from datasets import load_dataset
from pathlib import Path
from huggingface_hub import login
from datasets import load_dataset
from PIL import Image
from pathlib import Path
from datasets import load_from_disk
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, TrainerCallback, TrainerControl, TrainerState
from peft import LoraConfig
from trl import SFTConfig
from trl import SFTTrainer
from peft import get_peft_model
from transformers import AutoConfig


In [None]:
login("hf_login")
# Datensätze laden (Trainings- und Validierungssplit)
dataset = load_dataset("flaviagiammarino/path-vqa")
train_dataset = dataset["train"]
val_dataset = dataset["validation"]

print("Datasets wurden geladen.")
print("Trainingsgröße:", len(train_dataset))
print("Validierungsgröße:", len(val_dataset))

In [None]:
# Zelle 2: Few-Shot-Beispiele vorbereiten
few_shot_indices = [10658, 18497, 8273, 16324, 10392, 9073, 4623, 10336]
few_shot_examples = []

for idx in few_shot_indices:
    sample = train_dataset[idx]
    few_shot_examples.append({
        "question": sample["question"],
        "answer": sample["answer"],
        "image": sample["image"]
    })

print("Few-Shot-Beispiele vorbereitet. Anzahl:", len(few_shot_examples))

In [None]:
# Zelle 3: Nachrichtenliste (Prompt) erstellen

# Wähle ein Beispiel aus dem Validierungsdatensatz (z.B. das erste Sample)
sample = val_dataset[0]
val_image = sample["image"]
val_question = sample["question"]

# System Message definieren
system_message = (
    "You are a medical pathology expert. Your task is to answer medical questions "
    "based solely on the visual information in the provided pathology image. "
    "Focus only on what is visible in the image — do not rely on prior medical knowledge, "
    "assumptions, or external information. Your responses should be short, factual, "
    "and medically precise, using appropriate terminology. "
    "Do not include any explanations, reasoning, or additional text. "
    "Use a consistent format, without punctuation, and avoid capitalisation unless medically required. "
    "Only return the exact answer."
)

# Erstelle die Nachrichtenliste
messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": system_message}]
    }
]

# Füge alle Few-Shot-Beispiele hinzu, wobei Bild, Frage und Antwort in einer Nachricht kombiniert werden.
for ex in few_shot_examples:
    messages.append({
        "role": "user",
        "content": [
            {"type": "image", "image": ex["image"]},
            {"type": "text", "text": "question: " + ex["question"] + "\nanswer: " + ex["answer"]}
        ]
    })

# Füge das Validierungssample hinzu (nur Frage, da hier das Modell antworten soll)
messages.append({
        "role": "user",
        "content": [
            {"type": "image", "image": val_image},
            {"type": "text", "text": "question: " + val_question},
            {"type": "text", "text": "Answer: "}
        ]
    })

# Zeige die Nachrichtenliste an
print("Nachrichtenliste für den Prompt:")
for i, msg in enumerate(messages):
    print(f"Nachricht {i+1}: {msg}")


In [None]:
def process_vision_info(messages: list[dict]) -> tuple[list[Image.Image], list]:
    image_inputs = []
    # Durchlaufe alle Nachrichten
    for msg in messages:
        # Hole den Inhalt (sicherstellen, dass es eine Liste ist)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Ueberpruefe jeden Inhalt auf Bilder
        for element in content:
            if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                # Hole das Bild und konvertiere es in RGB
                image = element.get("image", element)
                image_inputs.append(image.convert("RGB"))
    return image_inputs, []  # Leere Liste für Videos, falls keine vorhanden sind

In [None]:
# Finalen Prompt und Modell-Input erzeugen und Antworten für 10 Beispiele generieren
model_id = "google/gemma-3-4b-it" 
config = AutoConfig.from_pretrained(model_id)

# Cache für Gemma ausschalten 
config.text_config.use_cache = False

# Definition von model init Argumenten
model_kwargs = dict(
    attn_implementation="eager",    
    torch_dtype=torch.bfloat16,
    device_map="auto",
    text_config=config.text_config
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Lade Modell und Prozessor
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
processor.tokenizer.padding_side = 'right'

# Iteriere über 10 Beispiele aus dem Validierungsdatensatz
for idx in range(10):
    sample = val_dataset[idx]
    val_image = sample["image"]
    val_question = sample["question"]

    # System Message
    system_message = (
        "You are a medical pathology expert. Your task is to answer medical questions "
        "based solely on the visual information in the provided pathology image. "
        "Focus only on what is visible in the image — do not rely on prior medical knowledge, "
        "assumptions, or external information. Your responses should be short, factual, "
        "and medically precise, using appropriate terminology. "
        "Do not include any explanations, reasoning, or additional text. "
        "Use a consistent format, without punctuation, and avoid capitalisation unless medically required. "
        "Only return the exact answer."
    )

    # Erstelle die initiale Nachrichtenliste mit der System Message
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}]
        }
    ]

    # Füge alle Few-Shot-Beispiele hinzu: 
    for ex in few_shot_examples:
        messages.append({
            "role": "user",
            "content": [
                {"type": "image", "image": ex["image"]},
                {"type": "text", "text": "question: " + ex["question"]}
            ]
        })
        messages.append({
            "role": "assistant",
            "content": [
                {"type": "text", "text": "answer: " + ex["answer"]}
            ]
        })

    # Füge das Validierungsbeispiel hinzu 
    messages.append({
        "role": "user",
        "content": [
            {"type": "image", "image": val_image},
            {"type": "text", "text": "question: " + val_question},
            {"type": "text", "text": "Answer: "}
        ]
    })

    # Anzeigen der Nachrichtenliste für das aktuelle Beispiel
    print("\n=== Nachrichtenliste für Beispiel", idx + 1, "===")
    for i, msg in enumerate(messages):
        print(f"Nachricht {i+1}: {msg}")

    # Erzeuge den Text-Prompt aus der gesamten Nachrichtenliste
    text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    print("\nErzeugter Text-Prompt für Beispiel", idx + 1, ":\n", text_prompt)

    # Verarbeite die Vision-Informationen (Bilder und Videos)
    image_inputs, video_inputs = process_vision_info(messages)

    # Erstelle die finalen Inputs für das Modell
    inputs = processor(
        text=[text_prompt],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Generiere die Antwort ohne Gradientenberechnung
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=128)
        # Entferne den Input-Teil, um nur die generierte Antwort zu erhalten
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

    generated_answer = output_text[0]
    print("Generierte Antwort für Beispiel", idx + 1, ":", generated_answer)
