In [None]:
!pip install gdown lightning

In [None]:
!gdown https://drive.google.com/uc?id=1rh_21CJliIkuahqqaWETH8Zf06qEA7eG
!mkdir temp
!mkdir data
!unzip -q data.zip -d temp
!mv temp/data/* data/
!rm -rf temp
!ls -l data

In [None]:
from pathlib import Path
from typing import List, Optional, Tuple

import cv2
import lightning.pytorch as pl
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import TensorBoardLogger
from PIL import Image
from scipy import ndimage
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import models, transforms

In [None]:
# ============================================================================
# Configuration
# ============================================================================
class Config:
    # Data paths
    DATA_DIR = "./data"
    TRAIN_DATA_DIR = "./data/train_data"
    TEST_DATA_DIR = "./data/test_data"
    TRAIN_LABELS_PATH = "./data/train_labels_cleaned.csv"
    OUTPUT_PATH = "./predictions.csv"

    # Class labels
    CLASSES = ["Luminal A", "Luminal B", "HER2(+)", "Triple negative"]
    NUM_CLASSES = 4

    # Image settings
    IMG_SIZE = 512  # Larger size for histopathology
    USE_MASK = True

    # Tissue detection settings
    TISSUE_THRESHOLD = 0.8  # Threshold for tissue detection (lower = more sensitive)
    MIN_TISSUE_AREA = 0.05  # Minimum tissue area ratio
    PADDING = 50  # Padding around tissue bounding box

    # Patch-based settings (for very large images)
    USE_PATCHES = True
    PATCH_SIZE = 512
    NUM_PATCHES = 8  # Number of patches to sample per image

    # Stain normalization
    USE_STAIN_NORMALIZATION = True

    # Training settings
    BATCH_SIZE = 4
    NUM_WORKERS = 2
    MAX_EPOCHS = 100
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4

    # Model settings
    MODEL_NAME = "efficientnet_b3"
    PRETRAINED = True

    # Validation split
    VAL_SPLIT = 0.2
    RANDOM_SEED = 42

In [None]:
# ============================================================================
# Stain Normalization (Macenko method simplified)
# ============================================================================
class StainNormalizer:
    """Simple stain normalization for H&E images."""

    def __init__(self):
        # Reference stain vectors (standard H&E)
        self.target_means = np.array([148.60, 41.56, 105.97])  # LAB color space
        self.target_stds = np.array([41.56, 9.01, 6.67])

    def normalize(self, img: np.ndarray) -> np.ndarray:
        """Normalize stain colors using LAB color space."""
        if img.dtype != np.uint8:
            img = (img * 255).astype(np.uint8)

        # Convert to LAB
        lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB).astype(np.float32)

        # Normalize each channel
        for i in range(3):
            channel = lab[:, :, i]
            channel_mean = channel.mean()
            channel_std = channel.std() + 1e-6

            # Normalize to target distribution
            lab[:, :, i] = ((channel - channel_mean) / channel_std) * self.target_stds[
                i
            ] + self.target_means[i]

        # Clip and convert back
        lab = np.clip(lab, 0, 255).astype(np.uint8)
        normalized = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

        return normalized


class ReinhardNormalizer:
    """Reinhard stain normalization."""

    def __init__(self):
        # Target statistics (can be computed from a reference image)
        self.target_means = None
        self.target_stds = None
        self._set_default_target()

    def _set_default_target(self):
        """Set default target statistics for H&E."""
        self.target_means = np.array([180.0, 135.0, 165.0])
        self.target_stds = np.array([25.0, 15.0, 20.0])

    def fit(self, target_img: np.ndarray):
        """Fit normalizer to a target image."""
        lab = cv2.cvtColor(target_img, cv2.COLOR_RGB2LAB).astype(np.float32)
        self.target_means = np.array([lab[:, :, i].mean() for i in range(3)])
        self.target_stds = np.array([lab[:, :, i].std() for i in range(3)])

    def normalize(self, img: np.ndarray) -> np.ndarray:
        """Normalize image to target statistics."""
        lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB).astype(np.float32)

        for i in range(3):
            src_mean = lab[:, :, i].mean()
            src_std = lab[:, :, i].std() + 1e-6
            lab[:, :, i] = ((lab[:, :, i] - src_mean) / src_std) * self.target_stds[
                i
            ] + self.target_means[i]

        lab = np.clip(lab, 0, 255).astype(np.uint8)
        return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)




In [None]:
# ============================================================================
# Data Module
# ============================================================================
class PathologyDataModule(LightningDataModule):
    """DataModule for pathology images."""

    def __init__(self, config: Config = Config()):
        super().__init__()
        self.config = config
        self.label_encoder = LabelEncoder()
        self.label_encoder.fit(config.CLASSES)

        # Transforms
        self.train_transform = transforms.Compose(
            [
                transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=90),
                transforms.ColorJitter(
                    brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
                ),
                transforms.RandomAffine(
                    degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        self.val_transform = transforms.Compose(
            [
                transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            labels_df = pd.read_csv(self.config.TRAIN_LABELS_PATH)

            def clean_idx(idx):
                idx = str(idx)
                if idx.startswith("img_"):
                    idx = idx[4:]
                if idx.endswith(".png"):
                    idx = idx[:-4]
                return idx

            labels_df["sample_index"] = labels_df["sample_index"].apply(clean_idx)

            train_df, val_df = train_test_split(
                labels_df,
                test_size=self.config.VAL_SPLIT,
                stratify=labels_df["label"],
                random_state=self.config.RANDOM_SEED,
            )

            self.train_dataset = PathologyDataset(
                data_dir=self.config.TRAIN_DATA_DIR,
                labels_df=train_df,
                transform=self.train_transform,
                img_size=self.config.IMG_SIZE,
                use_mask=self.config.USE_MASK,
                use_patches=self.config.USE_PATCHES,
                patch_size=self.config.PATCH_SIZE,
                num_patches=self.config.NUM_PATCHES,
                use_stain_norm=self.config.USE_STAIN_NORMALIZATION,
                label_encoder=self.label_encoder,
            )

            self.val_dataset = PathologyDataset(
                data_dir=self.config.TRAIN_DATA_DIR,
                labels_df=val_df,
                transform=self.val_transform,
                img_size=self.config.IMG_SIZE,
                use_mask=self.config.USE_MASK,
                use_patches=self.config.USE_PATCHES,
                patch_size=self.config.PATCH_SIZE,
                num_patches=self.config.NUM_PATCHES,
                use_stain_norm=self.config.USE_STAIN_NORMALIZATION,
                label_encoder=self.label_encoder,
            )

            # Calculate class weights for balanced sampling
            class_counts = train_df["label"].value_counts()
            weights = 1.0 / class_counts[train_df["label"].values].values
            self.sample_weights = torch.DoubleTensor(weights)

        if stage == "test" or stage == "predict" or stage is None:
            self.test_dataset = PathologyDataset(
                data_dir=self.config.TEST_DATA_DIR,
                transform=self.val_transform,
                img_size=self.config.IMG_SIZE,
                use_mask=self.config.USE_MASK,
                use_patches=self.config.USE_PATCHES,
                patch_size=self.config.PATCH_SIZE,
                num_patches=self.config.NUM_PATCHES,
                use_stain_norm=self.config.USE_STAIN_NORMALIZATION,
                is_test=True,
                label_encoder=self.label_encoder,
            )

    def train_dataloader(self) -> DataLoader:
        sampler = WeightedRandomSampler(
            self.sample_weights, len(self.sample_weights), replacement=True
        )

        return DataLoader(
            self.train_dataset,
            batch_size=self.config.BATCH_SIZE,
            sampler=sampler,
            num_workers=self.config.NUM_WORKERS,
            pin_memory=True,
            drop_last=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.config.BATCH_SIZE,
            shuffle=False,
            num_workers=self.config.NUM_WORKERS,
            pin_memory=True,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.config.BATCH_SIZE,
            shuffle=False,
            num_workers=self.config.NUM_WORKERS,
            pin_memory=True,
        )

    def predict_dataloader(self) -> DataLoader:
        return self.test_dataloader()

In [None]:
# ============================================================================
# Model with Multi-Instance Learning (MIL) for patches
# ============================================================================
class AttentionMIL(nn.Module):
    """Attention-based Multiple Instance Learning pooling."""

    def __init__(self, feature_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, 1)
        )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: [batch, num_instances, features]
        Returns:
            pooled: [batch, features]
            attention_weights: [batch, num_instances]
        """
        # Compute attention scores
        attn_scores = self.attention(x)  # [batch, num_instances, 1]
        attn_weights = F.softmax(attn_scores, dim=1)  # [batch, num_instances, 1]

        # Weighted sum
        pooled = torch.sum(x * attn_weights, dim=1)  # [batch, features]

        return pooled, attn_weights.squeeze(-1)


class PathologyClassifier(LightningModule):
    """Classifier optimized for histopathology with optional MIL."""

    def __init__(
        self,
        num_classes: int = Config.NUM_CLASSES,
        model_name: str = Config.MODEL_NAME,
        pretrained: bool = Config.PRETRAINED,
        learning_rate: float = Config.LEARNING_RATE,
        weight_decay: float = Config.WEIGHT_DECAY,
        class_names: List[str] = Config.CLASSES,
        use_patches: bool = Config.USE_PATCHES,
        num_patches: int = Config.NUM_PATCHES,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.class_names = class_names
        self.use_patches = use_patches
        self.num_patches = num_patches

        # Create backbone
        self.backbone, self.feature_dim = self._create_backbone(model_name, pretrained)

        # MIL attention pooling (for patch-based)
        if use_patches:
            self.mil_attention = AttentionMIL(self.feature_dim)

        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes),
        )

        # Loss function
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        # Metrics storage
        self.validation_step_outputs = []

    def _create_backbone(
        self, model_name: str, pretrained: bool
    ) -> Tuple[nn.Module, int]:
        """Create backbone model."""

        if model_name == "resnet50":
            weights = models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
            model = models.resnet50(weights=weights)
            feature_dim = 2048
            backbone = nn.Sequential(*list(model.children())[:-1], nn.Flatten())

        elif model_name == "efficientnet_b3":
            weights = (
                models.EfficientNet_B3_Weights.IMAGENET1K_V1 if pretrained else None
            )
            model = models.efficientnet_b3(weights=weights)
            feature_dim = 1536
            backbone = nn.Sequential(
                model.features, nn.AdaptiveAvgPool2d(1), nn.Flatten()
            )

        elif model_name == "efficientnet_b4":
            weights = (
                models.EfficientNet_B4_Weights.IMAGENET1K_V1 if pretrained else None
            )
            model = models.efficientnet_b4(weights=weights)
            feature_dim = 1792
            backbone = nn.Sequential(
                model.features, nn.AdaptiveAvgPool2d(1), nn.Flatten()
            )

        elif model_name == "convnext_small":
            weights = (
                models.ConvNeXt_Small_Weights.IMAGENET1K_V1 if pretrained else None
            )
            model = models.convnext_small(weights=weights)
            feature_dim = 768
            backbone = nn.Sequential(
                model.features, nn.AdaptiveAvgPool2d(1), nn.Flatten()
            )

        elif model_name == "densenet121":
            weights = models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
            model = models.densenet121(weights=weights)
            feature_dim = 1024
            backbone = nn.Sequential(
                model.features,
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
            )

        else:
            raise ValueError(f"Unsupported model: {model_name}")

        return backbone, feature_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_patches:
            # x shape: [batch, num_patches, C, H, W]
            batch_size, num_patches, C, H, W = x.shape

            # Flatten batch and patches
            x = x.view(batch_size * num_patches, C, H, W)

            # Extract features
            features = self.backbone(x)  # [batch*num_patches, feature_dim]

            # Reshape back
            features = features.view(
                batch_size, num_patches, -1
            )  # [batch, num_patches, feature_dim]

            # MIL attention pooling
            pooled, _ = self.mil_attention(features)  # [batch, feature_dim]

            # Classify
            logits = self.classifier(pooled)
        else:
            # x shape: [batch, C, H, W]
            features = self.backbone(x)
            logits = self.classifier(features)

        return logits

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        images, labels = batch
        logits = self(images)
        loss = self.criterion(logits, labels)

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> dict:
        images, labels = batch
        logits = self(images)
        loss = self.criterion(logits, labels)

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        self.validation_step_outputs.append({"preds": preds, "labels": labels})

        return {"val_loss": loss, "val_acc": acc}

    def on_validation_epoch_end(self):
        all_preds = torch.cat([x["preds"] for x in self.validation_step_outputs])
        all_labels = torch.cat([x["labels"] for x in self.validation_step_outputs])

        for i, class_name in enumerate(self.class_names):
            mask = all_labels == i
            if mask.sum() > 0:
                class_acc = (all_preds[mask] == all_labels[mask]).float().mean()
                self.log(f"val_acc_{class_name}", class_acc, on_epoch=True)

        self.validation_step_outputs.clear()

    def predict_step(
        self, batch: Tuple[torch.Tensor, str], batch_idx: int
    ) -> Tuple[torch.Tensor, List[str]]:
        images, sample_indices = batch
        logits = self(images)
        preds = torch.argmax(logits, dim=1)
        return preds, sample_indices

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2, eta_min=1e-7
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "interval": "epoch"},
        }

In [None]:
# ============================================================================
# Training and Inference
# ============================================================================
def train_model(
    config: Config = Config(),
) -> Tuple[PathologyClassifier, PathologyDataModule, Trainer]:
    """Train the pathology classifier."""

    pl.seed_everything(config.RANDOM_SEED)

    data_module = PathologyDataModule(config)

    model = PathologyClassifier(
        num_classes=config.NUM_CLASSES,
        model_name=config.MODEL_NAME,
        pretrained=config.PRETRAINED,
        learning_rate=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY,
        class_names=config.CLASSES,
        use_patches=config.USE_PATCHES,
        num_patches=config.NUM_PATCHES,
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath="checkpoints",
        filename="pathology-{epoch:02d}-{val_acc:.4f}",
        monitor="val_acc",
        mode="max",
        save_top_k=3,
        save_last=True,
    )

    early_stopping = EarlyStopping(monitor="val_acc", patience=15, mode="max")

    lr_monitor = LearningRateMonitor(logging_interval="epoch")

    logger = TensorBoardLogger("logs", name="pathology")

    trainer = Trainer(
        max_epochs=config.MAX_EPOCHS,
        accelerator="auto",
        devices="auto",
        precision="16-mixed",
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        gradient_clip_val=1.0,
        accumulate_grad_batches=16,
        deterministic=True,
    )

    trainer.fit(model, datamodule=data_module)

    return model, data_module, trainer

In [None]:
def predict_and_save(
    model: PathologyClassifier,
    data_module: PathologyDataModule,
    trainer: Trainer,
    output_path: str = Config.OUTPUT_PATH,
) -> pd.DataFrame:
    """Run inference and save predictions."""

    data_module.setup(stage="predict")
    predictions = trainer.predict(model, datamodule=data_module)

    all_preds = []
    all_indices = []

    for preds, indices in predictions:
        all_preds.extend(preds.cpu().numpy())
        all_indices.extend(indices)

    label_encoder = data_module.label_encoder
    predicted_labels = label_encoder.inverse_transform(all_preds)

    results_df = pd.DataFrame({"sample_index": all_indices, "label": predicted_labels})

    results_df = results_df.sort_values("sample_index").reset_index(drop=True)
    results_df.to_csv(output_path, index=False)

    print(f"Predictions saved to: {output_path}")

    return results_df


In [None]:
"""Main training and inference pipeline."""

config = Config()

print("=" * 60)
print("Pathology-Optimized Molecular Subtype Classification")
print("=" * 60)
print(f"Model: {config.MODEL_NAME}")
print(f"Image Size: {config.IMG_SIZE}")
print(f"Use Patches: {config.USE_PATCHES}")
if config.USE_PATCHES:
    print(f"  - Patch Size: {config.PATCH_SIZE}")
    print(f"  - Num Patches: {config.NUM_PATCHES}")
print(f"Stain Normalization: {config.USE_STAIN_NORMALIZATION}")
print(f"Batch Size: {config.BATCH_SIZE}")
print("=" * 60)

print("\n[1/2] Training model...")
model, data_module, trainer = train_model(config)

In [None]:
!nvidia-smi

In [None]:
print("\n[2/2] Running inference on test set...")
results_df = predict_and_save(model, data_module, trainer, config.OUTPUT_PATH)

print("\n" + "=" * 60)
print("Prediction Summary")
print("=" * 60)
print(results_df["label"].value_counts())
print("\nDone!")