# Training with resnet18

## Install Dependencies

In [None]:
# ! is used to run console commands in jupyter notebooks
!pip install -q nbstripout
!pip install torch-summary

## Import Dependencies

In [None]:
import pandas as pd

import numpy as np
import ast
import os
from sklearn.model_selection import train_test_split
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
from PIL import Image

import torch.optim as optim
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

from torchvision import models

## Set Variables

In [None]:
path = './KaggleCache/datasets/andrewmvd/ocular-disease-recognition-odir5k/versions/2'
df = pd.read_csv(os.path.join(path, 'full_df.csv'))

## Train Test Split

In [None]:
# --- 1. CONFIGURATION AND DATA ASSUMPTION ---

# Define the split ratios
TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
TEST_RATIO = 0.15
# The number of unique classes
NUM_CLASSES = 8

# --- 2. DATA PREPROCESSING: CONVERT TARGET TO CLASS INDEX ---

def target_string_to_index(target_str: str) -> int:
    """
    Converts a string representation of a one-hot list (e.g., '[0,1,0,...]')
    into a single integer class index (e.g., 1).
    """
    # Use ast.literal_eval for safe string-to-list conversion
    target_list = ast.literal_eval(target_str)
    # The index of '1' is the class index
    return target_list.index(1)

# Apply the conversion to create the necessary column for stratification
df['class_index'] = df['target'].apply(target_string_to_index)

# Print initial class distribution
print("--- Initial Class Distribution ---")
print(df['class_index'].value_counts().sort_index())
print("-" * 34)

# --- 3. STRATIFIED TRAIN / TEST / VAL SPLIT (70/15/15) ---

# Step 1: Split into Training (70%) and Temporary (30%) sets
train_df, temp_df = train_test_split(
    df,
    test_size=(VAL_RATIO + TEST_RATIO), # 0.15 + 0.15 = 0.30
    stratify=df['class_index'],
    random_state=42
)

# Step 2: Split Temporary (30%) into Validation (15%) and Test (15%) sets
# test_size = 0.5 because 0.5 of the remaining 0.30 is 0.15
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df['class_index'],
    random_state=42
)

# Verify the final split sizes
print("\n--- Final Dataset Sizes ---")
print(f"Total Samples: {len(df)}")
print(f"Training Samples (70%): {len(train_df)}")
print(f"Validation Samples (15%): {len(val_df)}")
print(f"Test Samples (15%): {len(test_df)}")

# --- 4. CALCULATE INVERSE CLASS FREQUENCY FOR WEIGHTED SAMPLER ---

# Count occurrences of each class in the training set
class_counts = Counter(train_df['class_index'])
# Get total number of samples in the training set
total_samples = len(train_df)
# Calculate the frequency of each class
class_frequencies = {i: class_counts.get(i, 0) / total_samples for i in range(NUM_CLASSES)}

# Calculate inverse frequency (or weight)
# The weight for a class is inversely proportional to its frequency: w_i = 1 / f_i
# We use this as the basis for the PyTorch WeightedRandomSampler
class_weights = {
    i: 1.0 / class_frequencies[i]
    for i in range(NUM_CLASSES) if class_frequencies[i] > 0
}

# Convert weights to a tensor (PyTorch requires this format)
# Note: PyTorch expects weights ordered by class index [w0, w1, w2, ...]
# Use max(class_weights.values()) for normalization, but absolute inverse frequency is fine too
inverse_weights = [class_weights.get(i, 0.0) for i in range(NUM_CLASSES)]
# Normalize the weights so the smallest weight is 1.0
max_weight = max(inverse_weights)
normalized_weights = [w / max_weight for w in inverse_weights]

# Print the final class weights for review
print("\n--- Training Set Class Weights (Normalized) ---")
print(f"Class Frequencies: {class_frequencies}")
print(f"Inverse Weights: {normalized_weights}")
# Store the weights as a numpy array for easy conversion to PyTorch tensor later
class_weights_np = np.array(normalized_weights, dtype=np.float32)

print("\n--- Stratification Check (Training Set) ---")
print(train_df['class_index'].value_counts(normalize=True).sort_index() * 100)
print("\n--- Stratification Check (Validation Set) ---")
print(val_df['class_index'].value_counts(normalize=True).sort_index() * 100)
print("\n--- Stratification Check (Test Set) ---")
print(test_df['class_index'].value_counts(normalize=True).sort_index() * 100)

## Focal Loss

In [None]:
# --- CUSTOM LOSS FUNCTION: FOCAL LOSS ---
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    """
    Focal Loss für Imbalanced Datasets.
    Reduziert den Loss für gut klassifizierte Beispiele (p > 0.5) und 
    fokussiert sich auf harte, falsch klassifizierte Beispiele.
    
    Formel: FL(p_t) = -alpha * (1 - p_t)^gamma * log(p_t)
    """
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha # Gewichtung der Klassen (optional)
        self.gamma = gamma # Focusing Parameter (Standard: 2.0)
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Standard Cross Entropy berechnen (ohne Reduction, damit wir gewichten können)
        ce_loss = F.cross_entropy(inputs, targets, weight=self.alpha, reduction='none')
        
        # p_t berechnen (Wahrscheinlichkeit der wahren Klasse)
        pt = torch.exp(-ce_loss)
        
        # Focal Term: (1 - pt)^gamma
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

## ResNet 18

In [None]:
## 1. Model Architecture: ResNet18 with Feature Fusion

from torchvision import models

class ResNet18WithSideInfo(nn.Module):
    """
    Hybrid Model: Combines a pre-trained ResNet18 backbone for image feature extraction
    with an explicit side-vector input (Left/Right eye encoding).
    """
    def __init__(self, num_classes=8):
        super(ResNet18WithSideInfo, self).__init__()
        
        # 1. Load ResNet18 Backbone (Pre-trained on ImageNet)
        # We use ResNet18 as it offers the best trade-off between performance and 
        # generalization for this dataset size (avoiding overfitting seen in ResNet50).
        print("Loading ResNet18 backbone (ImageNet weights)...")
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # 2. Remove the original classification head
        # ResNet18 outputs 512 features before the final layer
        self.num_ftrs = self.resnet.fc.in_features 
        self.resnet.fc = nn.Identity() # Passthrough layer
        
        # 3. Define Custom Head (Feature Fusion)
        # Input: 512 (Image Features) + 2 (Side Info One-Hot) -> Output: 8 Classes
        self.final_fc = nn.Linear(self.num_ftrs + 2, num_classes)

    def forward(self, image, side_vector):
        # Extract image features
        features = self.resnet(image)
        # Concatenate features with side information
        combined = torch.cat((features, side_vector), dim=1)
        # Final classification
        output = self.final_fc(combined)
        return output

# Initialization
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet18WithSideInfo(num_classes=8)
model.to(device)
print(f"ResNet18 Side-Aware Model initialized on {device}")

In [None]:
## 2. Data Pipeline: Custom Dataset & Mirroring Strategy

# --- SAMPLER SETUP ---
# Map class weights to individual samples to handle class imbalance
sample_weights = train_df['class_index'].apply(lambda x: class_weights_np[x]).values

# Create WeightedRandomSampler
train_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# --- TRANSFORMS ---
# Standard ImageNet normalization stats
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

# Training Transforms (High-Res 512x512)
train_transforms = transforms.Compose([
    transforms.Resize((512, 512)), 
    # NOTE: No RandomHorizontalFlip here! 
    # Mirroring is handled logically in the Dataset class to align optic disc position.
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(), 
    transforms.Normalize(mean=MEAN, std=STD)
])

# Validation/Test Transforms (Deterministic)
val_test_transforms = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

# --- CUSTOM DATASET CLASS ---
class OcularDatasetSideAware(Dataset):
    """
    Custom Dataset that handles image loading and side-specific preprocessing.
    Implements the 'Mirroring Trick': Right eyes are flipped to structurally resemble left eyes.
    """
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.root_dir = root_dir
        self.transform = transform
        self.image_dir = os.path.join(self.root_dir, "preprocessed_images")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['filename']
        label = row['class_index']
        
        # 1. Side Detection & Encoding
        is_right_eye = 'right' in img_name
        # One-Hot Encoding: [1, 0] for Left, [0, 1] for Right
        side_vector = torch.tensor([0.0, 1.0]) if is_right_eye else torch.tensor([1.0, 0.0])

        # 2. Load Image
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        # 3. Apply Mirroring Trick
        # Flip right eyes horizontally so optic disc is always on the same side (nasal)
        if is_right_eye:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
        
        # 4. Apply Transforms
        if self.transform:
            image = self.transform(image)
        
        return image, side_vector, torch.tensor(label, dtype=torch.long)

# --- DATALOADERS ---
# Note: BATCH_SIZE is set in the training configuration below.

train_dataset = OcularDatasetSideAware(train_df, path, transform=train_transforms)
val_dataset = OcularDatasetSideAware(val_df, path, transform=val_test_transforms)
test_dataset = OcularDatasetSideAware(test_df, path, transform=val_test_transforms)

In [None]:
## 3. Training Configuration & Execution

# --- HYPERPARAMETERS ---
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
NUM_EPOCHS = 30
BATCH_SIZE = 16

# --- LOADER INITIALIZATION ---
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# --- MODEL SETUP ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)


# MANUAL WEIGHTS
weights_modified = class_weights_np.copy()

# Index 5 = Hypertension.
# Wir sagen dem Modell: "Ein Fehler hier ist 5x so schlimm wie sonst!"
# Damit zwingen wir es, die 2 Treffer von eben zu reproduzieren, aber stabil.
weights_modified[5] = weights_modified[5] * 5.0
print("Manuelle Gewichte aktiv:", weights_modified)

class_weights_tensor = torch.tensor(class_weights_np, dtype=torch.float32).to(device)

criterion = FocalLoss(alpha=None, gamma=2.0).to(device)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True)

# --- TRAINING HELPER FUNCTIONS ---
def train_one_epoch_side(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, sides, labels in tqdm(dataloader, desc="Training"):
        images, sides, labels = images.to(device), sides.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images, sides)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    return running_loss / len(dataloader.dataset)

def validate_epoch_side(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, sides, labels in tqdm(dataloader, desc="Validation"):
            images, sides, labels = images.to(device), sides.to(device), labels.to(device)
            
            outputs = model(images, sides)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    return running_loss / len(dataloader.dataset), accuracy_score(all_labels, all_preds), f1_score(all_labels, all_preds, average='weighted', zero_division=0)

# --- MAIN TRAINING LOOP ---
history_train_loss = []
history_val_loss = []
history_val_acc = []
history_val_f1 = []
best_val_f1 = 0.0

print(f"\n--- Starting Optimized Training (ResNet18 + Side Aware + Focal Loss) ---")

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss = train_one_epoch_side(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, val_f1 = validate_epoch_side(model, val_loader, criterion, device)
    
    scheduler.step(val_f1)
    
    history_train_loss.append(train_loss)
    history_val_loss.append(val_loss)
    history_val_acc.append(val_acc)
    history_val_f1.append(val_f1)
    
    print(f"Epoch {epoch}: Train Loss {train_loss:.4f} | Val F1 {val_f1:.4f} | LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), 'best_resnet18_focal.pth')
        print("  --> Best model saved as 'best_resnet18_focal.pth'!")

print(f"Training Complete. Best Validation F1: {best_val_f1:.4f}")

In [None]:
## 4. Evaluation & Visualization

# --- CONFIGURATION ---
CLASS_NAMES = [
    "Normal", "Diabetic", "Glaucoma", "Cataract", 
    "Macular Deg.", "Hypertension", "Myopia", "Other"
]

# --- PLOTTING HELPERS ---
def plot_training_metrics(train_loss, val_loss, val_f1, val_acc):
    """Visualizes Loss and Metrics over epochs."""
    epochs = range(1, len(train_loss) + 1)
    plt.figure(figsize=(15, 5))

    # Loss Curve
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, 'bo-', label='Training Loss')
    plt.plot(epochs, val_loss, 'ro-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss (CrossEntropy)')
    plt.legend()
    plt.grid(True)

    # Metrics Curve
    plt.subplot(1, 2, 2)
    plt.plot(epochs, val_f1, 'go-', label='Validation F1-Score (Weighted)')
    plt.plot(epochs, val_acc, 'yo--', label='Validation Accuracy')
    plt.title('Validation Performance')
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(true_labels, predictions, class_names):
    """Plots a seaborn heatmap of the confusion matrix."""
    cm = confusion_matrix(true_labels, predictions, labels=np.arange(len(class_names)))
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix (Test Set)')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

def get_test_predictions_side_aware(model, dataloader, device):
    """Runs inference on the test set handling the dual-input requirement."""
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, sides, labels in tqdm(dataloader, desc="Testing"):
            images = images.to(device)
            sides = sides.to(device)
            outputs = model(images, sides)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return np.array(all_preds), np.array(all_labels)

# --- EXECUTION ---

# 1. Plot Training History
print("\n--- Plotting Training Metrics ---")
if 'history_train_loss' in locals() and len(history_train_loss) > 0:
    plot_training_metrics(history_train_loss, history_val_loss, history_val_f1, history_val_acc)
else:
    print("No training history found in memory.")

# 2. Evaluate on Unseen Test Set
print("\n--- Evaluating Best Model on Test Set ---")

# Load best weights
model.load_state_dict(torch.load('best_resnet18_focal.pth', weights_only=True))
model.to(device)

# Generate Predictions
test_preds, test_labels = get_test_predictions_side_aware(model, test_loader, device)

# Calculate Final Metrics
test_acc = accuracy_score(test_labels, test_preds)
test_f1 = f1_score(test_labels, test_preds, average='weighted', zero_division=0) 

print(f"\n=== FINAL TEST RESULTS (ResNet18) ===")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test F1-Score (Weighted): {test_f1:.4f}")

# Plot Confusion Matrix
plot_confusion_matrix(test_labels, test_preds, CLASS_NAMES)