In [None]:
# install necessary packages
!pip install gdown lightning

## Get data


In [None]:
!gdown https://drive.google.com/uc?id=19o_R8S5f09XFXZIk_F1B7nui7TxNq9mH
!unzip -q data.zip
!mv an2dl2526c2 data
!ls -l data

In [None]:
# get the trash list
!gdown https://drive.google.com/uc?id=1Hp0nKCW6DNDVCvWYGLGkahQFc4xtu4eu
!mv trash_list.txt data/trash_list.txt

## Import Libraries


In [None]:
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import cv2
import lightning as L
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset
from torchmetrics import AUROC, Accuracy, ConfusionMatrix, F1Score
from torchvision import transforms

## Utilities


### Config


In [None]:
class Config:
    # Data paths
    DATA_DIR = "./data"
    TRAIN_DATA_DIR = "./data/train_data"
    TEST_DATA_DIR = "./data/test_data"
    TRAIN_LABELS_PATH = "./data/train_labels.csv"
    OUTPUT_PATH = "./predictions.csv"
    # Class labels
    CLASSES = ["Luminal A", "Luminal B", "HER2(+)", "Triple negative"]
    NUM_CLASSES = 4

    # Image settings
    IMG_SIZE = 512  # Larger size for histopathology
    USE_MASK = True

    # Tissue detection settings
    TISSUE_THRESHOLD = 0.8  # Threshold for tissue detection (lower = more sensitive)
    MIN_TISSUE_AREA = 0.05  # Minimum tissue area ratio
    PADDING = 50  # Padding around tissue bounding box

    # Patch-based settings (for very large images)
    USE_PATCHES = True
    PATCH_SIZE = 224
    NUM_PATCHES = 16  # Number of patches to sample per image

    # Stain normalization
    USE_STAIN_NORMALIZATION = False

    # Training settings
    BATCH_SIZE = 16
    NUM_WORKERS = 2
    MAX_EPOCHS = 50
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4

    # Validation split
    VAL_SPLIT = 0.2
    RANDOM_SEED = 42

### Extract tissue


In [None]:
class TissueExtractor:
    """
    Extract patches from images using existing masks.
    Designed for workflow where ground truth masks are already available.

    Args:
        patch_size: Size of the square patch to extract.
        min_tissue_ratio: Minimum ratio of tissue pixels required in a patch.
    """

    def __init__(self, patch_size: int = 224, min_tissue_ratio: float = 0.05):
        self.patch_size = patch_size
        self.min_tissue_ratio = min_tissue_ratio

    def get_valid_patches(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        num_patches: int = 8,
        strategy: str = "random",  # 'random' or 'grid'
        stride: int = None,  # For grid strategy: step size between patches
        shuffle: bool = True,  # For grid strategy: shuffle valid patches before selecting
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Args:
            img: RGB image (H, W, 3)
            mask: Binary or Index mask (H, W). Assumes tissue > 0.
            num_patches: Number of patches to extract per image.
            strategy: 'random' samples points from mask; 'grid' slides across image.
            stride: Step size for grid strategy. Defaults to patch_size (no overlap).
            shuffle: Whether to shuffle grid patches before selecting (for diversity).

        Returns:
            images: List of RGB patches
            masks: List of corresponding Mask patches
        """

        h, w = img.shape[:2]

        # if mask has 3 channels, convert to single channel
        if len(mask.shape) == 3:
            mask = mask[:, :, 0]

        # Find all tissue pixel indices
        tissue_indices = np.where(mask > 0)

        # If no tissue found, return empty lists
        if len(tissue_indices[0]) == 0:
            print("Warning: No tissue found in mask!")
            return [], []

        if strategy == "random":
            return self._extract_random(img, mask, tissue_indices, num_patches, h, w)
        elif strategy == "grid":
            return self._extract_grid(img, mask, num_patches, h, w, stride, shuffle)
        else:
            raise ValueError(f"Unknown strategy: {strategy}. Use 'random' or 'grid'.")

    def _extract_random(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        tissue_indices: Tuple[np.ndarray, np.ndarray],
        num_patches: int,
        h: int,
        w: int,
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """Random sampling strategy: randomly select tissue pixels as patch centers."""

        patches_img = []
        patches_mask = []

        # Protection mechanism: limit the number of attempts to avoid infinite loop
        attempts = 0
        max_attempts = num_patches * 50

        while len(patches_img) < num_patches and attempts < max_attempts:
            attempts += 1

            # Randomly select a tissue pixel as center
            idx = np.random.randint(len(tissue_indices[0]))
            cy, cx = tissue_indices[0][idx], tissue_indices[1][idx]

            # Calculate top-left and bottom-right corners to ensure no out-of-bounds
            half_size = self.patch_size // 2

            # Simple center cropping logic
            y_min = int(cy - half_size)
            x_min = int(cx - half_size)
            y_max = y_min + self.patch_size
            x_max = x_min + self.patch_size

            # Boundary check: if the patch goes out of image bounds, skip and retry
            if y_min < 0 or x_min < 0 or y_max > h or x_max > w:
                continue

            # Extract Mask Patch for validation
            mask_patch = mask[y_min:y_max, x_min:x_max]

            # Calculate the proportion of non-zero pixels
            current_ratio = np.count_nonzero(mask_patch) / mask_patch.size

            if current_ratio >= self.min_tissue_ratio:
                # Extract Image Patch
                img_patch = img[y_min:y_max, x_min:x_max]

                patches_img.append(img_patch)
                patches_mask.append(mask_patch)

        return patches_img, patches_mask

    def _extract_grid(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        num_patches: int,
        h: int,
        w: int,
        stride: int = None,
        shuffle: bool = True,
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Grid strategy: systematically slide across the image.

        Args:
            img: RGB image
            mask: Binary mask
            num_patches: Maximum number of patches to extract
            h, w: Image dimensions
            stride: Step size between patches. Defaults to patch_size (no overlap).
            shuffle: If True, shuffle valid patches before selecting to add diversity.
        """

        if stride is None:
            stride = self.patch_size  # No overlap by default

        # Calculate all valid grid positions
        y_positions = list(range(0, h - self.patch_size + 1, stride))
        x_positions = list(range(0, w - self.patch_size + 1, stride))

        # Collect all valid patches first
        valid_patches = []  # List of (y_min, x_min, tissue_ratio)

        for y_min in y_positions:
            for x_min in x_positions:
                y_max = y_min + self.patch_size
                x_max = x_min + self.patch_size

                # Extract mask patch for validation
                mask_patch = mask[y_min:y_max, x_min:x_max]

                # Calculate tissue ratio
                tissue_ratio = np.count_nonzero(mask_patch) / mask_patch.size

                if tissue_ratio >= self.min_tissue_ratio:
                    valid_patches.append((y_min, x_min, tissue_ratio))

        # Shuffle or sort based on preference
        if shuffle:
            np.random.shuffle(valid_patches)
        else:
            # Sort by tissue ratio (descending) to prioritize patches with more tissue
            valid_patches.sort(key=lambda x: x[2], reverse=True)

        # Extract the requested number of patches
        patches_img = []
        patches_mask = []

        for y_min, x_min, _ in valid_patches[:num_patches]:
            y_max = y_min + self.patch_size
            x_max = x_min + self.patch_size

            img_patch = img[y_min:y_max, x_min:x_max]
            mask_patch = mask[y_min:y_max, x_min:x_max]

            patches_img.append(img_patch)
            patches_mask.append(mask_patch)

        return patches_img, patches_mask

    def get_all_valid_patches(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        stride: int = None,
    ) -> Tuple[List[np.ndarray], List[np.ndarray], List[Tuple[int, int]]]:
        """
        Extract ALL valid patches from the image using grid strategy.
        Useful for inference or when you need complete coverage.

        Args:
            img: RGB image (H, W, 3)
            mask: Binary or Index mask (H, W)
            stride: Step size between patches. Defaults to patch_size.

        Returns:
            images: List of RGB patches
            masks: List of corresponding mask patches
            coordinates: List of (y_min, x_min) for each patch
        """
        h, w = img.shape[:2]

        if len(mask.shape) == 3:
            mask = mask[:, :, 0]

        if stride is None:
            stride = self.patch_size

        y_positions = list(range(0, h - self.patch_size + 1, stride))
        x_positions = list(range(0, w - self.patch_size + 1, stride))

        patches_img = []
        patches_mask = []
        coordinates = []

        for y_min in y_positions:
            for x_min in x_positions:
                y_max = y_min + self.patch_size
                x_max = x_min + self.patch_size

                mask_patch = mask[y_min:y_max, x_min:x_max]
                tissue_ratio = np.count_nonzero(mask_patch) / mask_patch.size

                if tissue_ratio >= self.min_tissue_ratio:
                    img_patch = img[y_min:y_max, x_min:x_max]
                    patches_img.append(img_patch)
                    patches_mask.append(mask_patch)
                    coordinates.append((y_min, x_min))

        return patches_img, patches_mask, coordinates

## Custom dataset


In [None]:
class PathologyDataset(Dataset):
    """Dataset optimized for histopathology images.

    Args:
        data_dir: Directory containing images and masks.
        labels_df: DataFrame with 'sample_index' and 'label' columns for training/validation.
        transform: torchvision transforms to apply to images.
        img_size: Target size to resize images to (img_size x img_size). If using patches, this is ignored.
        use_mask: Whether to use existing masks for tissue extraction.
        use_patches: Whether to load images as patches.
        patch_size: Size of each patch if using patches.
        num_patches: Number of patches to extract per image.
        patch_strategy: Strategy for patch extraction ('random' or 'grid').
        min_tissue_ratio: Minimum tissue ratio for valid patches.
        use_stain_norm: Whether to apply stain normalization.
        is_test: Whether the dataset is for testing (no labels).
        label_encoder: Optional LabelEncoder for encoding labels.
    """

    def __init__(
        self,
        data_dir: str,
        labels_df: Optional[pd.DataFrame] = None,
        transform: Optional[transforms.Compose] = None,
        use_mask: bool = True,
        use_patches: bool = False,
        patch_size: int = 224,
        num_patches: int = 8,
        patch_strategy: str = "random",
        min_tissue_ratio: float = 0.05,
        use_stain_norm: bool = True,
        is_test: bool = False,
        label_encoder: Optional[LabelEncoder] = None,
    ):
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.use_mask = use_mask
        self.use_patches = use_patches
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.patch_strategy = patch_strategy
        self.use_stain_norm = use_stain_norm
        self.is_test = is_test
        self.label_encoder = label_encoder

        # Initialize helpers
        self.tissue_extractor = TissueExtractor(
            patch_size=patch_size,
            min_tissue_ratio=min_tissue_ratio,
        )
        self.stain_normalizer = None if use_stain_norm else None

        if is_test:
            self.samples = self._get_test_samples()
            self.labels = None
            self.encoded_labels = None
        else:
            if labels_df is None:
                raise ValueError("labels_df must be provided for training/validation.")

            self.samples = [
                self._clean_sample_idx(str(idx))
                for idx in labels_df["sample_index"].tolist()
            ]
            self.labels = labels_df["label"].tolist()

            if self.label_encoder is None:
                self.label_encoder = LabelEncoder()
                self.label_encoder.fit(
                    ["Luminal A", "Luminal B", "HER2(+)", "Triple negative"]
                )
            self.encoded_labels = self.label_encoder.transform(self.labels)

    def _clean_sample_idx(self, sample_idx: str) -> str:
        """Clean sample index by removing prefix and suffix."""
        sample_idx = str(sample_idx)
        if sample_idx.startswith("img_"):
            sample_idx = sample_idx[4:]
        if sample_idx.endswith(".png"):
            sample_idx = sample_idx[:-4]
        return sample_idx

    def _get_test_samples(self) -> List[str]:
        """Get list of sample indices from test directory."""
        samples = []
        for f in sorted(self.data_dir.glob("img_*.png")):
            sample_idx = self._clean_sample_idx(f.stem)
            samples.append(sample_idx)
        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def _load_image_and_mask(
        self, sample_idx: str
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """Load image and optionally its mask."""
        img_path = self.data_dir / f"img_{sample_idx}.png"
        img = np.array(Image.open(img_path).convert("RGB"))

        mask = None
        if self.use_mask:
            mask_path = self.data_dir / f"mask_{sample_idx}.png"
            if mask_path.exists():
                mask = np.array(Image.open(mask_path).convert("L"))

        return img, mask

    def _crop_to_tissue_bbox(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
        """Crop image to bounding box of tissue region."""
        # Find bounding box of tissue
        rows = np.any(mask > 0, axis=1)
        cols = np.any(mask > 0, axis=0)

        if not rows.any() or not cols.any():
            return img  # No tissue found, return original

        y_min, y_max = np.where(rows)[0][[0, -1]]
        x_min, x_max = np.where(cols)[0][[0, -1]]

        # Add small padding
        padding = 10
        y_min = max(0, y_min - padding)
        y_max = min(img.shape[0], y_max + padding)
        x_min = max(0, x_min - padding)
        x_max = min(img.shape[1], x_max + padding)

        return img[y_min:y_max, x_min:x_max]

    def _apply_stain_normalization(self, img: np.ndarray) -> np.ndarray:
        """Apply stain normalization to image."""
        if self.use_stain_norm and self.stain_normalizer is not None:
            try:
                return self.stain_normalizer.normalize(img)
            except Exception:
                pass  # Skip normalization if it fails
        return img

    def _load_and_preprocess(self, sample_idx: str) -> np.ndarray:
        """Load and preprocess full image with optional tissue cropping."""
        img, mask = self._load_image_and_mask(sample_idx)

        # Crop to tissue region if mask is available
        if mask is not None:
            img = self._crop_to_tissue_bbox(img, mask)

        # Stain normalization
        if self.use_stain_norm:
            img = self._apply_stain_normalization(img)

        return img

    def _load_patches(self, sample_idx: str) -> List[np.ndarray]:
        """Load image as patches using TissueExtractor."""
        img, mask = self._load_image_and_mask(sample_idx)

        if mask is None:
            # If no mask, create a simple one (all tissue)
            mask = np.ones(img.shape[:2], dtype=np.uint8) * 255

        # Extract patches using TissueExtractor
        patches, _ = self.tissue_extractor.get_valid_patches(
            img=img,
            mask=mask,
            num_patches=self.num_patches,
            strategy=self.patch_strategy,
            stride=self.patch_size // 2 if self.patch_strategy == "grid" else None,
            shuffle=False,
        )

        # Handle case where fewer patches are found than requested
        if len(patches) == 0:
            # Fallback: extract center crop
            h, w = img.shape[:2]
            cy, cx = h // 2, w // 2
            half = self.patch_size // 2
            y1 = max(0, cy - half)
            x1 = max(0, cx - half)
            y2 = min(h, y1 + self.patch_size)
            x2 = min(w, x1 + self.patch_size)
            fallback_patch = img[y1:y2, x1:x2]
            fallback_patch = cv2.resize(
                fallback_patch, (self.patch_size, self.patch_size)
            )
            patches = [fallback_patch] * self.num_patches

        elif len(patches) < self.num_patches:
            # Duplicate existing patches to reach num_patches
            while len(patches) < self.num_patches:
                patches.append(patches[len(patches) % len(patches)])

        # Apply stain normalization to each patch
        normalized_patches = []
        for patch in patches:
            patch = self._apply_stain_normalization(patch)
            normalized_patches.append(patch)

        return normalized_patches

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]:
        sample_idx = self.samples[idx]

        if self.use_patches:
            patches = self._load_patches(sample_idx)

            # Transform each patch
            transformed_patches = []
            for patch in patches:
                patch_pil = Image.fromarray(patch)
                if self.transform:
                    patch_tensor = self.transform(patch_pil)
                else:
                    patch_tensor = transforms.ToTensor()(patch_pil)
                transformed_patches.append(patch_tensor)

            # Stack patches [num_patches, C, H, W]
            img_tensor = torch.stack(transformed_patches)
        else:
            img = self._load_and_preprocess(sample_idx)
            img_pil = Image.fromarray(img)

            if self.transform:
                img_tensor = self.transform(img_pil)
            else:
                img_tensor = transforms.ToTensor()(img_pil)

        if self.is_test:
            return img_tensor, sample_idx
        else:
            label = self.encoded_labels[idx]
            return img_tensor, torch.tensor(label, dtype=torch.long)

## Custom data module


In [None]:
class PathologyDataModule(L.LightningDataModule):
    """Lightning DataModule for histopathology image classification.

    Args:
        train_data_dir: Directory containing training images and masks.
        test_data_dir: Directory containing test images and masks.
        train_labels_path: Path to CSV file with training labels.
        batch_size: Batch size for dataloaders.
        num_workers: Number of workers for dataloaders.
        img_size: Target image size (used when not using patches).
        use_mask: Whether to use masks for tissue extraction.
        use_patches: Whether to use patch-based loading.
        patch_size: Size of patches to extract.
        num_patches: Number of patches per image.
        min_tissue_ratio: Minimum tissue ratio for valid patches.
        use_stain_norm: Whether to apply stain normalization.
        val_split: Fraction of training data to use for validation.
        random_seed: Random seed for reproducibility.
    """

    def __init__(
        self,
        train_data_dir: str = Config.TRAIN_DATA_DIR,
        test_data_dir: str = Config.TEST_DATA_DIR,
        train_labels_path: str = Config.TRAIN_LABELS_PATH,
        batch_size: int = Config.BATCH_SIZE,
        num_workers: int = Config.NUM_WORKERS,
        img_size: int = Config.IMG_SIZE,
        use_mask: bool = Config.USE_MASK,
        use_patches: bool = Config.USE_PATCHES,
        patch_size: int = Config.PATCH_SIZE,
        num_patches: int = Config.NUM_PATCHES,
        min_tissue_ratio: float = Config.MIN_TISSUE_AREA,
        use_stain_norm: bool = Config.USE_STAIN_NORMALIZATION,
        val_split: float = Config.VAL_SPLIT,
        random_seed: int = Config.RANDOM_SEED,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.train_data_dir = train_data_dir
        self.test_data_dir = test_data_dir
        self.train_labels_path = train_labels_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.img_size = img_size
        self.use_mask = use_mask
        self.use_patches = use_patches
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.min_tissue_ratio = min_tissue_ratio
        self.use_stain_norm = use_stain_norm
        self.val_split = val_split
        self.random_seed = random_seed

        # Initialize label encoder
        self.label_encoder = LabelEncoder()
        self.label_encoder.fit(Config.CLASSES)

        # Will be set in setup()
        self.train_df = None
        self.val_df = None
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def _get_train_transforms(self) -> transforms.Compose:
        """Get augmentation transforms for training."""
        target_size = self.patch_size if self.use_patches else self.img_size
        return transforms.Compose(
            [
                transforms.Resize((target_size, target_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=90),
                transforms.ColorJitter(
                    brightness=0.2,
                    contrast=0.2,
                    saturation=0.1,
                    hue=0.05,
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )

    def _get_val_transforms(self) -> transforms.Compose:
        """Get transforms for validation/test (no augmentation)."""
        target_size = self.patch_size if self.use_patches else self.img_size
        return transforms.Compose(
            [
                transforms.Resize((target_size, target_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )

    def setup(self, stage: Optional[str] = None):
        """Setup datasets for each stage."""
        if stage == "fit" or stage is None:
            # Load and split training data
            full_df = pd.read_csv(self.train_labels_path)

            # Stratified split
            from sklearn.model_selection import train_test_split

            self.train_df, self.val_df = train_test_split(
                full_df,
                test_size=self.val_split,
                stratify=full_df["label"],
                random_state=self.random_seed,
            )

            # Training dataset: random strategy with half overlap
            self.train_dataset = PathologyDataset(
                data_dir=self.train_data_dir,
                labels_df=self.train_df,
                transform=self._get_train_transforms(),
                use_mask=self.use_mask,
                use_patches=self.use_patches,
                patch_size=self.patch_size,
                num_patches=self.num_patches,
                patch_strategy="random",  # Random for training
                min_tissue_ratio=self.min_tissue_ratio,
                use_stain_norm=self.use_stain_norm,
                is_test=False,
                label_encoder=self.label_encoder,
            )

            # Validation dataset: grid strategy with no overlap
            self.val_dataset = PathologyDataset(
                data_dir=self.train_data_dir,
                labels_df=self.val_df,
                transform=self._get_val_transforms(),
                use_mask=self.use_mask,
                use_patches=self.use_patches,
                patch_size=self.patch_size,
                num_patches=self.num_patches,
                patch_strategy="grid",  # Grid for validation
                min_tissue_ratio=self.min_tissue_ratio,
                use_stain_norm=self.use_stain_norm,
                is_test=False,
                label_encoder=self.label_encoder,
            )

        if stage == "test" or stage == "predict" or stage is None:
            # Test dataset: grid strategy with no overlap
            self.test_dataset = PathologyDataset(
                data_dir=self.test_data_dir,
                labels_df=None,
                transform=self._get_val_transforms(),
                use_mask=self.use_mask,
                use_patches=self.use_patches,
                patch_size=self.patch_size,
                num_patches=self.num_patches,
                patch_strategy="grid",
                min_tissue_ratio=self.min_tissue_ratio,
                use_stain_norm=self.use_stain_norm,
                is_test=True,
                label_encoder=self.label_encoder,
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )

    def predict_dataloader(self) -> DataLoader:
        return self.test_dataloader()

## Define neural network model


In [None]:
class PathologyModel(L.LightningModule):
    """Lightning Module for histopathology image classification using timm ResNet.

    Args:
        model_name: Name of the timm model to use (e.g., 'resnet50', 'resnet34').
        num_classes: Number of output classes.
        pretrained: Whether to use pretrained weights.
        learning_rate: Learning rate for optimizer.
        weight_decay: Weight decay for optimizer.
        use_patches: Whether input is patch-based [B, num_patches, C, H, W].
        patch_aggregation: How to aggregate patch features ('mean', 'max', 'attention').
        dropout_rate: Dropout rate before final classifier.
        label_smoothing: Label smoothing factor for cross-entropy loss.
        class_weights: Optional tensor of class weights for imbalanced data.
    """

    def __init__(
        self,
        model_name: str = "resnet50",
        num_classes: int = Config.NUM_CLASSES,
        pretrained: bool = True,
        learning_rate: float = Config.LEARNING_RATE,
        weight_decay: float = Config.WEIGHT_DECAY,
        use_patches: bool = Config.USE_PATCHES,
        patch_aggregation: str = "mean",
        dropout_rate: float = 0.3,
        label_smoothing: float = 0.1,
        class_weights: Optional[torch.Tensor] = None,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["class_weights"])

        self.model_name = model_name
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.use_patches = use_patches
        self.patch_aggregation = patch_aggregation
        self.label_smoothing = label_smoothing
        self.class_weights = class_weights

        # Build backbone
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=0,  # Remove classifier head, we'll add our own
        )

        # Get feature dimension from backbone
        self.feature_dim = self.backbone.num_features

        # Attention module for patch aggregation (optional)
        if patch_aggregation == "attention":
            self.attention = nn.Sequential(
                nn.Linear(self.feature_dim, 128),
                nn.Tanh(),
                nn.Linear(128, 1),
            )
        else:
            self.attention = None

        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate),
            nn.Linear(self.feature_dim, num_classes),
        )

        # Loss function
        self.criterion = nn.CrossEntropyLoss(
            weight=class_weights,
            label_smoothing=label_smoothing,
        )

        # Metrics
        self._setup_metrics()

    def _setup_metrics(self):
        """Initialize metrics for training, validation, and test."""
        # Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=self.num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=self.num_classes)
        self.test_acc = Accuracy(task="multiclass", num_classes=self.num_classes)

        # F1 Score (macro averaged)
        self.train_f1 = F1Score(
            task="multiclass", num_classes=self.num_classes, average="macro"
        )
        self.val_f1 = F1Score(
            task="multiclass", num_classes=self.num_classes, average="macro"
        )
        self.test_f1 = F1Score(
            task="multiclass", num_classes=self.num_classes, average="macro"
        )

        # AUROC (one-vs-rest)
        self.val_auroc = AUROC(task="multiclass", num_classes=self.num_classes)
        self.test_auroc = AUROC(task="multiclass", num_classes=self.num_classes)

        # Confusion matrix for test
        self.test_confmat = ConfusionMatrix(
            task="multiclass", num_classes=self.num_classes
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: Input tensor. Shape depends on use_patches:
               - If use_patches=False: [B, C, H, W]
               - If use_patches=True: [B, num_patches, C, H, W]

        Returns:
            Logits of shape [B, num_classes]
        """
        if self.use_patches and x.dim() == 5:
            # Patch-based input: [B, num_patches, C, H, W]
            batch_size, num_patches, c, h, w = x.shape

            # Reshape to process all patches at once
            x = x.view(batch_size * num_patches, c, h, w)

            # Extract features for all patches
            features = self.backbone(x)  # [B * num_patches, feature_dim]

            # Reshape back to [B, num_patches, feature_dim]
            features = features.view(batch_size, num_patches, -1)

            # Aggregate patch features
            features = self._aggregate_patches(features)  # [B, feature_dim]
        else:
            # Standard input: [B, C, H, W]
            features = self.backbone(x)  # [B, feature_dim]

        # Classification
        logits = self.classifier(features)  # [B, num_classes]
        return logits

    def _aggregate_patches(self, features: torch.Tensor) -> torch.Tensor:
        """
        Aggregate features from multiple patches.

        Args:
            features: Patch features of shape [B, num_patches, feature_dim]

        Returns:
            Aggregated features of shape [B, feature_dim]
        """
        if self.patch_aggregation == "mean":
            return features.mean(dim=1)

        elif self.patch_aggregation == "max":
            return features.max(dim=1)[0]

        elif self.patch_aggregation == "attention":
            # Attention-based aggregation (MIL-style)
            # attention_weights: [B, num_patches, 1]
            attention_scores = self.attention(features)
            attention_weights = F.softmax(attention_scores, dim=1)

            # Weighted sum: [B, feature_dim]
            aggregated = (attention_weights * features).sum(dim=1)
            return aggregated

        else:
            raise ValueError(f"Unknown aggregation method: {self.patch_aggregation}")

    def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        # Get predictions
        preds = torch.argmax(logits, dim=1)

        # Update metrics
        self.train_acc(preds, y)
        self.train_f1(preds, y)

        # Log metrics
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log(
            "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True
        )
        self.log("train/f1", self.train_f1, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        # Get predictions and probabilities
        preds = torch.argmax(logits, dim=1)
        probs = F.softmax(logits, dim=1)

        # Update metrics
        self.val_acc(preds, y)
        self.val_f1(preds, y)
        self.val_auroc(probs, y)

        # Log metrics
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/f1", self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/auroc", self.val_auroc, on_step=False, on_epoch=True)

        return loss

    def test_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        # Get predictions and probabilities
        preds = torch.argmax(logits, dim=1)
        probs = F.softmax(logits, dim=1)

        # Update metrics
        self.test_acc(preds, y)
        self.test_f1(preds, y)
        self.test_auroc(probs, y)
        self.test_confmat(preds, y)

        # Log metrics
        self.log("test/loss", loss, on_step=False, on_epoch=True)
        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
        self.log("test/f1", self.test_f1, on_step=False, on_epoch=True)
        self.log("test/auroc", self.test_auroc, on_step=False, on_epoch=True)

        return loss

    def predict_step(self, batch: tuple, batch_idx: int) -> Dict[str, Any]:
        """Prediction step for inference."""
        x, sample_ids = batch
        logits = self(x)

        probs = F.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)

        return {
            "sample_ids": sample_ids,
            "predictions": preds,
            "probabilities": probs,
        }

    def configure_optimizers(self) -> Dict[str, Any]:
        """Configure optimizer and learning rate scheduler."""
        # Use different learning rates for backbone and classifier
        backbone_params = list(self.backbone.parameters())
        classifier_params = list(self.classifier.parameters())

        if self.attention is not None:
            classifier_params += list(self.attention.parameters())

        param_groups = [
            {
                "params": backbone_params,
                "lr": self.learning_rate * 0.1,
            },  # Lower LR for pretrained backbone
            {"params": classifier_params, "lr": self.learning_rate},
        ]

        optimizer = torch.optim.AdamW(
            param_groups,
            weight_decay=self.weight_decay,
        )

        # Cosine annealing scheduler with warm restarts
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=10,  # Restart every 10 epochs
            T_mult=2,  # Double the period after each restart
            eta_min=1e-7,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }

    def on_test_epoch_end(self):
        """Log confusion matrix at end of test epoch."""
        confmat = self.test_confmat.compute()
        print("\nConfusion Matrix:")
        print(confmat.cpu().numpy())

        # Per-class accuracy
        per_class_acc = confmat.diag() / confmat.sum(dim=1)
        for i, class_name in enumerate(Config.CLASSES):
            print(f"  {class_name}: {per_class_acc[i]:.4f}")


class PathologyModelWithMixup(PathologyModel):
    """Extended model with Mixup/CutMix augmentation for improved generalization.

    Args:
        mixup_alpha: Alpha parameter for mixup. Set to 0 to disable.
        cutmix_alpha: Alpha parameter for cutmix. Set to 0 to disable.
        mixup_prob: Probability of applying mixup/cutmix.
        **kwargs: Arguments passed to PathologyModel.
    """

    def __init__(
        self,
        mixup_alpha: float = 0.4,
        cutmix_alpha: float = 1.0,
        mixup_prob: float = 0.5,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.mixup_alpha = mixup_alpha
        self.cutmix_alpha = cutmix_alpha
        self.mixup_prob = mixup_prob

    def _mixup_data(self, x: torch.Tensor, y: torch.Tensor) -> tuple:
        """Apply mixup augmentation."""
        if self.mixup_alpha > 0:
            lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
        else:
            lam = 1.0

        batch_size = x.size(0)
        index = torch.randperm(batch_size, device=x.device)

        mixed_x = lam * x + (1 - lam) * x[index]
        y_a, y_b = y, y[index]

        return mixed_x, y_a, y_b, lam

    def _mixup_criterion(
        self, logits: torch.Tensor, y_a: torch.Tensor, y_b: torch.Tensor, lam: float
    ) -> torch.Tensor:
        """Compute mixup loss."""
        return lam * self.criterion(logits, y_a) + (1 - lam) * self.criterion(
            logits, y_b
        )

    def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
        x, y = batch

        # Apply mixup with probability
        if self.training and np.random.random() < self.mixup_prob:
            x, y_a, y_b, lam = self._mixup_data(x, y)
            logits = self(x)
            loss = self._mixup_criterion(logits, y_a, y_b, lam)

            # For metrics, use the dominant class
            preds = torch.argmax(logits, dim=1)
            target = y_a if lam > 0.5 else y_b
        else:
            logits = self(x)
            loss = self.criterion(logits, y)
            preds = torch.argmax(logits, dim=1)
            target = y

        # Update metrics
        self.train_acc(preds, target)
        self.train_f1(preds, target)

        # Log metrics
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log(
            "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True
        )
        self.log("train/f1", self.train_f1, on_step=False, on_epoch=True)

        return loss

## Train logic


In [None]:
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

# Initialize
datamodule = PathologyDataModule()
model = PathologyModel(
    model_name="resnet50",  # or resnet34, resnet101, etc.
    use_patches=True,
    patch_aggregation="attention",
)

# Train
trainer = L.Trainer(
    max_epochs=Config.MAX_EPOCHS,
    accelerator="auto",
    callbacks=[
        ModelCheckpoint(monitor="val/f1", mode="max"),
        EarlyStopping(monitor="val/loss", patience=10),
    ],
    accumulate_grad_batches=1,
    gradient_clip_val=0.5,
    precision="16-mixed",
    devices=2,
)
trainer.fit(model, datamodule)

In [None]:
print(trainer.checkpoint_callback.best_model_path)

## Inference logic


In [None]:
# Inference on test set
# 1. Load the best model from the checkpoint
best_checkpoint = (
    "/kaggle/working/lightning_logs/version_2/checkpoints/epoch=45-step=1610.ckpt"
)

print(f"Loading model from: {best_checkpoint}")
best_model = PathologyModel.load_from_checkpoint(best_checkpoint)

datamodule = PathologyDataModule()
datamodule.setup(stage="test")

trainer = L.Trainer(
    accelerator="auto",
    precision="16-mixed",
)

# 2. Run prediction using the Trainer
print("Generating predictions...")
predictions = trainer.predict(best_model, datamodule=datamodule)

# 3. Process the results
sample_ids = []
pred_labels_encoded = []

for batch in predictions:
    sample_ids.extend(batch["sample_ids"])
    pred_labels_encoded.extend(batch["predictions"].cpu().numpy().tolist())

# 4. Decode integer labels back to string labels
decoded_labels = datamodule.label_encoder.inverse_transform(pred_labels_encoded)

# 5. Format sample_ids back to filenames (e.g., "1004" -> "img_1004.png")
# We assume sample_ids contains the raw ID strings (e.g., "1004", "1005")
formatted_sample_ids = [f"img_{sid}.png" for sid in sample_ids]

# 6. Create DataFrame
submission_df = pd.DataFrame(
    {"sample_index": formatted_sample_ids, "label": decoded_labels}
)

# Optional: Sort by sample_index for a cleaner look
submission_df = submission_df.sort_values("sample_index").reset_index(drop=True)

# 7. Save to CSV
output_csv_path = "submission.csv"
submission_df.to_csv(output_csv_path, index=False)

print("\n" + "=" * 50)
print(f"Submission saved to: {output_csv_path}")
print(f"Total samples predicted: {len(submission_df)}")
print("=" * 50)

# Preview the first few rows
print(submission_df.head())


In [None]:
submission_df