# Part 1: Environment, Config, and Data Pipeline

## 1. Environment Setup

In [None]:
# Mount Drive (Colab Only)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IS_COLAB = True
    print("Colab detected. Drive mounted.")
except ImportError:
    IS_COLAB = False
    print("Local environment detected.")

# Install dependencies
%pip install transformers torchview torchsummary


## 2. Imports & Configuration

In [None]:
import os
import random
import numpy as np
import torch
import pandas as pd
import cv2
import shutil
import math
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, f1_score

# --- Torch & Vision ---
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms.functional as TF
from torchvision import transforms

# --- Foundation Model Support ---
from transformers import AutoImageProcessor, AutoModel

# Set seeds for reproducibility
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Plotting Config
sns.set(font_scale=1.4)
sns.set_style('white')
plt.rc('font', size=14)

# Suppress warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

In [None]:
# --- CONFIGURATION ---

# Adjust these paths to match your specific folder structure
if 'IS_COLAB' in globals() and IS_COLAB:
    # Assuming standard Drive structure. ADJUST THIS PATH IF NEEDED.
    datasets_path = "/content/drive/MyDrive/AN2DL_CH_2/an2dl2526c2"
    print(f"Using Colab Path: {datasets_path}")
else:
    datasets_path = os.path.join(os.path.pardir, "an2dl2526c2") 
    print(f"Using Local Path: {datasets_path}")

train_data_path = os.path.join(datasets_path, "train_data")
train_labels_path = os.path.join(datasets_path, "train_labels.csv")

CSV_PATH = train_labels_path

# Output Directories
PATCHES_OUT = os.path.join(datasets_path, "preprocessing_results", "train_patches")
MASKS_DIR = os.path.join(datasets_path, "preprocessing_results", "train_patches")

# ImageNet Normalization (Used by Phikon)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

TARGET_SIZE = (224, 224)
BATCH_SIZE = 32 # Can increase to 64/128 with AMP on A100

print(f"Patches Directory: {PATCHES_OUT}")

## 3. Metadata Generation & Splitting

In [None]:
def create_metadata_dataframe(patches_dir, labels_csv_path):
    """
    Scans the patch directory and links patches to their WSI labels.
    """
    df_labels = pd.read_csv(labels_csv_path)
    id_col = df_labels.columns[0]
    label_col = df_labels.columns[1]
    
    # Standardize IDs
    df_labels[id_col] = df_labels[id_col].astype(str)
    df_labels[id_col] = df_labels[id_col].apply(lambda x: os.path.splitext(x)[0])
    
    # Scan patches
    patch_files = [f for f in os.listdir(patches_dir) if f.endswith('.png') and "mask" not in f]
    
    data = []
    print(f"Parsing {len(patch_files)} patches...")
    
    for filename in patch_files:
        # Naming convention: img_XXXX_pY.png
        try:
            if '_p' in filename:
                bag_id = filename.rsplit('_p', 1)[0]
            else:
                bag_id = os.path.splitext(filename)[0]
            
            data.append({
                'filename': filename,
                'sample_id': bag_id,
                'path': os.path.join(patches_dir, filename)
            })
        except Exception as e:
            continue
            
    df_patches = pd.DataFrame(data)
    
    # Merge
    df = pd.merge(df_patches, df_labels, left_on='sample_id', right_on=id_col)
    df = df[['filename', label_col, 'sample_id', 'path']]
    df = df.rename(columns={label_col: 'label'})
    
    return df

# Create Metadata
if os.path.exists(PATCHES_OUT):
    patches_metadata_df = create_metadata_dataframe(PATCHES_OUT, CSV_PATH)
    
    # Label Encoding
    label_encoder = LabelEncoder()
    patches_metadata_df['label_encoded'] = label_encoder.fit_transform(patches_metadata_df['label'])
    print(f"Classes: {label_encoder.classes_}")
    
    # Stratified Split (Image Level, not Patch Level)
    unique_samples = patches_metadata_df['sample_id'].unique()
    train_ids, val_ids = train_test_split(
        unique_samples, test_size=0.2, random_state=SEED, 
        stratify=patches_metadata_df.drop_duplicates('sample_id').set_index('sample_id').loc[unique_samples]['label']
    )
    
    df_train = patches_metadata_df[patches_metadata_df['sample_id'].isin(train_ids)].reset_index(drop=True)
    df_val = patches_metadata_df[patches_metadata_df['sample_id'].isin(val_ids)].reset_index(drop=True)
    
    print(f"Train Patches: {len(df_train)} | Val Patches: {len(df_val)}")
else:
    print("WARNING: Patches directory not found. Please run preprocessing first.")

## 4. Custom Dataset Class (The "God Mode" Pipeline)
This dataset handles the **Two-Stream** architecture requirements:
1.  **Synchronized Augmentations:** Ensures geometric transforms (flips, rotations) happen identically on both the RGB Image and the Mask.
2.  **Dual Output:** Returns `(image, mask, label)` tuples.

In [None]:
class TissueDataset(Dataset):
    def __init__(self, df, img_dir=None, masks_dir=None, augmentation=None, normalize_imagenet=True, target_size=(224, 224), label_col='label_encoded'):
        """
        Args:
            df: DataFrame containing metadata.
            img_dir: Root directory for images.
            masks_dir: Directory where masks are stored.
            augmentation: Boolean (True/False) to enable synchronized augmentation.
            normalize_imagenet: Apply ImageNet mean/std (Required for Phikon).
        """
        self.df = df
        self.masks_location = masks_dir
        self.do_augmentation = augmentation
        self.normalize_imagenet = normalize_imagenet
        self.target_size = target_size
        self.label_col = label_col

        # Normalization for the RGB Stream (Phikon)
        self.normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        
        # Color Jitter (Applied ONLY to Image, NOT Mask)
        self.color_jitter = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # --- A. Load Image ---
        img_path = row['path']
        image = Image.open(img_path).convert("RGB")
        
        # --- B. Load Mask ---
        if self.masks_location:
            img_filename = os.path.basename(img_path)
            mask_filename = img_filename.replace('img_', 'mask_') # Logic: img_123 -> mask_123
            mask_path = os.path.join(self.masks_location, mask_filename)
            
            if os.path.exists(mask_path):
                mask = Image.open(mask_path).convert("L")
            else:
                # Fallback: Black mask
                mask = Image.new('L', image.size, 0)
        else:
            mask = Image.new('L', image.size, 0)

        # --- C. Resize ---
        image = TF.resize(image, self.target_size)
        mask = TF.resize(mask, self.target_size, interpolation=transforms.InterpolationMode.NEAREST)

        # --- D. SYNCHRONIZED AUGMENTATION ---
        # Critical for "God Mode": The Mask must guide the Image, so they must align perfectly.
        if self.do_augmentation:
            # 1. Random Horizontal Flip
            if random.random() > 0.5:
                image = TF.hflip(image)
                mask = TF.hflip(mask)
            
            # 2. Random Vertical Flip
            if random.random() > 0.5:
                image = TF.vflip(image)
                mask = TF.vflip(mask)
            
            # 3. Random Rotation (Assumes rotation invariant tissue)
            angle = transforms.RandomRotation.get_params(degrees=[-15, 15])
            image = TF.rotate(image, angle)
            mask = TF.rotate(mask, angle)
            
            # 4. Color Jitter (Image Only)
            image = self.color_jitter(image)

        # --- E. Convert to Tensor ---
        img_tensor = TF.to_tensor(image)
        mask_tensor = TF.to_tensor(mask)

        # --- F. Stream-Specific Processing ---
        
        # Stream 1 (Phikon): Apply ImageNet Normalization
        if self.normalize_imagenet:
            img_tensor = self.normalize(img_tensor)
            
        # Stream 2 (Mask): Keep as float (0.0 to 1.0), do NOT normalize with ImageNet stats.
        # Ensure it is (1, H, W)
        
        # --- G. Label ---
        label = torch.tensor(row[self.label_col], dtype=torch.long)

        return img_tensor, label, mask_tensor

## 5. Data Loaders & Visualization

In [None]:
# Setup Datasets
train_dataset = TissueDataset(df_train, masks_dir=MASKS_DIR, augmentation=True, normalize_imagenet=True)
val_dataset = TissueDataset(df_val, masks_dir=MASKS_DIR, augmentation=False, normalize_imagenet=True)

# Setup Loaders
num_workers = os.cpu_count()

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=num_workers, pin_memory=True, persistent_workers=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
    num_workers=num_workers, pin_memory=True, persistent_workers=True
)

print(f"Train Batches: {len(train_loader)}")
print(f"Val Batches: {len(val_loader)}")

def show_batch(loader, count=4):
    batch = next(iter(loader))
    images, labels, masks = batch[0], batch[1], batch[2]
    
    # De-normalize for visualization
    mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
    std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
    
    plt.figure(figsize=(15, 8))
    
    for i in range(count):
        # Plot Image
        ax = plt.subplot(2, count, i + 1)
        img = images[i] * std + mean
        img = torch.clamp(img, 0, 1)
        plt.imshow(img.permute(1, 2, 0))
        plt.title(f"Label: {labels[i].item()}")
        plt.axis("off")
        
        # Plot Mask
        ax = plt.subplot(2, count, i + 1 + count)
        plt.imshow(masks[i].squeeze(), cmap="gray")
        plt.title("Mask")
        plt.axis("off")
        
    plt.tight_layout()
    plt.show()

show_batch(train_loader)

# Part 2: The "God Mode" Hybrid Architecture

We now construct the Dual-Stream model.
*   **Stream 1:** `owkin/phikon` (ViT) for RGB images.
*   **Stream 2:** Custom CNN for Binary Masks.
*   **Fusion:** Concatenation + MLP.

In [None]:
class MaskEncoder(nn.Module):
    """
    Pillar 2: The Guide.
    A lightweight CNN to extract geometric features from the binary mask.
    Input: (B, 1, 224, 224)
    Output: (B, 128)
    """
    def __init__(self):
        super(MaskEncoder, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), # 112x112
            
            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), # 56x56
            
            # Block 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)) # Global Average Pooling -> 1x1
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1) # Flatten: (B, 128)
        return x


class GodModeClassifier(nn.Module):
    """
    The Summit: Phikon + Geometry Fusion.
    """
    def __init__(self, num_classes=4, dropout_rate=0.3, freeze_backbone=True):
        super(GodModeClassifier, self).__init__()
        
        # --- Pillar 1: Phikon (Foundation Model) ---
        print("Loading Phikon (this may take a moment).... ")
        # We use owkin/phikon. It maps to a ViT architecture.
        self.phikon = AutoModel.from_pretrained("owkin/phikon")
        
        # Determine Phikon output dimension dynamically (usually 768 for Base, 1024 for Large)
        self.phikon_dim = self.phikon.config.hidden_size
        print(f"Phikon Loaded. Embedding Dimension: {self.phikon_dim}")
        
        # --- Pillar 2: Mask Encoder ---
        self.mask_encoder = MaskEncoder()
        self.mask_dim = 128
        
        # --- The Summit: Fusion Head ---
        fusion_dim = self.phikon_dim + self.mask_dim
        
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )
        
        # Initialize weights for custom heads
        self._init_weights(self.mask_encoder)
        self._init_weights(self.classifier)
        
        # Freeze Logic
        if freeze_backbone:
            self.freeze_phikon()

    def _init_weights(self, module):
        for m in module.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def freeze_phikon(self):
        print("Freezing Phikon Backbone...")
        for param in self.phikon.parameters():
            param.requires_grad = False
            
    def unfreeze_phikon_layers(self, n_layers=2):
        """
        Unfreezes the last n_layers of the ViT encoder for fine-tuning.
        """
        print(f"Unfreezing last {n_layers} layers of Phikon...")
        # Phikon (ViT) structure: phikon.encoder.layer is a ModuleList
        # We unlock the last N blocks and the LayerNorms
        
        # 1. Unfreeze the final LayerNorm (pooler)
        for param in self.phikon.pooler.parameters():
             param.requires_grad = True
        if hasattr(self.phikon, 'layernorm'):
             for param in self.phikon.layernorm.parameters():
                 param.requires_grad = True

        # 2. Unfreeze last N Transformer Blocks
        encoder_layers = self.phikon.encoder.layer
        for layer in encoder_layers[-n_layers:]:
            for param in layer.parameters():
                param.requires_grad = True

    def forward(self, img, mask):
        # 1. Pillar 1: Phikon
        # Phikon expects (B, 3, 224, 224). Returns object with last_hidden_state.
        # We utilize the CLS token (index 0)
        phikon_out = self.phikon(pixel_values=img)
        # Extract CLS token: (B, Hidden_Dim)
        img_features = phikon_out.last_hidden_state[:, 0, :]
        
        # 2. Pillar 2: Mask
        mask_features = self.mask_encoder(mask)
        
        # 3. Fusion
        combined = torch.cat((img_features, mask_features), dim=1)
        
        # 4. Classify
        logits = self.classifier(combined)
        return logits

In [None]:
# --- Instantiate Model ---
# Assuming label_encoder exists from Part 1
NUM_CLASSES = len(label_encoder.classes_)

model = GodModeClassifier(num_classes=NUM_CLASSES, freeze_backbone=True)
model = model.to(device)

# --- Stats Helper (from ResNet notebook) ---
def print_model_stats(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params: {total_params:,}")
    print(f"Trainable params: {trainable_params:,}")

print_model_stats(model)

## 6. Architecture Visualization

In [None]:
from torchview import draw_graph

# Create dummy inputs on CPU for visualization to save GPU memory
model_cpu = GodModeClassifier(num_classes=NUM_CLASSES, freeze_backbone=True).to("cpu")
dummy_img = torch.randn(1, 3, 224, 224)
dummy_mask = torch.randn(1, 1, 224, 224)

model_graph = draw_graph(
    model_cpu,
    input_data=(dummy_img, dummy_mask),
    device='cpu',
    expand_nested=True,
    depth=2 # Limit depth to avoid massive ViT graph explosion
)

model_graph.visual_graph

# Part 3: Training Phase 1 (Warmup)

## 7. Loss & Optimizer Configuration

We use **Focal Loss** to focus on hard-to-classify examples, which is crucial for Histopathology where "Luminal A" might dominate the dataset.

**Spec:** Gamma=2.0, RAdam Optimizer (LR=1e-4).

In [None]:
# 1. Calculate Class Weights (Inverse Frequency)
class_counts = df_train['label_encoded'].value_counts().sort_index().values
total_samples = sum(class_counts)
n_classes = len(class_counts)

weight_tensor = torch.tensor(
    [total_samples / (n_classes * c) for c in class_counts],
    dtype=torch.float32
).to(device)

print(f"Class Weights: {weight_tensor.cpu().numpy()}")

# 2. Define Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

criterion = FocalLoss(alpha=weight_tensor, gamma=2.0)
print("Criterion: Focal Loss (Gamma=2.0) with Class Weights enabled.")

## 8. Training & Validation Loops

In [None]:
from torch.cuda.amp import autocast, GradScaler

# Initialize Scaler for AMP
scaler = GradScaler()

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_targets = []

    loop = tqdm(loader, desc="Training", leave=False)

    for images, labels, masks in loop:
        images = images.to(device)
        masks = masks.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # --- Forward Pass (AMP Enabled) ---
        with autocast():
            logits = model(images, masks)
            loss = criterion(logits, labels)

        # --- Backward Pass (AMP Enabled) ---
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

        # Metrics
        _, predicted = torch.max(logits, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_targets.extend(labels.cpu().numpy())
        
        loop.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(loader)
    epoch_f1 = f1_score(all_targets, all_preds, average='macro')
    return epoch_loss, epoch_f1

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for images, labels, masks in loader:
            images = images.to(device)
            masks = masks.to(device)
            labels = labels.to(device)

            # Validate with AMP (optional but good practice for speed)
            with autocast():
                logits = model(images, masks)
                loss = criterion(logits, labels)

            running_loss += loss.item()
            _, predicted = torch.max(logits, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(loader)
    epoch_f1 = f1_score(all_targets, all_preds, average='macro')
    return epoch_loss, epoch_f1

## 9. Phase 1 Execution: Warmup
**Objective:** Train the Fusion Head and the Mask Encoder. The Phikon Backbone is **FROZEN**.
This prevents the massive gradients from the uninitialized head from wrecking the pre-trained Phikon weights.

In [None]:
# Configuration for Phase 1
NUM_EPOCHS_WARMUP = 10
LR_WARMUP = 1e-4

# Optimizer: RAdam as requested
# Note: Phikon params are frozen, so they won't be in model.parameters() where requires_grad=True
optimizer = torch.optim.RAdam(
    [p for p in model.parameters() if p.requires_grad],
    lr=LR_WARMUP,
    weight_decay=1e-4
)

history = {'train_loss': [], 'train_f1': [], 'val_loss': [], 'val_f1': []}
best_val_f1 = 0.0
# Ensure persistence on Drive
MODELS_DIR = os.path.join(datasets_path, "models") if 'IS_COLAB' in globals() and IS_COLAB else "models"
if not os.path.exists(MODELS_DIR):
    os.makedirs(MODELS_DIR)

model_path_warmup = os.path.join(MODELS_DIR, "best_godmode_warmup.pt")


print("--- Starting Phase 1: Warmup (Phikon Frozen) ---")

for epoch in range(NUM_EPOCHS_WARMUP):
    # Train
    train_loss, train_f1 = train_one_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_f1 = validate(model, val_loader, criterion, device)
    
    # History
    history['train_loss'].append(train_loss)
    history['train_f1'].append(train_f1)
    history['val_loss'].append(val_loss)
    history['val_f1'].append(val_f1)
    
    # Save Best
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), model_path_warmup)
        save_msg = "[Saved]"
    else:
        save_msg = ""
        
    print(f"Epoch {epoch+1}/{NUM_EPOCHS_WARMUP} | Train Loss: {train_loss:.4f} F1: {train_f1:.4f} | Val Loss: {val_loss:.4f} F1: {val_f1:.4f} {save_msg}")

print(f"Warmup Complete. Best Val F1: {best_val_f1:.4f}")

In [None]:
# Plot Warmup History
plt.figure(figsize=(20, 5))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Warmup Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_f1'], label='Train F1')
plt.plot(history['val_f1'], label='Val F1')
plt.title('Warmup F1 Score (Macro)')
plt.legend()

plt.show()

# Part 4: Training Phase 2 (Fine-Tuning)

## 10. Unfreezing & Config

Now that the Fusion Head is stable, we unlock the **Foundation Model**. 
We unfreeze the last **2 Transformer Blocks** of Phikon to adapt the high-level texture features to our specific histology dataset.

*   **Mixed Precision (AMP):** Enabled (Critical for A100/V100 memory management).
*   **LR:** Reduced to `1e-5` to preserve pre-trained knowledge.

In [None]:
# 1. Load Best Warmup Weights
print("Loading best model from Warmup Phase...")
model.load_state_dict(torch.load(model_path_warmup), strict=True)

# 2. Unfreeze Phikon Layers
model.unfreeze_phikon_layers(n_layers=2)
model.mask_encoder.requires_grad_(True) # Ensure mask encoder is also training

# 3. Verify Trainable Parameters
print_model_stats(model)

# 4. Configuration for Fine-Tuning
NUM_EPOCHS_FT = 30
LR_FT = 1e-5
PATIENCE = 8

# Optimizer: Re-initialize for new parameters
optimizer_ft = torch.optim.RAdam(
    [p for p in model.parameters() if p.requires_grad],
    lr=LR_FT,
    weight_decay=1e-4
)

# Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_ft, mode='max', factor=0.1, patience=3, verbose=True)

# Mixed Precision Scaler
scaler = torch.cuda.amp.GradScaler()
model_path_ft = "models/best_godmode_finetuned.pt"

In [None]:
def train_one_epoch_amp(model, loader, criterion, optimizer, device, scaler):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_targets = []

    loop = tqdm(loader, desc="Fine-Tuning", leave=False)

    for images, labels, masks in loop:
        images = images.to(device)
        masks = masks.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # --- Mixed Precision Forward --- 
        with torch.cuda.amp.autocast():
            logits = model(images, masks)
            loss = criterion(logits, labels)

        # --- Mixed Precision Backward ---
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

        # Metrics (Disable autocast for simple tensor ops to avoid overhead)
        _, predicted = torch.max(logits, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_targets.extend(labels.cpu().numpy())
        
        loop.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(loader)
    epoch_f1 = f1_score(all_targets, all_preds, average='macro')
    return epoch_loss, epoch_f1

In [None]:
print("--- Starting Phase 2: Fine-Tuning (Phikon Partial Unfreeze) ---")

# Reset history for FT phase visualization
ft_history = {'train_loss': [], 'train_f1': [], 'val_loss': [], 'val_f1': []}
best_val_f1_ft = best_val_f1 # Carry over best from warmup
patience_counter = 0

for epoch in range(NUM_EPOCHS_FT):
    # Train with AMP
    train_loss, train_f1 = train_one_epoch_amp(model, train_loader, criterion, optimizer_ft, device, scaler)
    
    # Validate (No scaler needed for inference, just standard)
    val_loss, val_f1 = validate(model, val_loader, criterion, device)
    
    # Scheduler Step
    scheduler.step(val_f1)
    
    # Update History
    ft_history['train_loss'].append(train_loss)
    ft_history['train_f1'].append(train_f1)
    ft_history['val_loss'].append(val_loss)
    ft_history['val_f1'].append(val_f1)
    
    # Checkpointing & Early Stopping
    if val_f1 > best_val_f1_ft:
        best_val_f1_ft = val_f1
        torch.save(model.state_dict(), model_path_ft)
        patience_counter = 0
        save_msg = "[Saved Best FT]"
    else:
        patience_counter += 1
        save_msg = ""
        
    print(f"FT Epoch {epoch+1}/{NUM_EPOCHS_FT} | Train Loss: {train_loss:.4f} F1: {train_f1:.4f} | Val Loss: {val_loss:.4f} F1: {val_f1:.4f} {save_msg}")
    
    if patience_counter >= PATIENCE:
        print("Early Stopping Triggered.")
        break

print(f"Fine-Tuning Complete. Best Val F1: {best_val_f1_ft:.4f}")

## 11. Training History Visualization

In [None]:
plt.figure(figsize=(20, 5))

# Combine histories for visualization
full_train_loss = history['train_loss'] + ft_history['train_loss']
full_val_loss = history['val_loss'] + ft_history['val_loss']
full_train_f1 = history['train_f1'] + ft_history['train_f1']
full_val_f1 = history['val_f1'] + ft_history['val_f1']

# Vertical line indicating phase switch
phase_switch = len(history['train_loss']) - 1

plt.subplot(1, 2, 1)
plt.plot(full_train_loss, label='Train Loss')
plt.plot(full_val_loss, label='Val Loss')
plt.axvline(x=phase_switch, color='r', linestyle='--', label='Start Fine-Tuning')
plt.title('Total Loss History')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(full_train_f1, label='Train F1')
plt.plot(full_val_f1, label='Val F1')
plt.axvline(x=phase_switch, color='r', linestyle='--', label='Start Fine-Tuning')
plt.title('Total F1 Score History')
plt.legend()

plt.show()

# Part 5: Evaluation and Submission

## 12. Evaluation (Confusion Matrix)
We evaluate the model on the Validation set to see per-class performance.

In [None]:
# 1. Load the Best Model (Fine-Tuned)
final_model = GodModeClassifier(num_classes=NUM_CLASSES, freeze_backbone=False).to(device)
final_model.load_state_dict(torch.load(model_path_ft), strict=True)
final_model.eval()
print("Final Model Loaded.")

# 2. Get Predictions
y_true = []
y_pred = []

print("Generating predictions for Confusion Matrix...")
with torch.no_grad():
    for images, labels, masks in tqdm(val_loader):
        images = images.to(device)
        masks = masks.to(device)
        
        logits = final_model(images, masks)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        
        y_true.extend(labels.numpy())
        y_pred.extend(preds)

# 3. Plot
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=label_encoder.classes_, 
            yticklabels=label_encoder.classes_)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Patch Level)')
plt.show()

## 13. Submission Generation

**Aggregation Strategy:** `Max Mean Confidence`.
1.  We collect all patch probabilities for a Bag (WSI).
2.  We calculate the **Mean** probability for each class across all patches.
3.  The class with the highest Mean probability is the WSI label.

In [None]:
from datetime import datetime

# Output configs
TEST_PATCHES_DIR = os.path.join(datasets_path, "preprocessing_results", "test_patches")
TEST_MASKS_DIR = os.path.join(datasets_path, "preprocessing_results", "test_patches") # Assuming same struct
SUBMISSION_DIR = os.path.join(os.path.pardir, "submission_csvs")
os.makedirs(SUBMISSION_DIR, exist_ok=True)

def generate_submission(model, submission_folder, masks_folder=None):
    model.eval()
    
    # 1. Scan Test Patches
    if not os.path.exists(submission_folder):
        print(f"Test folder not found: {submission_folder}")
        return
        
    patch_files = sorted([f for f in os.listdir(submission_folder) if f.lower().endswith('.png') and "mask" not in f])
    print(f"Found {len(patch_files)} test patches.")
    
    # 2. Setup Transforms (Standard Phikon normalization)
    val_transform = transforms.Compose([
        transforms.Resize(TARGET_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])
    
    mask_transform = transforms.Compose([
        transforms.Resize(TARGET_SIZE, interpolation=transforms.InterpolationMode.NEAREST),
        transforms.ToTensor()
    ])
    
    # 3. Inference Loop
    image_predictions = {}
    
    with torch.no_grad():
        for filename in tqdm(patch_files):
            filepath = os.path.join(submission_folder, filename)
            
            # Extract Sample ID
            if '_p' in filename:
                sample_id = filename.rsplit('_p', 1)[0]
            else:
                sample_id = os.path.splitext(filename)[0]
                
            if sample_id not in image_predictions:
                image_predictions[sample_id] = []
            
            # A. Load Image
            try:
                image = Image.open(filepath).convert('RGB')
                img_tensor = val_transform(image).unsqueeze(0).to(device)
                
                # B. Load Mask (Robust Fallback)
                mask_tensor = None
                if masks_folder:
                    mask_name = filename.replace('img_', 'mask_')
                    mask_path = os.path.join(masks_folder, mask_name)
                    if os.path.exists(mask_path):
                        mask = Image.open(mask_path).convert('L')
                        mask_tensor = mask_transform(mask).unsqueeze(0).to(device)
                
                if mask_tensor is None:
                    # Black mask fallback
                    mask_tensor = torch.zeros((1, 1, TARGET_SIZE[0], TARGET_SIZE[1])).to(device)
                
                # C. Predict
                logits = model(img_tensor, mask_tensor)
                probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
                image_predictions[sample_id].append(probs)
                
            except Exception as e:
                print(f"Error processing {filename}: {e}")
                continue

    # 4. Aggregation (Max Mean Confidence)
    final_results = []
    print("Aggregating results...")
    
    for sample_id, probs_list in image_predictions.items():
        if len(probs_list) == 0:
            # Fallback for empty bags (rare)
            pred_label = label_encoder.classes_[0]
        else:
            probs_array = np.array(probs_list)
            # Mean probability across all patches for each class
            avg_probs = np.mean(probs_array, axis=0)
            final_class_idx = np.argmax(avg_probs)
            pred_label = label_encoder.inverse_transform([final_class_idx])[0]
            
        final_results.append({
            'sample_index': f"{sample_id}.png", # Format requirement from sample_submission
            'label': pred_label
        })
        
    # 5. Save
    df_submission = pd.DataFrame(final_results)
    df_submission = df_submission.sort_values('sample_index')
    
    now = datetime.now()
    timestamp = now.strftime("%d_%b-%H_%M")
    out_path = os.path.join(SUBMISSION_DIR, f"submission_godmode_{timestamp}.csv")
    
    df_submission.to_csv(out_path, index=False)
    print(f"Submission saved to: {out_path}")
    return df_submission

# Run Generation (Ensure TEST_PATCHES_DIR is correct)
if os.path.exists(TEST_PATCHES_DIR):
    generate_submission(final_model, TEST_PATCHES_DIR, TEST_MASKS_DIR)
else:
    print("Test patches directory not found. Skipping submission generation.")

# Part 6: Explainability (God Mode CAM)

## 14. ViT-Compatible Class Activation Maps

We need a custom CAM implementation because our model has two distinct pillars with different architectures:
1.  **Phikon (ViT):** Requires reshaping sequence tokens `(B, 197, 768)` -> `(B, 14, 14, 768)` to get a spatial map.
2.  **Mask Encoder (CNN):** Uses standard spatial feature maps.

In [None]:
import matplotlib.gridspec as gridspec

class GodModeCAM:
    def __init__(self, model, target_layer_phikon, target_layer_mask):
        self.model = model
        self.gradients = {"phikon": None, "mask": None}
        self.activations = {"phikon": None, "mask": None}
        
        # Register Hooks
        # Phikon Hook
        target_layer_phikon.register_forward_hook(self.save_activation_phikon)
        target_layer_phikon.register_full_backward_hook(self.save_gradient_phikon)
        
        # Mask Hook
        target_layer_mask.register_forward_hook(self.save_activation_mask)
        target_layer_mask.register_full_backward_hook(self.save_gradient_mask)

    def save_activation_phikon(self, module, input, output):
        # ViT Output is usually a tuple, index 0 is hidden state
        if isinstance(output, tuple):
            self.activations["phikon"] = output[0]
        else:
            self.activations["phikon"] = output

    def save_gradient_phikon(self, module, grad_input, grad_output):
        self.gradients["phikon"] = grad_output[0]

    def save_activation_mask(self, module, input, output):
        self.activations["mask"] = output

    def save_gradient_mask(self, module, grad_input, grad_output):
        self.gradients["mask"] = grad_output[0]

    def generate_cam(self, img_tensor, mask_tensor, class_idx=None):
        # 1. Forward Pass
        self.model.zero_grad()
        logits = self.model(img_tensor, mask_tensor)
        
        if class_idx is None:
            class_idx = torch.argmax(logits)
            
        # 2. Backward Pass
        score = logits[0, class_idx]
        score.backward()
        
        # --- Process Phikon CAM (ViT Reshaping) ---
        grads_phi = self.gradients["phikon"]
        acts_phi = self.activations["phikon"]
        
        # Remove CLS token (index 0), keep spatial tokens (1-196)
        # Shape: (B, 197, 768) -> (B, 196, 768)
        grads_phi = grads_phi[:, 1:, :]
        acts_phi = acts_phi[:, 1:, :]
        
        # Weighted combination
        weights_phi = torch.mean(grads_phi, dim=1, keepdim=True)
        cam_phi = torch.sum(weights_phi * acts_phi, dim=2)
        
        # Reshape to square (14x14 for ViT-Base/Large)
        side = int(math.sqrt(cam_phi.shape[1])) # Should be 14
        cam_phi = cam_phi.view(1, 1, side, side)
        
        # Upsample to 224x224
        cam_phi = F.interpolate(cam_phi, size=(224, 224), mode='bilinear', align_corners=False)
        cam_phi = F.relu(cam_phi)

        # --- Process Mask CAM (Standard CNN) ---
        grads_mask = self.gradients["mask"]
        acts_mask = self.activations["mask"]
        
        weights_mask = torch.mean(grads_mask, dim=(2, 3), keepdim=True)
        cam_mask = torch.sum(weights_mask * acts_mask, dim=1, keepdim=True)
        cam_mask = F.interpolate(cam_mask, size=(224, 224), mode='bilinear', align_corners=False)
        cam_mask = F.relu(cam_mask)

        return cam_phi.detach().cpu().numpy()[0, 0], cam_mask.detach().cpu().numpy()[0, 0], logits, class_idx

# Setup CAM Hooks
# 1. Target last layer of Phikon Encoder (The LayerNorm before pooler or last block output)
# Structure: phikon.encoder.layer[-1].output
target_phikon = final_model.phikon.encoder.layer[-1].output

# 2. Target last conv layer of Mask Encoder
# Structure: mask_encoder.features[6] (The Conv2d(64, 128))
target_mask = final_model.mask_encoder.features[6]

god_cam = GodModeCAM(final_model, target_phikon, target_mask)

In [None]:
def visualize_god_mode_analysis(dataset, model, cam_engine, count=3):
    model.eval()
    indices = np.random.choice(len(dataset), count, replace=False)
    
    for idx in indices:
        # Load Data
        img_tensor, label, mask_tensor = dataset[idx]
        img_tensor = img_tensor.unsqueeze(0).to(device)
        mask_tensor = mask_tensor.unsqueeze(0).to(device)
        true_label = label_encoder.inverse_transform([label.item()])[0]
        
        # Generate CAMs
        cam_phi, cam_mask, logits, pred_idx = cam_engine.generate_cam(img_tensor, mask_tensor)
        pred_label = label_encoder.inverse_transform([pred_idx.item()])[0]
        
        # Prepare for Plotting
        # Denormalize Image
        mean = np.array(IMAGENET_MEAN).reshape(1, 1, 3)
        std = np.array(IMAGENET_STD).reshape(1, 1, 3)
        img_np = img_tensor.cpu().numpy().squeeze().transpose(1, 2, 0)
        img_np = np.clip(img_np * std + mean, 0, 1)
        
        # Normalize CAMs (0-1)
        cam_phi = (cam_phi - cam_phi.min()) / (cam_phi.max() + 1e-7)
        cam_mask = (cam_mask - cam_mask.min()) / (cam_mask.max() + 1e-7)
        
        # Create Heatmaps
        heatmap_phi = cv2.applyColorMap(np.uint8(255 * cam_phi), cv2.COLORMAP_JET)
        heatmap_phi = cv2.cvtColor(heatmap_phi, cv2.COLOR_BGR2RGB)
        overlay_phi = np.uint8(255 * img_np) * 0.5 + heatmap_phi * 0.5
        overlay_phi = overlay_phi / 255.0
        
        # Plot
        fig = plt.figure(figsize=(20, 6))
        gs = gridspec.GridSpec(1, 5, width_ratios=[1, 1, 1, 1, 0.2])
        
        # 1. Original Image
        ax0 = plt.subplot(gs[0])
        ax0.imshow(img_np)
        ax0.set_title(f"Original\nTrue: {true_label}")
        ax0.axis('off')
        
        # 2. Mask Input
        ax1 = plt.subplot(gs[1])
        ax1.imshow(mask_tensor.cpu().squeeze(), cmap='gray')
        ax1.set_title("Input Mask")
        ax1.axis('off')
        
        # 3. Phikon Focus (Where the Foundation Model looked)
        ax2 = plt.subplot(gs[2])
        ax2.imshow(overlay_phi)
        ax2.set_title(f"Phikon Attention\n(Texture Focus)")
        ax2.axis('off')
        
        # 4. Mask Focus (Where the CNN looked)
        ax3 = plt.subplot(gs[3])
        ax3.imshow(cam_mask, cmap='jet')
        ax3.set_title(f"Geometry Attention\n(Shape Focus)")
        ax3.axis('off')
        
        # 5. Prediction Bar
        ax4 = plt.subplot(gs[4])
        probs = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
        colors = ['gray'] * len(probs)
        colors[pred_idx] = 'green' if pred_idx == label else 'red'
        ax4.barh(label_encoder.classes_, probs, color=colors)
        ax4.set_xlim(0, 1)
        ax4.set_title(f"Pred: {pred_label}")
        ax4.invert_yaxis()
        
        plt.tight_layout()
        plt.show()

print("Visualizing God Mode Inference Analysis on Validation Set...")
visualize_god_mode_analysis(val_dataset, final_model, god_cam, count=5)

# End of Notebook
**Summary:**
1.  **Foundation:** Uses `owkin/phikon` (ViT-Large) for SOTA texture extraction.
2.  **Guidance:** Uses a parallel CNN to process binary masks explicitly.
3.  **Training:** Two-stage process (Warmup -> Fine-Tuning with AMP).
4.  **Explainability:** Custom ViT-Reshape CAM to visualize Foundation Model attention.

# Part 7 (Optional): Test Time Augmentation (TTA)

To squeeze the final drops of performance out of the model, we use TTA.
Instead of predicting on just the image, we predict on:
1.  Original
2.  Horizontal Flip
3.  Vertical Flip

We average the probabilities before making the final decision.

In [None]:
def predict_with_tta(model, img_tensor, mask_tensor):
    """
    Predicts using Original, H-Flip, and V-Flip.
    """
    model.eval()
    
    # List of augmentations (Original is identity)
    # Note: mask must be augmented identically to image!
    augments = [
        lambda x, m: (x, m),                                      # Original
        lambda x, m: (TF.hflip(x), TF.hflip(m)),                  # H-Flip
        lambda x, m: (TF.vflip(x), TF.vflip(m)),                  # V-Flip
    ]
    
    probs_sum = None
    
    for aug_func in augments:
        img_aug, mask_aug = aug_func(img_tensor, mask_tensor)
        
        with torch.no_grad():
            logits = model(img_aug, mask_aug)
            probs = torch.softmax(logits, dim=1)
            
        if probs_sum is None:
            probs_sum = probs
        else:
            probs_sum += probs
            
    # Average
    return (probs_sum / len(augments)).cpu().numpy()[0]

def generate_submission_tta(model, submission_folder, masks_folder=None):
    print("--- Generating Submission with TTA (3x Slower, Better Accuracy) ---")
    model.eval()
    
    patch_files = sorted([f for f in os.listdir(submission_folder) if f.lower().endswith('.png') and "mask" not in f])
    
    # Transforms
    val_transform = transforms.Compose([
        transforms.Resize(TARGET_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])
    
    mask_transform = transforms.Compose([
        transforms.Resize(TARGET_SIZE, interpolation=transforms.InterpolationMode.NEAREST),
        transforms.ToTensor()
    ])
    
    image_predictions = {}
    
    for filename in tqdm(patch_files):
        filepath = os.path.join(submission_folder, filename)
        
        if '_p' in filename:
            sample_id = filename.rsplit('_p', 1)[0]
        else:
            sample_id = os.path.splitext(filename)[0]
            
        if sample_id not in image_predictions:
            image_predictions[sample_id] = []
        
        try:
            # Load & Prep
            image = Image.open(filepath).convert('RGB')
            img_tensor = val_transform(image).unsqueeze(0).to(device)
            
            mask_tensor = None
            if masks_folder:
                mask_name = filename.replace('img_', 'mask_')
                mask_path = os.path.join(masks_folder, mask_name)
                if os.path.exists(mask_path):
                    mask = Image.open(mask_path).convert('L')
                    mask_tensor = mask_transform(mask).unsqueeze(0).to(device)
            
            if mask_tensor is None:
                mask_tensor = torch.zeros((1, 1, TARGET_SIZE[0], TARGET_SIZE[1])).to(device)
            
            # Predict with TTA
            probs = predict_with_tta(model, img_tensor, mask_tensor)
            image_predictions[sample_id].append(probs)
            
        except Exception as e:
            continue

    # Aggregate
    final_results = []
    for sample_id, probs_list in image_predictions.items():
        if len(probs_list) == 0:
            pred_label = label_encoder.classes_[0]
        else:
            probs_array = np.array(probs_list)
            avg_probs = np.mean(probs_array, axis=0)
            final_class_idx = np.argmax(avg_probs)
            pred_label = label_encoder.inverse_transform([final_class_idx])[0]
            
        final_results.append({'sample_index': f"{sample_id}.png", 'label': pred_label})
        
    df_submission = pd.DataFrame(final_results)
    df_submission = df_submission.sort_values('sample_index')
    
    now = datetime.now()
    timestamp = now.strftime("%d_%b-%H_%M")
    out_path = os.path.join(SUBMISSION_DIR, f"submission_godmode_TTA_{timestamp}.csv")
    
    df_submission.to_csv(out_path, index=False)
    print(f"TTA Submission saved to: {out_path}")

# Uncomment to run TTA Submission
# if os.path.exists(TEST_PATCHES_DIR):
#     generate_submission_tta(final_model, TEST_PATCHES_DIR, TEST_MASKS_DIR)