# Training with resnet18

## Install Dependencies

In [None]:
# ! is used to run console commands in jupyter notebooks
!pip install -q nbstripout
!pip install torch-summary

## Import Dependencies

In [None]:
import pandas as pd

import numpy as np
import ast
import os
from sklearn.model_selection import train_test_split
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
from PIL import Image

import torch.optim as optim
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

from torchvision import models

## Set Variables

In [None]:
path = './KaggleCache/datasets/andrewmvd/ocular-disease-recognition-odir5k/versions/2'
df = pd.read_csv(os.path.join(path, 'full_df.csv'))

## Train Test Split

In [None]:
# --- 1. CONFIGURATION AND DATA ASSUMPTION ---

# Define the split ratios
TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
TEST_RATIO = 0.15
# The number of unique classes
NUM_CLASSES = 8

# --- 2. DATA PREPROCESSING: CONVERT TARGET TO CLASS INDEX ---

def target_string_to_index(target_str: str) -> int:
    """
    Converts a string representation of a one-hot list (e.g., '[0,1,0,...]')
    into a single integer class index (e.g., 1).
    """
    # Use ast.literal_eval for safe string-to-list conversion
    target_list = ast.literal_eval(target_str)
    # The index of '1' is the class index
    return target_list.index(1)

# Apply the conversion to create the necessary column for stratification
df['class_index'] = df['target'].apply(target_string_to_index)

# Print initial class distribution
print("--- Initial Class Distribution ---")
print(df['class_index'].value_counts().sort_index())
print("-" * 34)

# --- 3. STRATIFIED TRAIN / TEST / VAL SPLIT (70/15/15) ---

# Step 1: Split into Training (70%) and Temporary (30%) sets
train_df, temp_df = train_test_split(
    df,
    test_size=(VAL_RATIO + TEST_RATIO), # 0.15 + 0.15 = 0.30
    stratify=df['class_index'],
    random_state=42
)

# Step 2: Split Temporary (30%) into Validation (15%) and Test (15%) sets
# test_size = 0.5 because 0.5 of the remaining 0.30 is 0.15
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df['class_index'],
    random_state=42
)

# Verify the final split sizes
print("\n--- Final Dataset Sizes ---")
print(f"Total Samples: {len(df)}")
print(f"Training Samples (70%): {len(train_df)}")
print(f"Validation Samples (15%): {len(val_df)}")
print(f"Test Samples (15%): {len(test_df)}")

# --- 4. CALCULATE INVERSE CLASS FREQUENCY FOR WEIGHTED SAMPLER ---

# Count occurrences of each class in the training set
class_counts = Counter(train_df['class_index'])
# Get total number of samples in the training set
total_samples = len(train_df)
# Calculate the frequency of each class
class_frequencies = {i: class_counts.get(i, 0) / total_samples for i in range(NUM_CLASSES)}

# Calculate inverse frequency (or weight)
# The weight for a class is inversely proportional to its frequency: w_i = 1 / f_i
# We use this as the basis for the PyTorch WeightedRandomSampler
class_weights = {
    i: 1.0 / class_frequencies[i]
    for i in range(NUM_CLASSES) if class_frequencies[i] > 0
}

# Convert weights to a tensor (PyTorch requires this format)
# Note: PyTorch expects weights ordered by class index [w0, w1, w2, ...]
# Use max(class_weights.values()) for normalization, but absolute inverse frequency is fine too
inverse_weights = [class_weights.get(i, 0.0) for i in range(NUM_CLASSES)]
# Normalize the weights so the smallest weight is 1.0
max_weight = max(inverse_weights)
normalized_weights = [w / max_weight for w in inverse_weights]

# Print the final class weights for review
print("\n--- Training Set Class Weights (Normalized) ---")
print(f"Class Frequencies: {class_frequencies}")
print(f"Inverse Weights: {normalized_weights}")
# Store the weights as a numpy array for easy conversion to PyTorch tensor later
class_weights_np = np.array(normalized_weights, dtype=np.float32)

print("\n--- Stratification Check (Training Set) ---")
print(train_df['class_index'].value_counts(normalize=True).sort_index() * 100)
print("\n--- Stratification Check (Validation Set) ---")
print(val_df['class_index'].value_counts(normalize=True).sort_index() * 100)
print("\n--- Stratification Check (Test Set) ---")
print(test_df['class_index'].value_counts(normalize=True).sort_index() * 100)

## Train with resnet18

In [None]:
def build_ocular_resnet18(num_classes=8, feature_extract=True):
    """
    Constructs a ResNet18 model pre-trained on ImageNet and adapts the classification head
    for the specific ocular disease task.
    
    Args:
        num_classes (int): Number of output classes (target diseases).
        feature_extract (bool): If True, freezes the backbone weights to keep pre-trained knowledge.
    """
    print("Initializing Pre-trained ResNet18...")
    
    # 1. Load the pre-trained model with default ImageNet weights
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

    # 2. Freeze the parameters (weights) of the backbone
    # This prevents the pre-trained features from being destroyed during the initial training phase
    if feature_extract:
        for param in model.parameters():
            param.requires_grad = False

    # 3. Replace the Head (Fully Connected Layer)
    # The original ResNet fc layer takes 512 input features.
    # We replace it with a new layer that maps to our num_classes (8).
    # Note: New layers automatically have requires_grad=True.
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    
    return model

# --- MODEL INSTANTIATION ---

# Initialize the model
ocular_model = build_ocular_resnet18(num_classes=NUM_CLASSES, feature_extract=True)

# Move model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ocular_model.to(device)

# Sanity Check: Verify that only the last layer is trainable
print("\n--- Parameter Check ---")
total_params = sum(p.numel() for p in ocular_model.parameters())
trainable_params = sum(p.numel() for p in ocular_model.parameters() if p.requires_grad)
print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters (Head only): {trainable_params:,}")

## Data Pipeline

In [None]:
# --- 1. CONFIGURATION AND ASSUMPTIONS ---

# Batch size is a critical hyperparameter; 32 is a common starting point for 512x512
BATCH_SIZE = 32
# Assumed to be available from previous cells:
# train_df, val_df, test_df (DataFrames for splits)
# path (Root path string)
# class_weights_np (Numpy array of class weights for WeightedRandomSampler)

# --- 2. CUSTOM DATASET CLASS (The PyTorch Way) ---

class OcularDataset(Dataset):
    """
    Custom PyTorch Dataset for loading ocular fundus images and labels.
    Handles image loading, preprocessing, and string-to-index target parsing.
    """
    def __init__(self, df: pd.DataFrame, root_dir: str, transform=None):
        """
        Initializes the dataset.
        
        Args:
            df: DataFrame containing 'filename' and 'class_index'.
            root_dir: Base directory containing the 'preprocessed_images' folder.
            transform: Composed torchvision transforms to apply to the image.
        """
        self.df = df
        self.root_dir = root_dir
        self.transform = transform
        self.image_dir = os.path.join(self.root_dir, "preprocessed_images")

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.df)

    def __getitem__(self, idx):
        """
        Loads one sample (image and label) based on index.
        """
        # Get filename and class index from the DataFrame
        row = self.df.iloc[idx]
        img_name = row['filename']
        # Target is an integer class index (0-7), converted to LongTensor for CrossEntropyLoss
        label = row['class_index'] 
        
        # Construct the full image path
        img_path = os.path.join(self.image_dir, img_name)
        
        # Load the image using PIL (standard for torchvision transforms)
        image = Image.open(img_path).convert('RGB')
        
        # Apply transforms if provided
        if self.transform:
            image = self.transform(image)
        
        # Return the tensor and the label
        return image, torch.tensor(label, dtype=torch.long)

# --- 3. TRANSFORMS DEFINITION ---

# Standard normalization stats for ImageNet-trained models
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

# Training Transforms
train_transforms = transforms.Compose([
    # Resize to 224x224 (Standard for ResNet) to improve performance and speed
    transforms.Resize((224, 224)),
    # Augmentations for robustness
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    # Convert to Tensor and Normalize
    transforms.ToTensor(), 
    transforms.Normalize(mean=MEAN, std=STD)
])

# Validation/Test Transforms (Deterministic)
val_test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

# --- 4. DATASET INSTANTIATION ---

train_dataset = OcularDataset(train_df, path, transform=train_transforms)
val_dataset = OcularDataset(val_df, path, transform=val_test_transforms)
test_dataset = OcularDataset(test_df, path, transform=val_test_transforms)

# --- 5. DATA LOADER CONFIGURATION (Addressing Imbalance) ---

# Calculate sample weights for the Training DataLoader
# Map the class index of every sample in the training set to its inverse weight
sample_weights = train_df['class_index'].apply(lambda x: class_weights_np[x]).values
# Create the sampler: replaces the standard shuffling behavior
train_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights), # Sample size equals the full dataset length
    replacement=True # Must be True for random selection with replacement
)

# Instantiate the DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=train_sampler, # Use the custom sampler instead of 'shuffle=True'
    num_workers=4, # Use multiple threads for faster data loading (best practice)
    pin_memory=True # Speeds up data transfer to GPU
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False, # No shuffling needed for validation
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False, # No shuffling needed for testing
    num_workers=4,
    pin_memory=True
)

# --- 6. SANITY CHECK ---

print(f"\n--- DataLoader Sanity Check ---")
print(f"Train batches per epoch: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Check one batch to verify tensor shapes and device readiness
for images, labels in train_loader:
    print(f"\nSample Batch Test (Training Loader):")
    print(f"Image Tensor Shape (B x C x H x W): {images.shape}")
    print(f"Label Tensor Shape (B): {labels.shape}")
    print(f"Label Data Type (Should be torch.long): {labels.dtype}")
    break

## Training (resnet18)

In [None]:
# --- 1. CONFIGURATION AND ASSUMPTIONS ---

# Hyperparameters (initial estimates)
LEARNING_RATE = 1e-3
NUM_EPOCHS = 20
# REGULARIZATION FIXES for Overfitting and high complexity
# PATIENCE removed to disable Early Stopping, allowing full run
WEIGHT_DECAY = 1e-5    # L2 regularization to penalize large weights and reduce overfitting

# Assumed to be available from previous cells:
# ocular_model (Instance of OcularCNN, moved to 'device')
# train_loader, val_loader (DataLoaders)
# device (torch.device('cuda:0' or 'cpu'))
# NUM_CLASSES (int, 8)

# --- 2. INITIALIZATION ---

# Define the optimizer
# KEY CHANGE: We filter the parameters to only pass those with requires_grad=True (the new head).
# This prevents the optimizer from trying to update the frozen backbone weights.
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, ocular_model.parameters()), 
    lr=LEARNING_RATE, 
    weight_decay=WEIGHT_DECAY
)

# Define the loss function
criterion = nn.CrossEntropyLoss()
criterion.to(device)

# --- 3. TRAINING FUNCTION DEFINITION ---

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """
    Executes a single training epoch with explicit forward and backward passes.
    """
    model.train() # Set the model to training mode (enables dropout/batchnorm updates)
    running_loss = 0.0
    
    # Wrap the loader with tqdm for a progress bar
    for images, labels in tqdm(dataloader, desc="Training"):
        # 1. Device Transfer: Move data to the active device (GPU)
        images = images.to(device)
        labels = labels.to(device)
        
        # 2. Optimization Step 1: Zero the gradients
        # Crucial in PyTorch to prevent gradient accumulation from previous batches
        optimizer.zero_grad()
        
        # 3. Forward Pass: Compute model output (logits)
        outputs = model(images)
        
        # 4. Loss Calculation
        loss = criterion(outputs, labels)
        
        # 5. Optimization Step 2: Backward Pass (Backpropagation)
        # Compute gradient of the loss with respect to model parameters
        loss.backward()
        
        # 6. Optimization Step 3: Update Weights
        # Optimizer steps, adjusting parameters based on calculated gradients
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

# --- 4. VALIDATION FUNCTION DEFINITION ---

def validate_epoch(model, dataloader, criterion, device):
    """
    Evaluates the model on the validation set without updating weights.
    Also collects raw predictions and labels for confusion matrices.
    """
    model.eval() # Set the model to evaluation mode (disables dropout/batchnorm updates)
    running_loss = 0.0
    all_preds = []
    all_labels = []

    # Disable gradient calculations during evaluation
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation"):
            # Device Transfer
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward Pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)

            # Convert logits to predicted class indices
            _, predicted = torch.max(outputs, 1)
            
            # Store predictions and true labels for metric calculation on CPU
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(dataloader.dataset)
    
    # Calculate performance metrics
    accuracy = accuracy_score(all_labels, all_preds)
    # F1-score is important for imbalanced data; 'weighted' accounts for imbalance
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) 
    
    # Return metrics PLUS raw arrays needed for confusion matrix
    return epoch_loss, accuracy, f1, np.array(all_preds), np.array(all_labels)

# --- 5. MAIN TRAINING LOOP EXECUTION ---

print(f"\n--- Starting Training on {device} for {NUM_EPOCHS} epochs ---")

# ARCHITECT FIX: Re-assert model device to fix the RuntimeError 
# This ensures all model parameters are definitely on the GPU before training starts.
ocular_model.to(device)

best_val_f1 = 0.0

# Initialize lists to store metrics history for later plotting
history_train_loss = []
history_val_loss = []
history_val_acc = []
history_val_f1 = []

for epoch in range(1, NUM_EPOCHS + 1):
    # Train
    train_loss = train_one_epoch(ocular_model, train_loader, criterion, optimizer, device)
    
    # Validate (Note: validation now returns raw predictions and labels too)
    val_loss, val_acc, val_f1, _, _ = validate_epoch(ocular_model, val_loader, criterion, device)

    # Store history
    history_train_loss.append(train_loss)
    history_val_loss.append(val_loss)
    history_val_acc.append(val_acc)
    history_val_f1.append(val_f1)

    print(f"\nEpoch {epoch}/{NUM_EPOCHS}:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Validation Loss: {val_loss:.4f} | Accuracy: {val_acc:.4f} | F1-Score: {val_f1:.4f}")
    
    # Model Checkpointing Logic (Save only if improved)
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        # Save only the model's learned parameters (state_dict)
        torch.save(ocular_model.state_dict(), 'best_ocular_cnn.pth')
        print("  --> Model saved! New best F1-Score achieved.")

print("\n--- Training Complete ---")
print(f"Best Validation F1-Score: {best_val_f1:.4f}")

## Fine-Tuning

In [None]:
# --- STEP 2: FINE-TUNING ---

print("\n--- Start Fine-Tuning ---")

# 1. Load the best model from the previous step to ensure we start from the peak
# (This prevents starting from the overfitting/oscillating state of epoch 20)
ocular_model.load_state_dict(torch.load('best_ocular_cnn.pth', weights_only=True))

# 2. Unfreeze ALL layers
# Now we allow the updates to propagate back into the ResNet backbone
for param in ocular_model.parameters():
    param.requires_grad = True

# 3. Re-initialize Optimizer with a MUCH LOWER learning rate
# High LR would destroy the pre-trained features. We use 1e-5 (100x smaller).
FINE_TUNE_LR = 1e-5
optimizer = optim.AdamW(ocular_model.parameters(), lr=FINE_TUNE_LR, weight_decay=WEIGHT_DECAY)

# 4. Train for a few more epochs (e.g., 10-15)
NUM_FINETUNE_EPOCHS = 15
print(f"Fine-tuning for {NUM_FINETUNE_EPOCHS} epochs with LR={FINE_TUNE_LR}...")

# Reuse the same training loop logic
best_val_f1 = 0.4933 # Set current best to beat (from your log)

for epoch in range(1, NUM_FINETUNE_EPOCHS + 1):
    current_epoch = NUM_EPOCHS + epoch # Just for display (e.g., Epoch 21)
    
    # Train
    train_loss = train_one_epoch(ocular_model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc, val_f1, _, _ = validate_epoch(ocular_model, val_loader, criterion, device)

    # Append to history (so we can plot everything together later)
    history_train_loss.append(train_loss)
    history_val_loss.append(val_loss)
    history_val_acc.append(val_acc)
    history_val_f1.append(val_f1)

    print(f"\nEpoch {current_epoch}/{NUM_EPOCHS + NUM_FINETUNE_EPOCHS}:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Validation Loss: {val_loss:.4f} | Accuracy: {val_acc:.4f} | F1-Score: {val_f1:.4f}")
    
    # Save if improved
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(ocular_model.state_dict(), 'best_ocular_cnn_finetuned.pth')
        print("  --> Model saved! New best F1-Score achieved (Fine-Tuned).")

print(f"\n--- Fine-Tuning Complete. Best F1: {best_val_f1:.4f} ---")

## Plotting (after Fine-Tuning)

In [None]:
# --- 1. CONFIGURATION AND ASSUMPTIONS ---

# Assumed to be available from previous cells:
# ocular_model (Instance of OcularCNN)
# test_loader (DataLoader for the final, unseen data)
# device (torch.device('cuda:0' or 'cpu'))
# history_train_loss, history_val_loss, history_val_f1, history_val_acc (Metric lists from training)
# NUM_CLASSES (int, 8)

# Define class names for the confusion matrix display
# The order must match the class indices (0 to 7)
CLASS_NAMES = [
    "Normal", "Diabetic", "Glaucoma", "Cataract", 
    "Macular Deg.", "Retinal Detach.", "Hypertensive", "Other"
]

# --- 2. PLOTTING FUNCTIONS ---

def plot_training_metrics(train_loss, val_loss, val_f1, val_acc):
    """
    Plots the training and validation loss and F1-score/Accuracy over epochs.
    """
    epochs = range(1, len(train_loss) + 1)
    
    plt.figure(figsize=(15, 5))

    # Plot 1: Loss Curve
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, 'bo-', label='Training Loss')
    plt.plot(epochs, val_loss, 'ro-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss (CrossEntropy)')
    plt.legend()
    plt.grid(True)

    # Plot 2: F1 and Accuracy Scores
    plt.subplot(1, 2, 2)
    plt.plot(epochs, val_f1, 'go-', label='Validation F1-Score (Weighted)')
    plt.plot(epochs, val_acc, 'yo--', label='Validation Accuracy')
    plt.title('Validation Performance')
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

def get_test_predictions(model, dataloader, device):
    """
    Runs the model on the test set to collect all predictions and true labels.
    """
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Testing"):
            images = images.to(device)
            # Forward Pass
            outputs = model(images)
            # Get the predicted class index
            _, predicted = torch.max(outputs, 1)
            
            # Store predictions and true labels
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    return np.array(all_preds), np.array(all_labels)

def plot_confusion_matrix(true_labels, predictions, class_names):
    """
    Generates and plots the confusion matrix for final classification results.
    """
    # Calculate the raw confusion matrix
    cm = confusion_matrix(true_labels, predictions, labels=np.arange(len(class_names)))
    
    plt.figure(figsize=(10, 8))
    # Use heatmap for visualization with seaborn
    sns.heatmap(
        cm, 
        annot=True, 
        fmt='d', # 'd' formats the numbers as integers
        cmap='Blues', 
        xticklabels=class_names, 
        yticklabels=class_names
    )
    plt.title('Confusion Matrix (Test Set)')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
# --- 3. EXECUTION ---

# 1. Plot Training History
print("\n--- Plotting Training Metrics ---")
plot_training_metrics(
    history_train_loss, 
    history_val_loss, 
    history_val_f1, 
    history_val_acc
)

# 2. Evaluate on Test Set and Plot Confusion Matrix

# Load the best weights saved during training
# FIX: Added weights_only=True to comply with PyTorch best practice and remove Future Warning
#ocular_model.load_state_dict(torch.load('best_ocular_cnn.pth', weights_only=True))
# Ändere den Namen auf die _finetuned Version
ocular_model.load_state_dict(torch.load('best_ocular_cnn_finetuned.pth', weights_only=True))
ocular_model.to(device) # Ensure model is on device before running inference

# Get predictions and true labels from the unseen test set
test_preds, test_labels = get_test_predictions(ocular_model, test_loader, device)

# Calculate final test metrics
test_acc = accuracy_score(test_labels, test_preds)
test_f1 = f1_score(test_labels, test_preds, average='weighted', zero_division=0) 

print("\n--- Final Test Set Results ---")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test F1-Score (Weighted): {test_f1:.4f}")

# Plot the confusion matrix
plot_confusion_matrix(test_labels, test_preds, CLASS_NAMES)

## Training with 512x512 resolution

In [None]:
# --- STEP 3: HIGH-RESOLUTION TRAINING (512x512) ---

print("\n--- Setting up 512x512 Data Pipeline ---")

# 1. Config
# We reduce batch size because 512px images consume much more VRAM
# If you get a "CUDA Out of Memory" error, reduce this to 8 or 4.
HIGH_RES_BATCH_SIZE = 16 

# 2. Define High-Res Transforms
# We explicitly resize to 512 (or just skip resizing if images are already 512)
# to ensure consistency.
train_transforms_512 = transforms.Compose([
    transforms.Resize((512, 512)), 
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(), 
    transforms.Normalize(mean=MEAN, std=STD)
])

val_test_transforms_512 = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

# 3. Re-Initialize Datasets with new transforms
train_dataset_512 = OcularDataset(train_df, path, transform=train_transforms_512)
val_dataset_512 = OcularDataset(val_df, path, transform=val_test_transforms_512)
test_dataset_512 = OcularDataset(test_df, path, transform=val_test_transforms_512)

# 4. Re-Initialize Loaders
# We reuse the same weighted sampler logic for training
train_loader_512 = DataLoader(
    train_dataset_512,
    batch_size=HIGH_RES_BATCH_SIZE,
    sampler=train_sampler, # Reusing the sampler from before
    num_workers=4,
    pin_memory=True
)

val_loader_512 = DataLoader(
    val_dataset_512,
    batch_size=HIGH_RES_BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

test_loader_512 = DataLoader(
    test_dataset_512,
    batch_size=HIGH_RES_BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"High-Res Loaders ready. Batch Size: {HIGH_RES_BATCH_SIZE}")

In [None]:
print("\n--- Preparing Model for High-Res Fine-Tuning ---")

# 1. Load the best 224px model weights
# This is "Progressive Resizing": Start with knowledge from small images.
ocular_model.load_state_dict(torch.load('best_ocular_cnn_finetuned.pth', weights_only=True))
ocular_model.to(device)

# 2. Optimizer Setup
# We keep the learning rate low to carefully adapt to the new resolution
# without destroying previous knowledge.
LR_HIGH_RES = 1e-5 
optimizer_512 = optim.AdamW(ocular_model.parameters(), lr=LR_HIGH_RES, weight_decay=WEIGHT_DECAY)

print(f"Model loaded. Starting High-Res training with LR={LR_HIGH_RES}...")

In [None]:
# --- EXECUTE HIGH-RES TRAINING ---

NUM_EPOCHS_512 = 20 # Start with 10 epochs, it takes longer now!
best_val_f1_512 = 0.0

# New history lists for this phase
history_train_loss_512 = []
history_val_loss_512 = []
history_val_f1_512 = []
history_val_acc_512 = []

for epoch in range(1, NUM_EPOCHS_512 + 1):
    print(f"\nHigh-Res Epoch {epoch}/{NUM_EPOCHS_512}:")
    
    # Train
    train_loss = train_one_epoch(ocular_model, train_loader_512, criterion, optimizer_512, device)
    
    # Validate
    val_loss, val_acc, val_f1, _, _ = validate_epoch(ocular_model, val_loader_512, criterion, device)

    # Store history
    history_train_loss_512.append(train_loss)
    history_val_loss_512.append(val_loss)
    history_val_acc_512.append(val_acc)
    history_val_f1_512.append(val_f1)

    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Validation Loss: {val_loss:.4f} | Accuracy: {val_acc:.4f} | F1-Score: {val_f1:.4f}")
    
    # Save best 512 model
    if val_f1 > best_val_f1_512:
        best_val_f1_512 = val_f1
        torch.save(ocular_model.state_dict(), 'best_ocular_cnn_512.pth')
        print("  --> High-Res Model saved!")

print(f"\n--- High-Res Training Complete. Best F1: {best_val_f1_512:.4f} ---")

## Plotting after resizing to 512x512

In [None]:
# --- STEP 4: FINAL EVALUATION (HIGH-RES) ---

print("\n--- Running FINAL Evaluation on Test Set (512x512) ---")

# 1. Load the best HIGH-RES model
# Wichtig: Wir laden jetzt die .pth Datei, die im High-Res Loop gespeichert wurde
ocular_model.load_state_dict(torch.load('best_ocular_cnn_512.pth', weights_only=True))
ocular_model.to(device)
ocular_model.eval()

# 2. Get predictions using the HIGH-RES Test Loader
# Wir müssen hier zwingend 'test_loader_512' nutzen, sonst passen die Dimensionen nicht
test_preds, test_labels = get_test_predictions(ocular_model, test_loader_512, device)

# 3. Calculate Metrics
test_acc = accuracy_score(test_labels, test_preds)
test_f1 = f1_score(test_labels, test_preds, average='weighted', zero_division=0)

print(f"\n=== ULTIMATE TEST RESULTS ===")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test F1-Score (Weighted): {test_f1:.4f}")

# 4. Plot Confusion Matrix
# Das ist die wichtigste Grafik für euren Bericht!
plot_confusion_matrix(test_labels, test_preds, CLASS_NAMES)

# 5. Optional: Plot High-Res History
# Damit man sieht, wie sich das Modell in den letzten 20 Epochen stabilisiert hat
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(history_train_loss_512, 'bo-', label='Train Loss')
plt.plot(history_val_loss_512, 'ro-', label='Val Loss')
plt.title('High-Res Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history_val_f1_512, 'go-', label='Val F1')
plt.plot(history_val_acc_512, 'yo--', label='Val Accuracy')
plt.title('High-Res Validation Metrics')
plt.legend()
plt.grid(True)
plt.show()