In [None]:
# install necessary packages
%pip install gdown lightning lion-pytorch wandb torchstain[torch]
# uninstall tensorflow to avoid conflicts
%pip uninstall -y tensorflow keras tensorboard

## Get data


In [None]:
!gdown https://drive.google.com/uc?id=1fZG_-FwADGMI_TqbPaA8z10ZMGPvb_vG
!unzip -q data.zip
!mv an2dl2526c2v2 data
!ls -l data

In [None]:
# get the trash list
!gdown https://drive.google.com/uc?id=1D8tGrxR4oHZmOxKvINNQR5x-zI6MavVP
!mv trash_list.txt data/trash_list.txt

In [None]:
!ls -l

## Import Libraries


In [None]:
import random
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import cv2
import lightning as L
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lion_pytorch import Lion
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset
from torchmetrics import AUROC, Accuracy, ConfusionMatrix, F1Score
from torchvision import transforms

L.seed_everything(42)

## Utilities


### Config


In [None]:
class Config:
    # Data paths
    DATA_DIR = "./data"
    TRAIN_DATA_DIR = "./data/train_data"
    TEST_DATA_DIR = "./data/test_data"
    TRAIN_LABELS_PATH = "./data/train_labels.csv"
    # Class labels
    CLASSES = ["Luminal A", "Luminal B", "HER2(+)", "Triple negative"]
    NUM_CLASSES = 4

### Extract tissue


In [None]:
import warnings


class TissueExtractor:
    """
    Extract patches from images centered around cancer point annotations.
    Designed for workflow where annotations mark cancer locations, and we want
    to extract surrounding tissue context.
    """

    def __init__(self, patch_size: int = 224, min_annotation_pixels: int = 1):
        """
        Args:
            patch_size: Size of square patches to extract.
            min_annotation_pixels: Minimum number of annotation pixels required in patch.
        """
        self.patch_size = patch_size
        self.min_annotation_pixels = min_annotation_pixels

    def _validate_inputs(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
        """Validate inputs and return processed mask."""
        if img.shape[:2] != mask.shape[:2]:
            raise ValueError(
                f"Image shape {img.shape[:2]} doesn't match mask shape {mask.shape[:2]}"
            )

        h, w = img.shape[:2]
        if h < self.patch_size or w < self.patch_size:
            raise ValueError(
                f"Image dimensions ({h}, {w}) smaller than patch_size ({self.patch_size})"
            )

        if len(mask.shape) == 3:
            mask = mask[:, :, 0]

        if mask.dtype in [np.float32, np.float64]:
            warnings.warn("Float mask detected, thresholding at 0.5")
            mask = (mask > 0.5).astype(np.uint8)

        return mask

    def get_valid_patches(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        num_patches: int = 8,
        strategy: str = "random",
        stride: Optional[int] = None,
        shuffle: bool = True,
        min_distance: int = None,
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Extract tissue patches centered around cancer annotations in mask.

        Args:
            img: RGB image (H, W, 3) - the full tissue image
            mask: Annotation mask (H, W) - cancer point annotations (sparse)
            num_patches: Number of patches to extract per image.
            strategy: 'random' samples from annotation points; 'grid' finds patches containing annotations.
            stride: Step size for grid strategy. Defaults to patch_size (no overlap).
            shuffle: Whether to shuffle grid patches before selecting.
            min_distance: Minimum pixel distance between patch centers (random strategy).

        Returns:
            images: List of RGB patches (surrounding tissue context)
            masks: List of corresponding annotation patches (sparse cancer markers)
        """
        mask = self._validate_inputs(img, mask)
        h, w = img.shape[:2]

        # Find all annotation pixel indices
        annotation_indices = np.where(mask > 0)

        if len(annotation_indices[0]) == 0:
            warnings.warn("No annotations found in mask!")
            return [], []

        if strategy == "random":
            patches_img, patches_mask = self._extract_random(
                img, mask, annotation_indices, num_patches, h, w, min_distance
            )
        elif strategy == "grid":
            patches_img, patches_mask = self._extract_grid(
                img, mask, num_patches, h, w, stride, shuffle
            )
        else:
            raise ValueError(f"Unknown strategy: {strategy}. Use 'random' or 'grid'.")

        if len(patches_img) < num_patches:
            warnings.warn(
                f"Only {len(patches_img)} patches extracted (requested {num_patches})"
            )

        return patches_img, patches_mask

    def _extract_random(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        annotation_indices: Tuple[np.ndarray, np.ndarray],
        num_patches: int,
        h: int,
        w: int,
        min_distance: Optional[int] = None,
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Random sampling: center patches on annotation points to capture surrounding tissue.

        Args:
            img: RGB image (H, W, 3)
            mask: Annotation mask (H, W)
            annotation_indices: Tuple of arrays with y and x indices of annotation pixels.
            num_patches: Number of patches to extract.
            h: Height of the image.
            w: Width of the image.
            min_distance: Minimum distance between patch centers.
        """
        patches_img = []
        patches_mask = []
        selected_centers = []

        if min_distance is None:
            min_distance = self.patch_size // 2

        attempts = 0
        max_attempts = num_patches * 100

        while len(patches_img) < num_patches and attempts < max_attempts:
            attempts += 1

            # Sample a random annotation point
            idx = np.random.randint(len(annotation_indices[0]))
            cy, cx = annotation_indices[0][idx], annotation_indices[1][idx]

            # Check minimum distance from existing patches
            if min_distance > 0 and selected_centers:
                too_close = False
                for prev_cy, prev_cx in selected_centers:
                    dist = np.sqrt((cy - prev_cy) ** 2 + (cx - prev_cx) ** 2)
                    if dist < min_distance:
                        too_close = True
                        break
                if too_close:
                    continue

            # Calculate patch bounds centered on annotation point
            half_size = self.patch_size // 2
            y_min = cy - half_size
            x_min = cx - half_size
            y_max = y_min + self.patch_size
            x_max = x_min + self.patch_size

            # Boundary check
            if y_min < 0 or x_min < 0 or y_max > h or x_max > w:
                continue

            # Extract patches - image contains surrounding tissue, mask contains annotation
            img_patch = img[y_min:y_max, x_min:x_max]
            mask_patch = mask[y_min:y_max, x_min:x_max]

            # Check minimum annotation pixels (at least some annotation in patch)
            if np.count_nonzero(mask_patch) >= self.min_annotation_pixels:
                patches_img.append(img_patch)
                patches_mask.append(mask_patch)
                selected_centers.append((cy, cx))

        return patches_img, patches_mask

    def _extract_grid(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        num_patches: int,
        h: int,
        w: int,
        stride: Optional[int] = None,
        shuffle: bool = True,
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Grid strategy: find patches that contain annotation points.
        Prioritizes patches with more annotation pixels.

        Args:
            img: RGB image (H, W, 3)
            mask: Annotation mask (H, W)
            num_patches: Number of patches to extract.
            h: Height of the image.
            w: Width of the image.
            stride: Step size for grid sampling. Defaults to patch_size (no overlap).
            shuffle: Whether to shuffle valid patches before selection.
        """
        if stride is None:
            stride = self.patch_size

        if stride <= 0:
            raise ValueError(f"Stride must be positive, got {stride}")

        # Find annotation bounding box to focus search
        annotation_rows = np.any(mask > 0, axis=1)
        annotation_cols = np.any(mask > 0, axis=0)

        if not annotation_rows.any() or not annotation_cols.any():
            return [], []

        y_min_ann, y_max_ann = np.where(annotation_rows)[0][[0, -1]]
        x_min_ann, x_max_ann = np.where(annotation_cols)[0][[0, -1]]

        # Expand search region to capture surrounding tissue
        padding = self.patch_size
        y_start = max(0, y_min_ann - padding)
        y_end = min(h, y_max_ann + padding)
        x_start = max(0, x_min_ann - padding)
        x_end = min(w, x_max_ann + padding)

        y_positions = list(range(y_start, y_end - self.patch_size + 1, stride))
        x_positions = list(range(x_start, x_end - self.patch_size + 1, stride))

        valid_patches = []

        for y_min in y_positions:
            for x_min in x_positions:
                y_max = y_min + self.patch_size
                x_max = x_min + self.patch_size

                mask_patch = mask[y_min:y_max, x_min:x_max]
                annotation_count = np.count_nonzero(mask_patch)

                # Only include patches that contain annotations
                if annotation_count >= self.min_annotation_pixels:
                    valid_patches.append((y_min, x_min, annotation_count))

        if shuffle:
            np.random.shuffle(valid_patches)
        else:
            # Sort by annotation count (highest first) for deterministic selection
            valid_patches.sort(key=lambda x: x[2], reverse=True)

        patches_img = []
        patches_mask = []

        for y_min, x_min, _ in valid_patches[:num_patches]:
            y_max = y_min + self.patch_size
            x_max = x_min + self.patch_size

            patches_img.append(img[y_min:y_max, x_min:x_max])
            patches_mask.append(mask[y_min:y_max, x_min:x_max])

        return patches_img, patches_mask

    def get_all_valid_patches(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        stride: Optional[int] = None,
    ) -> Tuple[List[np.ndarray], List[np.ndarray], List[Tuple[int, int]]]:
        """Extract ALL patches containing annotations from the image."""

        mask = self._validate_inputs(img, mask)
        h, w = img.shape[:2]

        if stride is None:
            stride = self.patch_size

        # Focus on region around annotations
        annotation_rows = np.any(mask > 0, axis=1)
        annotation_cols = np.any(mask > 0, axis=0)

        if not annotation_rows.any() or not annotation_cols.any():
            return [], [], []

        y_min_ann, y_max_ann = np.where(annotation_rows)[0][[0, -1]]
        x_min_ann, x_max_ann = np.where(annotation_cols)[0][[0, -1]]

        padding = self.patch_size
        y_start = max(0, y_min_ann - padding)
        y_end = min(h, y_max_ann + padding)
        x_start = max(0, x_min_ann - padding)
        x_end = min(w, x_max_ann + padding)

        y_positions = list(range(y_start, y_end - self.patch_size + 1, stride))
        x_positions = list(range(x_start, x_end - self.patch_size + 1, stride))

        patches_img = []
        patches_mask = []
        coordinates = []

        for y_min in y_positions:
            for x_min in x_positions:
                y_max = y_min + self.patch_size
                x_max = x_min + self.patch_size

                mask_patch = mask[y_min:y_max, x_min:x_max]
                annotation_count = np.count_nonzero(mask_patch)

                if annotation_count >= self.min_annotation_pixels:
                    patches_img.append(img[y_min:y_max, x_min:x_max])
                    patches_mask.append(mask_patch)
                    coordinates.append((y_min, x_min))

        return patches_img, patches_mask, coordinates

## Custom dataset


In [None]:
class PathologyDataset(Dataset):
    """Dataset optimized for histopathology images."""

    def __init__(
        self,
        data_dir: str,
        labels_df: Optional[pd.DataFrame] = None,
        transform: Optional[transforms.Compose] = None,
        use_mask: bool = True,
        use_patches: bool = False,
        patch_size: int = 224,
        num_patches: int = 8,
        patch_strategy: str = "random",
        min_annotation_pixels: int = 100,
        is_test: bool = False,
        label_encoder: Optional[LabelEncoder] = None,
        use_dual_stream: bool = False,
    ):
        """
        Args:
            data_dir: Directory with images and masks.
            labels_df: DataFrame with 'sample_index' and 'label' columns (None for test).
            transform: torchvision transforms to apply to images.
            use_mask: Whether to load and use masks.
            use_patches: Whether to extract patches or use full images.
            patch_size: Size of square patches to extract.
            num_patches: Number of patches to extract per image.
            patch_strategy: 'random' or 'grid' strategy for patch extraction.
            min_annotation_pixels: Minimum annotation pixels required in patch.
            is_test: Whether the dataset is for testing (no labels).
            label_encoder: Pre-fitted LabelEncoder (if None, will fit on training labels).
            use_dual_stream: Whether to return masks alongside images.
        """
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.use_mask = use_mask
        self.use_patches = use_patches
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.patch_strategy = patch_strategy
        self.min_annotation_pixels = min_annotation_pixels
        self.is_test = is_test
        self.label_encoder = label_encoder
        self.use_dual_stream = use_dual_stream

        # Initialize helpers
        self.tissue_extractor = TissueExtractor(
            patch_size=patch_size,
            min_annotation_pixels=min_annotation_pixels,
        )

        if is_test:
            self.samples = self._get_test_samples()
            self.labels = None
            self.encoded_labels = None
        else:
            if labels_df is None:
                raise ValueError("labels_df must be provided for training/validation.")

            self.samples = [
                self._clean_sample_idx(str(idx))
                for idx in labels_df["sample_index"].tolist()
            ]
            self.labels = labels_df["label"].tolist()

            if self.label_encoder is None:
                self.label_encoder = LabelEncoder()
                self.label_encoder.fit(
                    ["Luminal A", "Luminal B", "HER2(+)", "Triple negative"]
                )
            self.encoded_labels = self.label_encoder.transform(self.labels)

    def _clean_sample_idx(self, sample_idx: str) -> str:
        sample_idx = str(sample_idx)
        if sample_idx.startswith("img_"):
            sample_idx = sample_idx[4:]
        if sample_idx.endswith(".png"):
            sample_idx = sample_idx[:-4]
        return sample_idx

    def _get_test_samples(self) -> List[str]:
        samples = []
        for f in sorted(self.data_dir.glob("img_*.png")):
            sample_idx = self._clean_sample_idx(f.stem)
            samples.append(sample_idx)
        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def _load_image_and_mask(
        self, sample_idx: str
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        img_path = self.data_dir / f"img_{sample_idx}.png"
        img = np.array(Image.open(img_path).convert("RGB"))

        mask = None
        if self.use_mask:
            mask_path = self.data_dir / f"mask_{sample_idx}.png"
            if mask_path.exists():
                mask = np.array(Image.open(mask_path).convert("L"))

        return img, mask

    def _crop_to_tissue_bbox(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
        rows = np.any(mask > 0, axis=1)
        cols = np.any(mask > 0, axis=0)
        if not rows.any() or not cols.any():
            return img
        y_min, y_max = np.where(rows)[0][[0, -1]]
        x_min, x_max = np.where(cols)[0][[0, -1]]
        padding = 10
        y_min = max(0, y_min - padding)
        y_max = min(img.shape[0], y_max + padding)
        x_min = max(0, x_min - padding)
        x_max = min(img.shape[1], x_max + padding)
        return img[y_min:y_max, x_min:x_max]

    def _load_and_preprocess(self, sample_idx: str) -> np.ndarray:
        img, mask = self._load_image_and_mask(sample_idx)
        if mask is not None:
            img = self._crop_to_tissue_bbox(img, mask)
        return img

    def _load_patches(
        self, sample_idx: str
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Load image patches AND mask patches.
        Ensures both lists stay synchronized during padding/augmentation.
        """
        img, mask = self._load_image_and_mask(sample_idx)

        if mask is None:
            mask = np.ones(img.shape[:2], dtype=np.uint8) * 255

        # Extract patches (TissueExtractor returns both)
        patches, mask_patches = self.tissue_extractor.get_valid_patches(
            img=img,
            mask=mask,
            num_patches=self.num_patches,
            strategy=self.patch_strategy,
            stride=self.patch_size // 2,
            shuffle=False,
            min_distance=32,
        )

        if len(patches) == 0:
            h, w = img.shape[:2]
            cy, cx = h // 2, w // 2
            half = self.patch_size // 2
            y1 = max(0, cy - half)
            x1 = max(0, cx - half)
            y2 = min(h, y1 + self.patch_size)
            x2 = min(w, x1 + self.patch_size)

            fallback_patch = img[y1:y2, x1:x2]
            fallback_patch = cv2.resize(
                fallback_patch, (self.patch_size, self.patch_size)
            )

            fallback_mask = mask[y1:y2, x1:x2]
            fallback_mask = cv2.resize(
                fallback_mask,
                (self.patch_size, self.patch_size),
                interpolation=cv2.INTER_NEAREST,
            )

            patches = [fallback_patch.copy() for _ in range(self.num_patches)]
            mask_patches = [fallback_mask.copy() for _ in range(self.num_patches)]

        elif len(patches) < self.num_patches:
            num_missing = self.num_patches - len(patches)
            indices = [random.randint(0, len(patches) - 1) for _ in range(num_missing)]

            for idx in indices:
                patch = patches[idx].copy()
                mask_patch = mask_patches[idx].copy()

                if random.random() > 0.5:
                    patch = cv2.flip(patch, 1)
                    mask_patch = cv2.flip(mask_patch, 1)

                patches.append(patch)
                mask_patches.append(mask_patch)

        return patches, mask_patches

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]:
        sample_idx = self.samples[idx]

        if self.use_patches:
            # Load both images and masks
            patches, mask_patches = self._load_patches(sample_idx)

            # Transform Image Patches
            transformed_patches = []
            for patch in patches:
                patch_pil = Image.fromarray(patch)
                if self.transform:
                    patch_tensor = self.transform(patch_pil)
                else:
                    patch_tensor = transforms.ToTensor()(patch_pil)
                transformed_patches.append(patch_tensor)

            # Stack Images: [num_patches, 3, H, W]
            img_tensor = torch.stack(transformed_patches)

            # Handle Masks for Dual Stream
            if self.use_dual_stream:
                transformed_masks = []
                for mask_p in mask_patches:
                    mask_pil = Image.fromarray(mask_p).convert("L")
                    mask_tensor = transforms.ToTensor()(mask_pil)
                    transformed_masks.append(mask_tensor)
                mask_tensor = torch.stack(transformed_masks)
            else:
                mask_tensor = None

        else:
            img, mask = self._load_image_and_mask(sample_idx)
            img_pil = Image.fromarray(img)
            if self.transform:
                img_tensor = self.transform(img_pil)
            else:
                img_tensor = transforms.ToTensor()(img_pil)

            mask_tensor = None
            if self.use_dual_stream:
                if mask is None:
                    mask = np.zeros(img.shape[:2], dtype=np.uint8)

                mask_pil = Image.fromarray(mask).convert("L")

                if self.transform:
                    for t in self.transform.transforms:
                        if isinstance(t, transforms.Resize):
                            mask_pil = t(mask_pil)
                            break

                mask_tensor = transforms.ToTensor()(mask_pil)

        if self.is_test:
            if self.use_dual_stream and mask_tensor is not None:
                return img_tensor, mask_tensor, sample_idx
            return img_tensor, sample_idx
        else:
            label = self.encoded_labels[idx]
            label_t = torch.tensor(label, dtype=torch.long)

            if self.use_dual_stream and mask_tensor is not None:
                return img_tensor, mask_tensor, label_t
            return img_tensor, label_t

## Custom data module


In [None]:
class PathologyDataModule(L.LightningDataModule):
    """Lightning DataModule for histopathology image classification.

    Args:
        train_data_dir: Directory containing training images and masks.
        test_data_dir: Directory containing test images and masks.
        train_labels_path: Path to CSV file with training labels.
        batch_size: Batch size for dataloaders.
        num_workers: Number of workers for dataloaders.
        img_size: Target image size (used when not using patches).
        use_mask: Whether to use masks for tissue extraction.
        use_patches: Whether to use patch-based loading.
        patch_size: Size of patches to extract.
        num_patches: Number of patches per image.
        min_tissue_ratio: Minimum tissue ratio for valid patches.
        val_split: Fraction of training data to use for validation.
        random_seed: Random seed for reproducibility.
    """

    def __init__(
        self,
        train_data_dir: str = Config.TRAIN_DATA_DIR,
        test_data_dir: str = Config.TEST_DATA_DIR,
        train_labels_path: str = Config.TRAIN_LABELS_PATH,
        trash_list_path: str = "data/trash_list.txt",
        batch_size: int = 16,
        num_workers: int = 2,
        img_size: int = 224,
        use_mask: bool = True,
        use_patches: bool = True,
        patch_size: int = 64,
        num_patches: int = 10,
        min_annotation_pixels: int = 50,
        val_split: float = 0.2,
        random_seed: int = 42,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.train_data_dir = train_data_dir
        self.test_data_dir = test_data_dir
        self.train_labels_path = train_labels_path
        self.trash_list_path = trash_list_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.img_size = img_size
        self.use_mask = use_mask
        self.use_patches = use_patches
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.min_annotation_pixels = min_annotation_pixels
        self.val_split = val_split
        self.random_seed = random_seed

        # Initialize label encoder
        self.label_encoder = LabelEncoder()
        self.label_encoder.fit(Config.CLASSES)

        # Will be set in setup()
        self.train_df = None
        self.val_df = None
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def _get_train_transforms(self) -> transforms.Compose:
        """Get transforms for training with augmentation."""
        target_size = self.img_size

        transform_list = []

        # Standard augmentations
        transform_list.extend(
            [
                transforms.Resize((target_size, target_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=90),
                transforms.ColorJitter(
                    brightness=0.1,
                    contrast=0.1,
                    saturation=0.1,
                    hue=0.05,
                ),
                transforms.RandomAffine(
                    degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        return transforms.Compose(transform_list)

    def _get_val_transforms(self) -> transforms.Compose:
        """Get transforms for validation/test (no augmentation)."""
        target_size = self.img_size

        transform_list = []

        transform_list.extend(
            [
                transforms.Resize((target_size, target_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )

        return transforms.Compose(transform_list)

    def setup(self, stage: Optional[str] = None):
        """Setup datasets for each stage."""
        if stage == "fit" or stage is None:
            # Load and split training data
            full_df = pd.read_csv(self.train_labels_path)

            trash_path = Path(self.trash_list_path)
            if trash_path.exists():
                print(f"Loading trash list from {trash_path}...")
                with open(trash_path, "r") as f:
                    trash_files = [
                        line.strip() for line in f.readlines() if line.strip()
                    ]

                print(f"Total lines in trash_list.txt: {len(trash_files)}")

                # Normalize trash filenames to IDs
                trash_ids = set()
                for t_file in trash_files:
                    clean_id = t_file.replace("img_", "").replace(".png", "")
                    trash_ids.add(clean_id)

                print(
                    f"Unique IDs in trash list (after deduplication): {len(trash_ids)}"
                )

                # Helper to clean DataFrame IDs
                def clean_df_id(x):
                    return str(x).replace("img_", "").replace(".png", "")

                # Get all IDs currently in the CSV
                csv_ids = set(full_df["sample_index"].apply(clean_df_id))

                # Calculate intersection and difference
                ids_to_remove = trash_ids.intersection(csv_ids)
                ids_not_found = trash_ids - csv_ids

                print(f"IDs from trash list FOUND in CSV: {len(ids_to_remove)}")
                print(f"IDs from trash list NOT FOUND in CSV: {len(ids_not_found)}")

                if len(ids_not_found) > 0:
                    print(f"Example missing IDs: {list(ids_not_found)[:5]}")

                # Apply the filter
                initial_count = len(full_df)
                mask = full_df["sample_index"].apply(clean_df_id).isin(trash_ids)
                full_df = full_df[~mask].reset_index(drop=True)

                dropped_count = initial_count - len(full_df)
                print(f"Final check: Removed {dropped_count} rows from dataframe.")
                print(f"Remaining samples: {len(full_df)}")
            else:
                print("No trash_list.txt found, skipping filtering.")

            self.train_df, self.val_df = train_test_split(
                full_df,
                test_size=self.val_split,
                random_state=self.random_seed,
            )

            # Calculate class weights for balanced sampling
            self.class_weights, self.sample_weights = self._compute_sample_weights(
                self.train_df
            )

            # Training dataset: random strategy with half overlap
            self.train_dataset = PathologyDataset(
                data_dir=self.train_data_dir,
                labels_df=self.train_df,
                transform=self._get_train_transforms(),
                use_mask=self.use_mask,
                use_patches=self.use_patches,
                patch_size=self.patch_size,
                num_patches=self.num_patches,
                patch_strategy="random",
                min_annotation_pixels=self.min_annotation_pixels,
                is_test=False,
                label_encoder=self.label_encoder,
            )

            # Validation dataset: grid strategy with no overlap
            self.val_dataset = PathologyDataset(
                data_dir=self.train_data_dir,
                labels_df=self.val_df,
                transform=self._get_val_transforms(),
                use_mask=self.use_mask,
                use_patches=self.use_patches,
                patch_size=self.patch_size,
                num_patches=self.num_patches,
                patch_strategy="grid",  # Grid for validation
                min_annotation_pixels=self.min_annotation_pixels,
                is_test=False,
                label_encoder=self.label_encoder,
            )

        if stage == "test" or stage == "predict" or stage is None:
            # Test dataset: grid strategy with no overlap
            self.test_dataset = PathologyDataset(
                data_dir=self.test_data_dir,
                labels_df=None,
                transform=self._get_val_transforms(),
                use_mask=self.use_mask,
                use_patches=self.use_patches,
                patch_size=self.patch_size,
                num_patches=self.num_patches,
                patch_strategy="grid",
                min_annotation_pixels=self.min_annotation_pixels,
                is_test=True,
                label_encoder=self.label_encoder,
            )

    def _compute_sample_weights(self, df: pd.DataFrame):
        """Compute sample weights for balanced sampling."""
        # Encode labels
        labels = self.label_encoder.transform(df["label"].values)

        # Count samples per class
        class_counts = np.bincount(labels, minlength=len(Config.CLASSES))

        # Compute class weights (inverse frequency)
        class_weights = 1.0 / (class_counts + 1e-6)  # avoid division by zero
        class_weights = (
            class_weights / class_weights.sum() * len(Config.CLASSES)
        )  # normalize

        # Assign weight to each sample based on its class
        sample_weights = class_weights[labels]

        return class_weights, sample_weights

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            sampler=torch.utils.data.WeightedRandomSampler(
                weights=self.sample_weights,
                num_samples=len(self.sample_weights),
                replacement=True,
            ),
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )

    def predict_dataloader(self) -> DataLoader:
        return self.test_dataloader()

## Define neural network model


### Attention


In [None]:
from typing import Optional, Tuple

import lightning as L
import torch


# =============================================================================
# Attention Modules
# =============================================================================
class SimpleAttention(nn.Module):
    """Simple attention mechanism with proper softmax."""

    def __init__(self, feature_dim: int, hidden_dim: int = 256, dropout: float = 0.2):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, features):
        # features: [B, num_patches, feature_dim]
        attention_scores = self.attention(features)  # [B, num_patches, 1]
        attention_weights = F.softmax(attention_scores, dim=1)
        aggregated = torch.sum(attention_weights * features, dim=1)  # [B, feature_dim]
        return aggregated


class GatedAttention(nn.Module):
    """
    Gated Attention Mechanism (Ilse et al. 2018)
    Paper: https://arxiv.org/abs/1802.04712
    """

    def __init__(self, feature_dim: int, hidden_dim: int = 256, dropout: float = 0.25):
        super().__init__()
        self.attention_V = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.Tanh(),
            nn.Dropout(dropout),
        )
        self.attention_U = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.Sigmoid(),
            nn.Dropout(dropout),
        )
        self.attention_w = nn.Linear(hidden_dim, 1)

    def forward(self, features):
        # features: [B, num_patches, feature_dim]
        A_V = self.attention_V(features)  # [B, num_patches, hidden_dim]
        A_U = self.attention_U(features)  # [B, num_patches, hidden_dim]
        attention_scores = self.attention_w(A_V * A_U)  # [B, num_patches, 1]
        attention_weights = F.softmax(attention_scores, dim=1)
        aggregated = torch.sum(attention_weights * features, dim=1)  # [B, feature_dim]
        return aggregated


class CLAMAttention(nn.Module):
    """
    CLAM-style Attention (simplified, no unused parameters)
    """

    def __init__(
        self,
        feature_dim: int,
        hidden_dim: int = 256,
        dropout: float = 0.25,
        num_classes: int = 4,  # kept for API compatibility but not used
    ):
        super().__init__()

        # Gated attention network
        self.attention_a = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.Tanh(),
            nn.Dropout(dropout),
        )
        self.attention_b = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.Sigmoid(),
            nn.Dropout(dropout),
        )
        self.attention_c = nn.Linear(hidden_dim, 1)

    def forward(self, features, return_attention=False):
        # features: [B, num_patches, feature_dim]

        # Gated attention
        a = self.attention_a(features)
        b = self.attention_b(features)
        attention_scores = self.attention_c(a * b)  # [B, num_patches, 1]
        attention_weights = F.softmax(attention_scores, dim=1)

        # Weighted aggregation
        aggregated = torch.sum(attention_weights * features, dim=1)  # [B, feature_dim]

        if return_attention:
            return aggregated, attention_weights.squeeze(-1)
        return aggregated


class TransMIL(nn.Module):
    """
    Transformer-based Multiple Instance Learning (Shao et al. 2021)
    Paper: https://arxiv.org/abs/2106.00908

    Uses transformer encoder with learnable class token for aggregation.
    """

    def __init__(
        self,
        feature_dim: int,
        num_heads: int = 8,
        num_layers: int = 2,
        dropout: float = 0.1,
        max_patches: int = 512,
    ):
        super().__init__()
        self.feature_dim = feature_dim

        # Input projection (in case feature_dim is not divisible by num_heads)
        self.input_proj = nn.Linear(feature_dim, feature_dim)

        # Learnable positional embedding
        self.pos_embedding = nn.Parameter(
            torch.randn(1, max_patches + 1, feature_dim) * 0.02
        )

        # Learnable class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, feature_dim) * 0.02)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=feature_dim,
            nhead=num_heads,
            dim_feedforward=feature_dim * 4,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,  # Pre-norm for better training stability
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Final layer norm
        self.norm = nn.LayerNorm(feature_dim)

    def forward(self, features):
        # features: [B, num_patches, feature_dim]
        B, N, D = features.shape

        # Project input
        x = self.input_proj(features)

        # Add class token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # [B, 1+N, D]

        # Add positional embedding
        x = x + self.pos_embedding[:, : N + 1, :]

        # Transformer
        x = self.transformer(x)
        x = self.norm(x)

        # Return class token as bag representation
        return x[:, 0]  # [B, feature_dim]


class MultiHeadAttentionMIL(nn.Module):
    """
    Multi-Head Self-Attention for MIL with learnable class token.
    Simpler than TransMIL but still effective.
    """

    def __init__(self, feature_dim: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = feature_dim // num_heads
        self.scale = self.head_dim**-0.5

        self.qkv = nn.Linear(feature_dim, feature_dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(feature_dim, feature_dim)
        self.proj_drop = nn.Dropout(dropout)

        # Learnable class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, feature_dim) * 0.02)

        # Layer norm
        self.norm1 = nn.LayerNorm(feature_dim)
        self.norm2 = nn.LayerNorm(feature_dim)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(feature_dim, feature_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(feature_dim * 4, feature_dim),
            nn.Dropout(dropout),
        )

    def forward(self, features):
        # features: [B, num_patches, feature_dim]
        B, N, D = features.shape

        # Add class token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, features], dim=1)  # [B, 1+N, D]

        # Self-attention with residual
        x_norm = self.norm1(x)
        qkv = (
            self.qkv(x_norm)
            .reshape(B, N + 1, 3, self.num_heads, self.head_dim)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B, N + 1, D)
        out = self.proj(out)
        out = self.proj_drop(out)
        x = x + out

        # FFN with residual
        x = x + self.ffn(self.norm2(x))

        # Return class token
        return x[:, 0]  # [B, feature_dim]

### Custom Optimizer


#### Ranger


In [None]:
class Lookahead(torch.optim.Optimizer):
    """
    Lookahead optimizer wrapper (Zhang et al. 2019).

    Wraps any optimizer and maintains slow weights that are updated
    by interpolating towards fast weights every k steps.

    Reference: https://arxiv.org/abs/1907.08610
    """

    def __init__(self, base_optimizer, k: int = 6, alpha: float = 0.5):
        self.base_optimizer = base_optimizer
        self.k = k
        self.alpha = alpha
        self.param_groups = base_optimizer.param_groups
        self.state = {}
        self._step_count = 0

        # Initialize slow weights
        for group in self.param_groups:
            for p in group["params"]:
                if p.requires_grad:
                    self.state[p] = {"slow_weights": p.data.clone()}

    @torch.no_grad()
    def step(self, closure=None):
        """Perform optimization step."""
        loss = self.base_optimizer.step(closure)
        self._step_count += 1

        if self._step_count % self.k == 0:
            for group in self.param_groups:
                for p in group["params"]:
                    if p.requires_grad and p in self.state:
                        slow = self.state[p]["slow_weights"]
                        # Interpolate: slow = slow + alpha * (fast - slow)
                        slow.add_(p.data - slow, alpha=self.alpha)
                        # Copy slow weights to fast weights
                        p.data.copy_(slow)

        return loss

    def zero_grad(self, set_to_none: bool = False):
        self.base_optimizer.zero_grad(set_to_none=set_to_none)

    @property
    def defaults(self):
        return self.base_optimizer.defaults

    def state_dict(self):
        return {
            "base_optimizer": self.base_optimizer.state_dict(),
            "slow_weights": {
                id(p): self.state[p]["slow_weights"]
                for group in self.param_groups
                for p in group["params"]
                if p in self.state
            },
            "step_count": self._step_count,
        }

    def load_state_dict(self, state_dict):
        self.base_optimizer.load_state_dict(state_dict["base_optimizer"])
        self._step_count = state_dict["step_count"]


class RAdam(torch.optim.Optimizer):
    """
    RAdam optimizer (Liu et al. 2019).

    Rectified Adam - automatically adjusts adaptive learning rate
    based on variance of second moment estimate.

    Reference: https://arxiv.org/abs/1908.03265
    """

    def __init__(
        self,
        params,
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 0,
    ):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError("RAdam does not support sparse gradients")

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p)
                    state["exp_avg_sq"] = torch.zeros_like(p)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1
                step = state["step"]

                # Decoupled weight decay
                if group["weight_decay"] != 0:
                    p.mul_(1 - group["lr"] * group["weight_decay"])

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Bias correction
                bias_correction1 = 1 - beta1**step
                bias_correction2 = 1 - beta2**step

                # Compute the maximum length of the approximated SMA
                rho_inf = 2 / (1 - beta2) - 1
                # Compute the length of the approximated SMA
                rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2

                # Variance rectification
                if rho_t > 5:
                    # Compute variance rectification term
                    rect = (
                        (rho_t - 4)
                        * (rho_t - 2)
                        * rho_inf
                        / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
                    ) ** 0.5

                    # Compute adaptive learning rate
                    step_size = group["lr"] * rect / bias_correction1

                    denom = exp_avg_sq.sqrt().add_(group["eps"])
                    p.addcdiv_(exp_avg, denom, value=-step_size)
                else:
                    # Use unadapted learning rate
                    step_size = group["lr"] / bias_correction1
                    p.add_(exp_avg, alpha=-step_size)

        return loss


def Ranger(
    params,
    lr: float = 1e-3,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0,
    k=6,
    alpha=0.5,
):
    """
    Ranger optimizer: RAdam + Lookahead.

    Combines the variance rectification of RAdam with the
    stabilizing effect of Lookahead for robust training.

    Args:
        params: Model parameters
        lr: Learning rate
        betas: Coefficients for computing running averages
        eps: Term added to denominator for numerical stability
        weight_decay: Weight decay (L2 penalty)
        k: Lookahead step interval
        alpha: Lookahead interpolation factor

    Reference: https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
    """
    base = RAdam(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
    return Lookahead(base, k=k, alpha=alpha)

### Main model


In [None]:
class PathologyModel(L.LightningModule):
    """
    Lightning Module for pathology image classification.

    Supports both single-image and multi-instance learning (patch-based) approaches
    with flexible backbone architectures and aggregation strategies.
    """

    def __init__(
        self,
        model_name: str = "resnet50",
        num_classes: int = 4,
        pretrained: bool = True,
        learning_rate: float = 1e-4,
        weight_decay: float = 1e-4,
        use_patches: bool = True,
        patch_aggregation: str = "clam",
        optimizer_name: str = "adamw",
        dropout_rate: float = 0.3,
        label_smoothing: float = 0.1,
        class_weights: Optional[torch.Tensor] = None,
        warmup_epochs: int = 5,
        freeze_backbone_epochs: int = 0,
        mixup_alpha: float = 0.0,
    ):
        """
        Args:
            model_name: Name of the timm model ['resnet', 'convnext_tiny', 'efficientnet_b0', etc.].
            num_classes: Number of output classes.
            pretrained: Whether to use ImageNet pretrained weights.
            learning_rate: Base learning rate for optimizer.
            weight_decay: L2 regularization weight.
            use_patches: Whether input is patch-based [B, num_patches, C, H, W].
            patch_aggregation: Aggregation method:
                - 'mean': Simple mean pooling
                - 'max': Max pooling
                - 'attention': Simple attention with softmax
                - 'gated_attention': Gated attention
                - 'clam': CLAM attention
                - 'transmil': Transformer MIL
                - 'multihead': Multi-head self-attention
            optimizer_name: Optimizer to use:
                - 'adamw': AdamW
                - 'lion': Lion
                - 'ranger': RAdam
            dropout_rate: Dropout probability before classifier.
            label_smoothing: Label smoothing factor for cross-entropy.
            class_weights: Optional class weights for imbalanced datasets.
            warmup_epochs: Number of warmup epochs for learning rate.
            freeze_backbone_epochs: Number of epochs to freeze backbone, default 0 (no freezing).
            mixup_alpha: Alpha parameter for Mixup augmentation (0 = no Mixup).
        """
        super().__init__()
        self.save_hyperparameters(ignore=["class_weights"])

        # Store hyperparameters
        self.model_name = model_name
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.use_patches = use_patches
        self.patch_aggregation = patch_aggregation
        self.optimizer_name = optimizer_name.lower()
        self.label_smoothing = label_smoothing
        self.class_weights = class_weights
        self.warmup_epochs = warmup_epochs
        self.freeze_backbone_epochs = freeze_backbone_epochs
        self.mixup_alpha = mixup_alpha

        # Validate optimizer choice
        valid_optimizers = ["adamw", "lion", "ranger"]
        if self.optimizer_name not in valid_optimizers:
            raise ValueError(
                f"Unknown optimizer: {optimizer_name}. Choose from: {valid_optimizers}"
            )

        # Check if Lion is available when requested
        if self.optimizer_name == "lion" and Lion is None:
            raise ImportError(
                "Lion optimizer requires lion-pytorch package. "
                "Install with: pip install lion-pytorch"
            )

        # Build model architecture
        self._build_model(model_name, pretrained, dropout_rate)

        # Initialize loss and metrics
        self._setup_loss()
        self._setup_metrics()

        # Freeze backbone if requested
        if freeze_backbone_epochs > 0:
            self._freeze_backbone()

    def _build_model(self, model_name: str, pretrained: bool, dropout_rate: float):
        """Build the model architecture."""
        # Create backbone using timm
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=0,
            drop_rate=0.0,
        )

        # Get feature dimension
        self.feature_dim = self.backbone.num_features

        # Build aggregation module for patches
        if self.use_patches:
            self.aggregation = self._build_aggregation_module()
        else:
            self.aggregation = None

        # Build classifier head
        self.classifier = nn.Sequential(
            nn.LayerNorm(
                self.feature_dim
            ),  # LayerNorm often better than BatchNorm for MIL
            nn.Dropout(p=dropout_rate),
            nn.Linear(self.feature_dim, self.feature_dim // 2),
            nn.GELU(),  # GELU often better than ReLU
            nn.LayerNorm(self.feature_dim // 2),
            nn.Dropout(p=dropout_rate / 2),
            nn.Linear(self.feature_dim // 2, self.num_classes),
        )

    def _build_aggregation_module(self) -> Optional[nn.Module]:
        """Build patch aggregation module based on strategy."""

        if self.patch_aggregation in ["mean", "max"]:
            return None

        elif self.patch_aggregation == "attention":
            return SimpleAttention(self.feature_dim)

        elif self.patch_aggregation == "gated_attention":
            return GatedAttention(self.feature_dim)

        elif self.patch_aggregation == "clam":
            return CLAMAttention(
                self.feature_dim,
                num_classes=self.num_classes,
            )

        elif self.patch_aggregation == "transmil":
            return TransMIL(self.feature_dim)

        elif self.patch_aggregation == "multihead":
            return MultiHeadAttentionMIL(self.feature_dim)

        else:
            raise ValueError(
                f"Unknown aggregation: {self.patch_aggregation}. "
                f"Choose from: mean, max, attention, gated_attention, clam, transmil, multihead"
            )

    def _setup_loss(self):
        """Initialize loss function."""
        self.criterion = nn.CrossEntropyLoss(
            weight=self.class_weights,
            label_smoothing=self.label_smoothing,
        )

    def _setup_metrics(self):
        """Initialize metrics for each stage."""
        metric_kwargs = {"task": "multiclass", "num_classes": self.num_classes}

        # Training metrics
        self.train_acc = Accuracy(**metric_kwargs)
        self.train_f1 = F1Score(**metric_kwargs, average="macro")

        # Validation metrics
        self.val_acc = Accuracy(**metric_kwargs)
        self.val_f1 = F1Score(**metric_kwargs, average="macro")
        self.val_auroc = AUROC(**metric_kwargs)

        # Test metrics
        self.test_acc = Accuracy(**metric_kwargs)
        self.test_f1 = F1Score(**metric_kwargs, average="macro")
        self.test_auroc = AUROC(**metric_kwargs)
        self.test_confmat = ConfusionMatrix(**metric_kwargs)

    def _freeze_backbone(self):
        """Freeze backbone parameters for transfer learning."""
        for param in self.backbone.parameters():
            param.requires_grad = False
        print(f"Backbone frozen for {self.freeze_backbone_epochs} epochs")

    def _unfreeze_backbone(self):
        """Unfreeze backbone parameters."""
        for param in self.backbone.parameters():
            param.requires_grad = True
        print("Backbone unfrozen")

    def _apply_mixup(self, x: torch.Tensor, y: torch.Tensor):
        """
        Applies Mixup augmentation to the batch.
        Returns:
            mixed_x: The mixed input tensor
            target_a: Original labels
            target_b: Shuffled labels
            lam: The mixing coefficient (lambda)
        """
        # 1. Sample lambda from Beta distribution
        if self.mixup_alpha > 0:
            lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
        else:
            lam = 1.0

        # 2. Generate permutation indices
        batch_size = x.size(0)
        index = torch.randperm(batch_size, device=x.device)

        # 3. Create mixed inputs
        # This works for both [B, C, H, W] and [B, Num_Patches, C, H, W]
        mixed_x = lam * x + (1 - lam) * x[index, :]

        # 4. Get pair of targets
        target_a, target_b = y, y[index]

        return mixed_x, target_a, target_b, lam

    def _forward_single(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass for single images."""
        return self.backbone(x)

    def _forward_patches(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass for patch-based images."""
        batch_size, num_patches, c, h, w = x.shape

        # Reshape to process all patches
        x = x.view(batch_size * num_patches, c, h, w)

        # Extract features
        features = self.backbone(x)  # [B * num_patches, feature_dim]

        # Reshape back
        features = features.view(batch_size, num_patches, -1)

        # Aggregate patches
        features = self._aggregate_patches(features)

        return features

    def _aggregate_patches(self, features: torch.Tensor) -> torch.Tensor:
        """
        Aggregate patch features.

        Args:
            features: [B, num_patches, feature_dim]

        Returns:
            Aggregated features [B, feature_dim]
        """
        if self.patch_aggregation == "mean":
            return features.mean(dim=1)

        elif self.patch_aggregation == "max":
            return features.max(dim=1)[0]

        else:
            return self.aggregation(features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: Input tensor
               - If use_patches=False: [B, C, H, W]
               - If use_patches=True: [B, num_patches, C, H, W]

        Returns:
            Logits of shape [B, num_classes]
        """
        if self.use_patches and x.dim() == 5:
            features = self._forward_patches(x)
        else:
            features = self._forward_single(x)

        logits = self.classifier(features)
        return logits

    def training_step(self, batch: Tuple, batch_idx: int) -> torch.Tensor:
        """Training step with Mixup support."""
        x, y = batch

        # Check if we should apply mixup
        # We generally only apply mixup if alpha > 0
        if self.mixup_alpha > 0:
            mixed_x, target_a, target_b, lam = self._apply_mixup(x, y)

            # Forward pass with mixed input
            logits = self(mixed_x)

            # Mixup Loss: weighted sum of loss against both targets
            loss_a = self.criterion(logits, target_a)
            loss_b = self.criterion(logits, target_b)
            loss = lam * loss_a + (1 - lam) * loss_b

            # For logging accuracy, we look at the 'dominant' label (optional)
            # Or we can simply skip accuracy logging for mixup steps as it's noisy.
            # Here, we calculate acc against the target with higher weight.
            if lam >= 0.5:
                preds = torch.argmax(logits, dim=1)
                self.train_acc(preds, target_a)
                self.train_f1(preds, target_a)
            else:
                preds = torch.argmax(logits, dim=1)
                self.train_acc(preds, target_b)
                self.train_f1(preds, target_b)

        else:
            # Standard training (No Mixup)
            logits = self(x)
            loss = self.criterion(logits, y)
            preds = torch.argmax(logits, dim=1)
            self.train_acc(preds, y)
            self.train_f1(preds, y)

        # Log metrics
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log(
            "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True
        )
        self.log("train/f1", self.train_f1, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch: Tuple, batch_idx: int):
        """Validation step."""
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        preds = torch.argmax(logits, dim=1)
        probs = F.softmax(logits, dim=1)

        self.val_acc(preds, y)
        self.val_f1(preds, y)
        self.val_auroc(probs, y)

        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/f1", self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/auroc", self.val_auroc, on_step=False, on_epoch=True)

    def test_step(self, batch: Tuple, batch_idx: int):
        """Test step."""
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        preds = torch.argmax(logits, dim=1)
        probs = F.softmax(logits, dim=1)

        self.test_acc(preds, y)
        self.test_f1(preds, y)
        self.test_auroc(probs, y)
        self.test_confmat(preds, y)

        self.log("test/loss", loss)
        self.log("test/acc", self.test_acc)
        self.log("test/f1", self.test_f1)
        self.log("test/auroc", self.test_auroc)

    def predict_step(self, batch: Tuple, batch_idx: int) -> Dict[str, Any]:
        """
        Prediction step with Test Time Augmentation (TTA).
        Averages predictions across: Original, Horizontal Flip, Vertical Flip, and Rotations.
        """
        x, sample_ids = batch
        # x shape is either [B, C, H, W] or [B, N, C, H, W]
        # augment the spatial dims: H and W (the last two dimensions)
        spatial_dims = [-2, -1]

        # Define the list of augmentations to apply
        augmentations = [
            lambda t: t,
            lambda t: torch.flip(t, dims=[-1]),  # Horizontal Flip
            lambda t: torch.flip(t, dims=[-2]),  # Vertical Flip
            lambda t: torch.rot90(t, k=1, dims=spatial_dims),  # 90 degree rotation
        ]

        logits_sum = 0

        # Loop through augmentations
        for aug_func in augmentations:
            aug_x = aug_func(x)
            logits = self(aug_x)
            logits_sum += logits

        # Average the logits
        avg_logits = logits_sum / len(augmentations)

        # Compute final probabilities and predictions
        probs = F.softmax(avg_logits, dim=1)
        preds = torch.argmax(avg_logits, dim=1)

        return {
            "sample_ids": sample_ids,
            "predictions": preds,
            "probabilities": probs,
            "logits": avg_logits,
        }

    def configure_optimizers(self) -> Dict[str, Any]:
        """Configure optimizer and scheduler."""
        # Separate parameters for differential learning rates
        backbone_params = list(self.backbone.parameters())
        classifier_params = list(self.classifier.parameters())
        if self.aggregation is not None:
            classifier_params += list(self.aggregation.parameters())

        # Adjust hyperparameters based on optimizer
        if self.optimizer_name == "lion":
            backbone_lr = self.learning_rate * 0.03  # Even lower for backbone
            classifier_lr = self.learning_rate * 0.3
            wd = self.weight_decay * 10
        else:
            backbone_lr = self.learning_rate * 0.1
            classifier_lr = self.learning_rate
            wd = self.weight_decay

        # Differential learning rates
        param_groups = [
            {
                "params": backbone_params,
                "lr": backbone_lr,
                "name": "backbone",
            },
            {
                "params": classifier_params,
                "lr": classifier_lr,
                "name": "classifier",
            },
        ]

        # Create optimizer based on selection
        if self.optimizer_name == "adamw":
            optimizer = torch.optim.AdamW(
                param_groups,
                weight_decay=wd,
            )
        elif self.optimizer_name == "lion":
            optimizer = Lion(
                param_groups,
                weight_decay=wd,
                betas=(0.9, 0.99),
            )
        elif self.optimizer_name == "ranger":
            optimizer = Ranger(
                param_groups,
                weight_decay=wd,
                k=6,  # Lookahead step
                alpha=0.5,  # Lookahead alpha
            )
        else:
            raise ValueError(f"Unknown optimizer: {self.optimizer_name}")

        # Log optimizer info
        # print(f"\n{'=' * 60}")
        # print(f"Optimizer: {self.optimizer_name.upper()}")
        # print(f"Backbone LR: {backbone_lr:.2e}")
        # print(f"Classifier LR: {classifier_lr:.2e}")
        # print(f"Weight Decay: {wd:.2e}")
        # print(f"{'=' * 60}\n")

        # Cosine annealing with warmup
        def lr_lambda(epoch):
            if epoch < self.warmup_epochs:
                return (epoch + 1) / self.warmup_epochs
            return 0.5 * (
                1
                + torch.cos(
                    torch.tensor(
                        (epoch - self.warmup_epochs)
                        / (50 - self.warmup_epochs)
                        * 3.14159
                    )
                ).item()
            )

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }

    def on_train_epoch_start(self):
        """Called at the start of each training epoch."""
        if (
            self.freeze_backbone_epochs > 0
            and self.current_epoch == self.freeze_backbone_epochs
        ):
            self._unfreeze_backbone()
            self.trainer.strategy.setup_optimizers(self.trainer)

## Train logic


In [None]:
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)

# Initialize
datamodule = PathologyDataModule(
    use_mask=True,
    use_patches=True,
    patch_size=112,
    num_patches=8,
    img_size=224,
    batch_size=16,
    min_annotation_pixels=1,
)
model = PathologyModel(
    model_name="convnext_tiny",
    use_patches=True,
    patch_aggregation="clam",
    learning_rate=2e-4,
    weight_decay=1e-2,
    dropout_rate=0.3,
    label_smoothing=0.1,
    # freeze_backbone_epochs=5,
    optimizer_name="adamw",
    mixup_alpha=1,
)

In [None]:
# Train
trainer = L.Trainer(
    max_epochs=50,
    accelerator="auto",
    callbacks=[
        ModelCheckpoint(
            monitor="val/f1", mode="max", filename="{epoch:02d}-{val/f1:.4f}"
        ),
        EarlyStopping(monitor="val/loss", patience=10, mode="min"),
        LearningRateMonitor(logging_interval="epoch"),
    ],
    accumulate_grad_batches=1,
    # gradient_clip_val=0.5,
    precision="16-mixed",
    log_every_n_steps=5,
    devices="auto",
    logger=None,
)
trainer.fit(model, datamodule)

In [None]:
print(trainer.checkpoint_callback.best_model_path)

## Inference logic


In [None]:
# Inference on test set
# 1. Load the best model from the checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path

print(f"Loading model from: {best_checkpoint}")
best_model = PathologyModel.load_from_checkpoint(best_checkpoint)

datamodule = PathologyDataModule()
datamodule.setup(stage="test")

trainer = L.Trainer(
    accelerator="auto",
    precision="16-mixed",
)

# 2. Run prediction using the Trainer
print("Generating predictions...")
predictions = trainer.predict(best_model, datamodule=datamodule)

# 3. Process the results
sample_ids = []
pred_labels_encoded = []

for batch in predictions:
    sample_ids.extend(batch["sample_ids"])
    pred_labels_encoded.extend(batch["predictions"].cpu().numpy().tolist())

# 4. Decode integer labels back to string labels
decoded_labels = datamodule.label_encoder.inverse_transform(pred_labels_encoded)

# 5. Format sample_ids back to filenames (e.g., "1004" -> "img_1004.png")
# We assume sample_ids contains the raw ID strings (e.g., "1004", "1005")
formatted_sample_ids = [f"img_{sid}.png" for sid in sample_ids]

# 6. Create DataFrame
submission_df = pd.DataFrame(
    {"sample_index": formatted_sample_ids, "label": decoded_labels}
)

# Optional: Sort by sample_index for a cleaner look
submission_df = submission_df.sort_values("sample_index").reset_index(drop=True)

# 7. Save to CSV
output_csv_path = "submission.csv"
submission_df.to_csv(output_csv_path, index=False)

print("\n" + "=" * 50)
print(f"Submission saved to: {output_csv_path}")
print(f"Total samples predicted: {len(submission_df)}")
print("=" * 50)

# Preview the first few rows
print(submission_df.head())

In [None]:
submission_df