# Hierarchisches Labeling

## Install Dependencies

In [None]:
# ! is used to run console commands in jupyter notebooks
!pip install -q nbstripout
!pip install torch-summary
!pip install grad-cam

## Import Dependencies

In [None]:
import pandas as pd

import numpy as np
import ast
import os
from sklearn.model_selection import train_test_split
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
from PIL import Image

import torch.optim as optim
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

from torchvision import models

import torchvision.transforms.functional as TF

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import cv2

## Set Variables

In [None]:
path = './KaggleCache/datasets/andrewmvd/ocular-disease-recognition-odir5k/versions/2'
df = pd.read_csv(os.path.join(path, 'full_df.csv'))

## Train Test Split

In [None]:
# --- 1. CONFIGURATION AND PREPROCESSING ---
import ast
from collections import Counter
from sklearn.model_selection import train_test_split

# Define the split ratios
TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
TEST_RATIO = 0.15

# --- HELPER: CONVERT TARGET TO CLASS INDEX ---
def target_string_to_index(target_str: str) -> int:
    """Converts '[0,1,0...]' into integer index."""
    target_list = ast.literal_eval(target_str)
    return target_list.index(1)

# 1. Apply conversion to original DF
df['class_index'] = df['target'].apply(target_string_to_index)

print("--- Initial Class Distribution (Original) ---")
print(df['class_index'].value_counts().sort_index())
print("-" * 34)


# --- 2. PREPARE STAGE 1: BINARY (Healthy vs. Sick) ---
print("\n=== STAGE 1 PREPARATION: BINARY ===")
df_binary = df.copy()
# Map: 0 remains 0, 1-7 become 1
df_binary['class_index'] = df_binary['class_index'].apply(lambda x: 0 if x == 0 else 1)

# Split Stage 1
train_df_binary, temp_df_binary = train_test_split(
    df_binary, test_size=(VAL_RATIO + TEST_RATIO), 
    stratify=df_binary['class_index'], random_state=42
)
val_df_binary, test_df_binary = train_test_split(
    temp_df_binary, test_size=0.5, 
    stratify=temp_df_binary['class_index'], random_state=42
)

# Calculate Weights for Stage 1 (Binary)
# Simple inverse ratio is enough here
count_0 = len(train_df_binary[train_df_binary['class_index'] == 0])
count_1 = len(train_df_binary[train_df_binary['class_index'] == 1])
# Gewicht für Klasse 0 und 1
weights_binary_np = np.array([1.0, count_0 / count_1], dtype=np.float32)
print(f"Binary Training Sizes: 0 (Normal)={count_0}, 1 (Sick)={count_1}")
print(f"Binary Weights: {weights_binary_np}")


# --- 3. PREPARE STAGE 2: SPECIALIST (Disease Only) ---
print("\n=== STAGE 2 PREPARATION: SPECIALIST ===")
# Filter: Keep only sick
df_disease = df[df['class_index'] != 0].copy()

# Shift Labels: 1->0, 2->1 ... 7->6
df_disease['class_index'] = df_disease['class_index'] - 1

# Split Stage 2
train_df_dis, temp_df_dis = train_test_split(
    df_disease, test_size=(VAL_RATIO + TEST_RATIO), 
    stratify=df_disease['class_index'], random_state=42
)
val_df_dis, test_df_dis = train_test_split(
    temp_df_dis, test_size=0.5, 
    stratify=temp_df_dis['class_index'], random_state=42
)

# Calculate Weights for Stage 2 (Inverse Frequency) - Dein alter Code, angepasst
class_counts = Counter(train_df_dis['class_index'])
total_samples = len(train_df_dis)
NUM_CLASSES_DIS = 7

class_frequencies = {i: class_counts.get(i, 0) / total_samples for i in range(NUM_CLASSES_DIS)}
class_weights = {i: 1.0 / class_frequencies[i] for i in range(NUM_CLASSES_DIS) if class_frequencies[i] > 0}
inverse_weights = [class_weights.get(i, 0.0) for i in range(NUM_CLASSES_DIS)]

# Normalize
max_weight = max(inverse_weights)
weights_specialist_np = np.array([w / max_weight for w in inverse_weights], dtype=np.float32)

print("Specialist Class Weights (Normalized):")
print(weights_specialist_np)
# Index 4 ist jetzt Hypertension, Index 5 ist Myopia (weil alles eins gerutscht ist)

## ResNet18 and ResNet50

In [None]:
# --- MODEL DEFINITIONS ---

class ResNet18WithSideInfo(nn.Module):
    def __init__(self, num_classes=8):
        super(ResNet18WithSideInfo, self).__init__()
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.num_ftrs = self.resnet.fc.in_features 
        self.resnet.fc = nn.Identity() 
        self.final_fc = nn.Linear(self.num_ftrs + 2, num_classes)

    def forward(self, image, side_vector):
        features = self.resnet(image)
        combined = torch.cat((features, side_vector), dim=1)
        output = self.final_fc(combined)
        return output

class ResNet50WithSideInfo(nn.Module):
    """
    Der große Bruder: Mehr Parameter (23M vs 11M), tieferes Netzwerk.
    Ideal für Stage 1 (Binary), da wir hier viele Daten haben.
    """
    def __init__(self, num_classes=2):
        super(ResNet50WithSideInfo, self).__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        
        # ResNet50 hat 2048 Features am Ende (ResNet18 hat nur 512)
        self.num_ftrs = self.resnet.fc.in_features 
        
        self.resnet.fc = nn.Identity() 
        self.final_fc = nn.Linear(self.num_ftrs + 2, num_classes)

    def forward(self, image, side_vector):
        features = self.resnet(image)
        combined = torch.cat((features, side_vector), dim=1)
        output = self.final_fc(combined)
        return output

## Data Pipeline

In [None]:
## 2. Data Pipeline: Custom Dataset & Mirroring Strategy

# --- TRANSFORMS ---
# Standard ImageNet normalization stats
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

# Training Transforms (High-Res 512x512)
train_transforms = transforms.Compose([
    transforms.Resize((512, 512)), 
    # NOTE: No RandomHorizontalFlip here! 
    # Mirroring is handled logically in the Dataset class to align optic disc position.
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(), 
    transforms.Normalize(mean=MEAN, std=STD)
])

# Validation/Test Transforms (Deterministic)
val_test_transforms = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

# --- CUSTOM DATASET CLASS ---
class OcularDatasetSideAware(Dataset):
    """
    Custom Dataset that handles image loading and side-specific preprocessing.
    Implements the 'Mirroring Trick': Right eyes are flipped to structurally resemble left eyes.
    """
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.root_dir = root_dir
        self.transform = transform
        self.image_dir = os.path.join(self.root_dir, "preprocessed_images")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['filename']
        label = row['class_index']
        
        # 1. Side Detection & Encoding
        is_right_eye = 'right' in img_name
        # One-Hot Encoding: [1, 0] for Left, [0, 1] for Right
        side_vector = torch.tensor([0.0, 1.0]) if is_right_eye else torch.tensor([1.0, 0.0])

        # 2. Load Image
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        # 3. Apply Mirroring Trick
        # Flip right eyes horizontally so optic disc is always on the same side (nasal)
        if is_right_eye:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
        
        # 4. Apply Transforms
        if self.transform:
            image = self.transform(image)
        
        return image, side_vector, torch.tensor(label, dtype=torch.long)

## Hierarchical Training

In [None]:
# --- 3. TRAINING CONFIGURATION & EXECUTION (HIERARCHICAL) ---

# --- GLOBAL SETTINGS ---
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16 

# --- HELPER FUNCTIONS ---

def train_one_epoch_side(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, sides, labels in tqdm(dataloader, desc="Training", leave=False):
        images, sides, labels = images.to(device), sides.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images, sides)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    return running_loss / len(dataloader.dataset)

def validate_epoch_side(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, sides, labels in tqdm(dataloader, desc="Validation", leave=False):
            images, sides, labels = images.to(device), sides.to(device), labels.to(device)
            
            outputs = model(images, sides)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    return running_loss / len(dataloader.dataset), accuracy_score(all_labels, all_preds), f1_score(all_labels, all_preds, average='weighted', zero_division=0)


# ==========================================
# STAGE 1: BINARY TRAINING (Healthy vs. Sick)
# ==========================================
print("\n" + "="*40 + "\n>>> STARTING STAGE 1: BINARY MODEL <<<\n" + "="*40)

# 1. Config Stage 1
NUM_EPOCHS_BINARY = 15
LR_BINARY = 1e-4

count_0 = len(train_df_binary[train_df_binary['class_index'] == 0])
count_1 = len(train_df_binary[train_df_binary['class_index'] == 1])
weights_binary = [1.0, count_0 / count_1] 
sample_weights_binary = train_df_binary['class_index'].apply(lambda x: weights_binary[x]).values
sampler_binary = WeightedRandomSampler(weights=sample_weights_binary, num_samples=len(sample_weights_binary), replacement=True)

# 2. Loaders
train_ds_binary = OcularDatasetSideAware(train_df_binary, path, transform=train_transforms)
val_ds_binary = OcularDatasetSideAware(val_df_binary, path, transform=val_test_transforms)
train_loader_binary = DataLoader(train_ds_binary, batch_size=16, sampler=sampler_binary, num_workers=4)
val_loader_binary = DataLoader(val_ds_binary, batch_size=16, shuffle=False, num_workers=4)

# 3. Model Setup
model_binary = ResNet50WithSideInfo(num_classes=2).to(DEVICE)

optimizer = optim.AdamW(model_binary.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss().to(DEVICE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3)

# Training Loop
best_f1 = 0.0
for epoch in range(1, 16): 
    model_binary.train()
    for img, side, lbl in tqdm(train_loader_binary, desc=f"Ep {epoch} Binary Train (R50)", leave=False):
        img, side, lbl = img.to(DEVICE), side.to(DEVICE), lbl.to(DEVICE)
        optimizer.zero_grad()
        loss = criterion(model_binary(img, side), lbl)
        loss.backward()
        optimizer.step()
    
    model_binary.eval()
    preds, targets = [], []
    with torch.no_grad():
        for img, side, lbl in val_loader_binary:
            img, side, lbl = img.to(DEVICE), side.to(DEVICE), lbl.to(DEVICE)
            out = model_binary(img, side)
            preds.extend(torch.argmax(out, 1).cpu().numpy())
            targets.extend(lbl.cpu().numpy())
            
    val_f1 = f1_score(targets, preds, average='weighted', zero_division=0)
    print(f"Binary (R50) Epoch {epoch}: Val F1 {val_f1:.4f}")
    scheduler.step(val_f1)
    
    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model_binary.state_dict(), 'best_binary_model_r50.pth')
        print("  --> Best Binary Model saved!")

print("Stage 1 Done. Saved 'best_binary_model_r50.pth'")


# ==========================================
# STAGE 2: SPECIALIST TRAINING (7 Diseases)
# ==========================================
print("\n" + "="*40 + "\n>>> STARTING STAGE 2: SPECIALIST MODEL <<<\n" + "="*40)

# 1. Config Stage 2
NUM_EPOCHS_DIS = 30
LR_DIS = 1e-4

# 2. Weights & Sampler Stage 2
# Berechnung der Basis-Gewichte
class_counts = Counter(train_df_dis['class_index'])
weights_dis_raw = {i: 1.0/class_counts[i] for i in range(7)}
max_w = max(weights_dis_raw.values())
weights_dis_norm = [weights_dis_raw[i]/max_w for i in range(7)]

# Sampler erstellen
sample_weights_dis = train_df_dis['class_index'].apply(lambda x: weights_dis_norm[x]).values
sampler_dis = WeightedRandomSampler(weights=sample_weights_dis, num_samples=len(sample_weights_dis), replacement=True)

# Loaders
train_loader_dis = DataLoader(OcularDatasetSideAware(train_df_dis, path, transform=train_transforms), batch_size=BATCH_SIZE, sampler=sampler_dis, num_workers=4)
val_loader_dis = DataLoader(OcularDatasetSideAware(val_df_dis, path, transform=val_test_transforms), batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# 3. Model Setup Stage 2 & LOSS GEWICHTE
model_dis = ResNet18WithSideInfo(num_classes=7).to(DEVICE)
optimizer = optim.AdamW(model_dis.parameters(), lr=LR_DIS, weight_decay=1e-5)

# HOLZHAMMER: Wir manipulieren die Loss-Gewichte manuell
loss_weights = np.array(weights_dis_norm, dtype=np.float32)
# Index 4 = Hypertension (verschoben von 5) -> x5 Boost
loss_weights[4] *= 5.0 
# Index 5 = Myopia (verschoben von 6) -> x3 Boost
loss_weights[5] *= 3.0
print("Specialist Loss Weights (inkl. Boost):", loss_weights)

criterion = nn.CrossEntropyLoss(weight=torch.tensor(loss_weights).to(DEVICE)).to(DEVICE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5)

# 4. Main Loop Stage 2
history_dis = {'train_loss': [], 'val_loss': [], 'val_f1': []}
best_val_f1 = 0.0

for epoch in range(1, NUM_EPOCHS_DIS + 1):
    train_loss = train_one_epoch_side(model_dis, train_loader_dis, criterion, optimizer, DEVICE)
    val_loss, val_acc, val_f1 = validate_epoch_side(model_dis, val_loader_dis, criterion, DEVICE)
    
    scheduler.step(val_f1)
    
    history_dis['train_loss'].append(train_loss)
    history_dis['val_loss'].append(val_loss)
    history_dis['val_f1'].append(val_f1)
    
    print(f"Specialist Epoch {epoch}: Train Loss {train_loss:.4f} | Val F1 {val_f1:.4f} | LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model_dis.state_dict(), 'best_specialist_model_r18.pth')
        print("  --> Best Specialist Model saved!")

print(f"Stage 2 Done. Best F1: {best_val_f1:.4f}")

## Evaluation

In [None]:
# --- 4. EVALUATION & VISUALIZATION (HIERARCHICAL) ---

# --- 1. CONFIGURATION ---
# WICHTIG: Korrekte Reihenfolge nach ODIR-Standard
CLASS_NAMES = [
    "Normal",           # 0
    "Diabetes",         # 1
    "Glaucoma",         # 2
    "Cataract",         # 3
    "Macular Deg.",     # 4
    "Hypertension",     # 5
    "Myopia",           # 6
    "Other"             # 7
]

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- 2. PLOTTING HELPERS ---
def plot_confusion_matrix(true_labels, predictions, class_names):
    """Plots a seaborn heatmap of the confusion matrix."""
    # Ensure labels match the length of class_names
    cm = confusion_matrix(true_labels, predictions, labels=np.arange(len(class_names)))
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm, 
        annot=True, 
        fmt='d', 
        cmap='Blues', 
        xticklabels=class_names, 
        yticklabels=class_names
    )
    plt.title('Hierarchical Model Confusion Matrix (Test Set)')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

# --- 3. HIERARCHICAL PREDICTION LOGIC ---
def predict_hierarchical(binary_model, specialist_model, dataloader, device):
    """
    Kombiniert Stage 1 (Binary) und Stage 2 (Specialist).
    """
    binary_model.eval()
    specialist_model.eval()
    
    final_preds = []
    final_labels = []
    
    with torch.no_grad():
        for images, sides, labels in tqdm(dataloader, desc="Hierarchical Inference"):
            images, sides = images.to(device), sides.to(device)
            
            # A: Stage 1 - Frag den Türsteher (Binary Model)
            binary_out = binary_model(images, sides)
            is_sick_preds = torch.argmax(binary_out, dim=1).cpu().numpy() # 0=Gesund, 1=Krank
            
            # B: Stage 2 - Frag den Spezialisten (Specialist Model)
            # Wir lassen ihn ALLE Bilder bewerten (technisch einfacher als rauspicken)
            spec_out = specialist_model(images, sides)
            disease_preds = torch.argmax(spec_out, dim=1).cpu().numpy() # 0-6 (Krankheiten)
            
            # C: Entscheidungs-Logik kombinieren
            batch_results = []
            for i in range(len(labels)):
                if is_sick_preds[i] == 0:
                    # Stage 1 sagt: Gesund
                    batch_results.append(0) # Label 0 = Normal
                else:
                    # Stage 1 sagt: Krank -> Wir vertrauen Stage 2
                    # Stage 2 Output (0-6) muss zurückgerechnet werden auf (1-7)
                    # Bsp: Stage 2 sagt 0 (Diabetes) -> Global Label 1 (Diabetes)
                    # Bsp: Stage 2 sagt 4 (Hypertension) -> Global Label 5 (Hypertension)
                    global_label = disease_preds[i] + 1
                    batch_results.append(global_label)
            
            final_preds.extend(batch_results)
            final_labels.extend(labels.numpy())
            
    return np.array(final_labels), np.array(final_preds)

# --- 4. EXECUTION ---

print("\n" + "="*40 + "\nSTARTING FINAL EVALUATION\n" + "="*40)

# A. Load Test Data (Original, unsplit Labels 0-7)
# Wir nutzen test_df_binary als Basis für die Indizes, laden aber die originalen Labels aus df
print("Loading Original Test Data...")
test_indices = test_df_binary.index
df_test_final = df.loc[test_indices].copy() 

# Dataset & Loader erstellen
test_ds_final = OcularDatasetSideAware(df_test_final, path, transform=val_test_transforms)
test_loader_final = DataLoader(test_ds_final, batch_size=16, shuffle=False, num_workers=4)

# 1. Load Models
# Stage 1: ResNet50 (Gesund vs. krank)
model_binary = ResNet50WithSideInfo(num_classes=2).to(DEVICE)
model_binary.load_state_dict(torch.load('best_binary_model_r50.pth', weights_only=True))

# Stage 2: ResNet18 (Krankheits-Klassen)
model_dis = ResNet18WithSideInfo(num_classes=7).to(DEVICE)
model_dis.load_state_dict(torch.load('best_specialist_model_r18.pth', weights_only=True))

# 2. TTA Prediction Function
def predict_hierarchical_tta(loader):
    model_binary.eval(); model_dis.eval()
    final_preds, final_labels = [], []
    
    with torch.no_grad():
        for img, side, lbl in tqdm(loader, desc="Inference with TTA"):
            img, side = img.to(DEVICE), side.to(DEVICE)
            
            # --- TTA LOGIC START ---
            
            # Variante 1: Original
            p_binary_1 = F.softmax(model_binary(img, side), dim=1)
            p_spec_1 = F.softmax(model_dis(img, side), dim=1)
            
            # Variante 2: Rotation +5 Grad
            img_rot_pos = TF.rotate(img, 5)
            p_binary_2 = F.softmax(model_binary(img_rot_pos, side), dim=1)
            p_spec_2 = F.softmax(model_dis(img_rot_pos, side), dim=1)

            # Variante 3: Rotation -5 Grad
            img_rot_neg = TF.rotate(img, -5)
            p_binary_3 = F.softmax(model_binary(img_rot_neg, side), dim=1)
            p_spec_3 = F.softmax(model_dis(img_rot_neg, side), dim=1)
            
            # DURCHSCHNITT BILDEN
            avg_binary = (p_binary_1 + p_binary_2 + p_binary_3) / 3.0
            avg_spec = (p_spec_1 + p_spec_2 + p_spec_3) / 3.0
            # --- TTA LOGIC END ---

            # Entscheidungen treffen (basierend auf Durchschnitt)
            is_sick = torch.argmax(avg_binary, 1).cpu().numpy() # 0/1
            spec_out = torch.argmax(avg_spec, 1).cpu().numpy() # 0-6
            
            # Kombinieren
            batch_preds = []
            for i in range(len(is_sick)):
                if is_sick[i] == 0:
                    batch_preds.append(0) # Normal
                else:
                    batch_preds.append(spec_out[i] + 1) # Shift +1
            
            final_preds.extend(batch_preds)
            final_labels.extend(lbl.numpy())
            
    return np.array(final_labels), np.array(final_preds)

# Run
y_true, y_pred = predict_hierarchical_tta(test_loader_final)

acc = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='weighted')

print(f"Final TTA Accuracy: {acc:.4f}")
print(f"Final TTA F1: {f1:.4f}")

CLASS_NAMES = ["Normal", "Diabetes", "Glaucoma", "Cataract", "Macular Deg.", "Hypertension", "Myopia", "Other"]
plot_confusion_matrix(y_true, y_pred, CLASS_NAMES)

## GradCAM

In [None]:
# 1. Wrapper, damit GradCAM mit den 2 Inputs (Bild + Seite) klarkommt
class ModelWrapper(torch.nn.Module):
    def __init__(self, model, side_vector):
        super().__init__()
        self.model = model
        self.side = side_vector.unsqueeze(0).to(DEVICE) # Batch dimension hinzufügen

    def forward(self, x):
        return self.model(x, self.side)

def visualize_gradcam(model, dataset, num_images):
    model.eval()
    
    # Letzter Convolutional Layer
    target_layers = [model.resnet.layer4[-1]]
    
    indices = np.where(df_test_final['class_index'] > 0)[0] # Nur Kranke
    selected_indices = np.random.choice(indices, num_images, replace=False)
    
    plt.figure(figsize=(15, 5*num_images))
    
    for i, idx in enumerate(selected_indices):
        img_tensor, side, label_idx = dataset[idx]
        
        # Originalbild für Anzeige (denormalisieren)
        inv_normalize = transforms.Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
            std=[1/0.229, 1/0.224, 1/0.225]
        )
        img_display = inv_normalize(img_tensor).permute(1, 2, 0).numpy()
        img_display = np.clip(img_display, 0, 1) # Sicherstellen, dass Pixel zwischen 0-1 sind
        
        # Wrapper erstellen (fixiert die Seite für dieses eine Bild)
        wrapped_model = ModelWrapper(model, side)
        
        # GradCAM initialisieren
        cam = GradCAM(model=wrapped_model, target_layers=target_layers)
        
        # Wir fragen: "Warum denkst du, ist das Klasse 1 (Krank)?"
        targets = [ClassifierOutputTarget(1)]
        
        # Heatmap generieren
        input_tensor = img_tensor.unsqueeze(0).to(DEVICE)
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]
        
        # Overlay erstellen
        visualization = show_cam_on_image(img_display, grayscale_cam, use_rgb=True)
        
        # Plotting
        true_label = CLASS_NAMES[label_idx]
        
        # Original
        ax = plt.subplot(num_images, 2, 2*i + 1)
        plt.imshow(img_display)
        plt.title(f"Original: {true_label}")
        plt.axis('off')
        
        # GradCAM
        ax = plt.subplot(num_images, 2, 2*i + 2)
        plt.imshow(visualization)
        plt.title(f"GradCAM (Focus for 'Sick')")
        plt.axis('off')
        
    plt.tight_layout()
    plt.show()

# Visualisierung von Stage 1 Modell (ResNet50), da dieses das Kranksein feststellt
print("Visualizing Stage 1 (ResNet50) Attention...")
visualize_gradcam(model_bin, test_ds_final, 10)