In [None]:
from pathlib import Path

import cv2
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

In [None]:
class TissueExtractor:
    """
    Extract patches from images using existing masks.
    Designed for workflow where ground truth masks are already available.

    Args:
        patch_size: Size of the square patch to extract.
        min_tissue_ratio: Minimum ratio of tissue pixels required in a patch.
    """

    def __init__(self, patch_size: int = 224, min_tissue_ratio: float = 0.05):
        self.patch_size = patch_size
        self.min_tissue_ratio = min_tissue_ratio

    def get_valid_patches(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        num_patches: int = 8,
        strategy: str = "random",  # 'random' or 'grid'
        stride: int = None,  # For grid strategy: step size between patches
        shuffle: bool = True,  # For grid strategy: shuffle valid patches before selecting
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Args:
            img: RGB image (H, W, 3)
            mask: Binary or Index mask (H, W). Assumes tissue > 0.
            num_patches: Number of patches to extract per image.
            strategy: 'random' samples points from mask; 'grid' slides across image.
            stride: Step size for grid strategy. Defaults to patch_size (no overlap).
            shuffle: Whether to shuffle grid patches before selecting (for diversity).

        Returns:
            images: List of RGB patches
            masks: List of corresponding Mask patches
        """

        h, w = img.shape[:2]

        # if mask has 3 channels, convert to single channel
        if len(mask.shape) == 3:
            mask = mask[:, :, 0]

        # Find all tissue pixel indices
        tissue_indices = np.where(mask > 0)

        # If no tissue found, return empty lists
        if len(tissue_indices[0]) == 0:
            print("Warning: No tissue found in mask!")
            return [], []

        if strategy == "random":
            return self._extract_random(img, mask, tissue_indices, num_patches, h, w)
        elif strategy == "grid":
            return self._extract_grid(img, mask, num_patches, h, w, stride, shuffle)
        else:
            raise ValueError(f"Unknown strategy: {strategy}. Use 'random' or 'grid'.")

    def _extract_random(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        tissue_indices: Tuple[np.ndarray, np.ndarray],
        num_patches: int,
        h: int,
        w: int,
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """Random sampling strategy: randomly select tissue pixels as patch centers."""

        patches_img = []
        patches_mask = []

        # Protection mechanism: limit the number of attempts to avoid infinite loop
        attempts = 0
        max_attempts = num_patches * 50

        while len(patches_img) < num_patches and attempts < max_attempts:
            attempts += 1

            # Randomly select a tissue pixel as center
            idx = np.random.randint(len(tissue_indices[0]))
            cy, cx = tissue_indices[0][idx], tissue_indices[1][idx]

            # Calculate top-left and bottom-right corners to ensure no out-of-bounds
            half_size = self.patch_size // 2

            # Simple center cropping logic
            y_min = int(cy - half_size)
            x_min = int(cx - half_size)
            y_max = y_min + self.patch_size
            x_max = x_min + self.patch_size

            # Boundary check: if the patch goes out of image bounds, skip and retry
            if y_min < 0 or x_min < 0 or y_max > h or x_max > w:
                continue

            # Extract Mask Patch for validation
            mask_patch = mask[y_min:y_max, x_min:x_max]

            # Calculate the proportion of non-zero pixels
            current_ratio = np.count_nonzero(mask_patch) / mask_patch.size

            if current_ratio >= self.min_tissue_ratio:
                # Extract Image Patch
                img_patch = img[y_min:y_max, x_min:x_max]

                patches_img.append(img_patch)
                patches_mask.append(mask_patch)

        return patches_img, patches_mask

    def _extract_grid(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        num_patches: int,
        h: int,
        w: int,
        stride: int = None,
        shuffle: bool = True,
    ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Grid strategy: systematically slide across the image.

        Args:
            img: RGB image
            mask: Binary mask
            num_patches: Maximum number of patches to extract
            h, w: Image dimensions
            stride: Step size between patches. Defaults to patch_size (no overlap).
            shuffle: If True, shuffle valid patches before selecting to add diversity.
        """

        if stride is None:
            stride = self.patch_size  # No overlap by default

        # Calculate all valid grid positions
        y_positions = list(range(0, h - self.patch_size + 1, stride))
        x_positions = list(range(0, w - self.patch_size + 1, stride))

        # Collect all valid patches first
        valid_patches = []  # List of (y_min, x_min, tissue_ratio)

        for y_min in y_positions:
            for x_min in x_positions:
                y_max = y_min + self.patch_size
                x_max = x_min + self.patch_size

                # Extract mask patch for validation
                mask_patch = mask[y_min:y_max, x_min:x_max]

                # Calculate tissue ratio
                tissue_ratio = np.count_nonzero(mask_patch) / mask_patch.size

                if tissue_ratio >= self.min_tissue_ratio:
                    valid_patches.append((y_min, x_min, tissue_ratio))

        # Shuffle or sort based on preference
        if shuffle:
            np.random.shuffle(valid_patches)
        else:
            # Sort by tissue ratio (descending) to prioritize patches with more tissue
            valid_patches.sort(key=lambda x: x[2], reverse=True)

        # Extract the requested number of patches
        patches_img = []
        patches_mask = []

        for y_min, x_min, _ in valid_patches[:num_patches]:
            y_max = y_min + self.patch_size
            x_max = x_min + self.patch_size

            img_patch = img[y_min:y_max, x_min:x_max]
            mask_patch = mask[y_min:y_max, x_min:x_max]

            patches_img.append(img_patch)
            patches_mask.append(mask_patch)

        return patches_img, patches_mask

    def get_all_valid_patches(
        self,
        img: np.ndarray,
        mask: np.ndarray,
        stride: int = None,
    ) -> Tuple[List[np.ndarray], List[np.ndarray], List[Tuple[int, int]]]:
        """
        Extract ALL valid patches from the image using grid strategy.
        Useful for inference or when you need complete coverage.

        Args:
            img: RGB image (H, W, 3)
            mask: Binary or Index mask (H, W)
            stride: Step size between patches. Defaults to patch_size.

        Returns:
            images: List of RGB patches
            masks: List of corresponding mask patches
            coordinates: List of (y_min, x_min) for each patch
        """
        h, w = img.shape[:2]

        if len(mask.shape) == 3:
            mask = mask[:, :, 0]

        if stride is None:
            stride = self.patch_size

        y_positions = list(range(0, h - self.patch_size + 1, stride))
        x_positions = list(range(0, w - self.patch_size + 1, stride))

        patches_img = []
        patches_mask = []
        coordinates = []

        for y_min in y_positions:
            for x_min in x_positions:
                y_max = y_min + self.patch_size
                x_max = x_min + self.patch_size

                mask_patch = mask[y_min:y_max, x_min:x_max]
                tissue_ratio = np.count_nonzero(mask_patch) / mask_patch.size

                if tissue_ratio >= self.min_tissue_ratio:
                    img_patch = img[y_min:y_max, x_min:x_max]
                    patches_img.append(img_patch)
                    patches_mask.append(mask_patch)
                    coordinates.append((y_min, x_min))

        return patches_img, patches_mask, coordinates

In [None]:
class PathologyDataset(Dataset):
    """Dataset optimized for histopathology images.

    Args:
        data_dir: Directory containing images and masks.
        labels_df: DataFrame with 'sample_index' and 'label' columns for training/validation.
        transform: torchvision transforms to apply to images.
        img_size: Target size to resize images to (img_size x img_size). If using patches, this is ignored.
        use_mask: Whether to use existing masks for tissue extraction.
        use_patches: Whether to load images as patches.
        patch_size: Size of each patch if using patches.
        num_patches: Number of patches to extract per image.
        patch_strategy: Strategy for patch extraction ('random' or 'grid').
        min_tissue_ratio: Minimum tissue ratio for valid patches.
        use_stain_norm: Whether to apply stain normalization.
        is_test: Whether the dataset is for testing (no labels).
        label_encoder: Optional LabelEncoder for encoding labels.
    """

    def __init__(
        self,
        data_dir: str,
        labels_df: Optional[pd.DataFrame] = None,
        transform: Optional[transforms.Compose] = None,
        use_mask: bool = True,
        use_patches: bool = False,
        patch_size: int = 224,
        num_patches: int = 8,
        patch_strategy: str = "random",
        min_tissue_ratio: float = 0.05,
        use_stain_norm: bool = True,
        is_test: bool = False,
        label_encoder: Optional[LabelEncoder] = None,
    ):
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.use_mask = use_mask
        self.use_patches = use_patches
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.patch_strategy = patch_strategy
        self.use_stain_norm = use_stain_norm
        self.is_test = is_test
        self.label_encoder = label_encoder

        # Initialize helpers
        self.tissue_extractor = TissueExtractor(
            patch_size=patch_size,
            min_tissue_ratio=min_tissue_ratio,
        )
        self.stain_normalizer = None if use_stain_norm else None

        if is_test:
            self.samples = self._get_test_samples()
            self.labels = None
            self.encoded_labels = None
        else:
            if labels_df is None:
                raise ValueError("labels_df must be provided for training/validation.")

            self.samples = [
                self._clean_sample_idx(str(idx))
                for idx in labels_df["sample_index"].tolist()
            ]
            self.labels = labels_df["label"].tolist()

            if self.label_encoder is None:
                self.label_encoder = LabelEncoder()
                self.label_encoder.fit(
                    ["Luminal A", "Luminal B", "HER2(+)", "Triple negative"]
                )
            self.encoded_labels = self.label_encoder.transform(self.labels)

    def _clean_sample_idx(self, sample_idx: str) -> str:
        """Clean sample index by removing prefix and suffix."""
        sample_idx = str(sample_idx)
        if sample_idx.startswith("img_"):
            sample_idx = sample_idx[4:]
        if sample_idx.endswith(".png"):
            sample_idx = sample_idx[:-4]
        return sample_idx

    def _get_test_samples(self) -> List[str]:
        """Get list of sample indices from test directory."""
        samples = []
        for f in sorted(self.data_dir.glob("img_*.png")):
            sample_idx = self._clean_sample_idx(f.stem)
            samples.append(sample_idx)
        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def _load_image_and_mask(
        self, sample_idx: str
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """Load image and optionally its mask."""
        img_path = self.data_dir / f"img_{sample_idx}.png"
        img = np.array(Image.open(img_path).convert("RGB"))

        mask = None
        if self.use_mask:
            mask_path = self.data_dir / f"mask_{sample_idx}.png"
            if mask_path.exists():
                mask = np.array(Image.open(mask_path).convert("L"))

        return img, mask

    def _crop_to_tissue_bbox(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
        """Crop image to bounding box of tissue region."""
        # Find bounding box of tissue
        rows = np.any(mask > 0, axis=1)
        cols = np.any(mask > 0, axis=0)

        if not rows.any() or not cols.any():
            return img  # No tissue found, return original

        y_min, y_max = np.where(rows)[0][[0, -1]]
        x_min, x_max = np.where(cols)[0][[0, -1]]

        # Add small padding
        padding = 10
        y_min = max(0, y_min - padding)
        y_max = min(img.shape[0], y_max + padding)
        x_min = max(0, x_min - padding)
        x_max = min(img.shape[1], x_max + padding)

        return img[y_min:y_max, x_min:x_max]

    def _apply_stain_normalization(self, img: np.ndarray) -> np.ndarray:
        """Apply stain normalization to image."""
        if self.use_stain_norm and self.stain_normalizer is not None:
            try:
                return self.stain_normalizer.normalize(img)
            except Exception:
                pass  # Skip normalization if it fails
        return img

    def _load_and_preprocess(self, sample_idx: str) -> np.ndarray:
        """Load and preprocess full image with optional tissue cropping."""
        img, mask = self._load_image_and_mask(sample_idx)

        # Crop to tissue region if mask is available
        if mask is not None:
            img = self._crop_to_tissue_bbox(img, mask)

        # Stain normalization
        if self.use_stain_norm:
            img = self._apply_stain_normalization(img)

        return img

    def _load_patches(self, sample_idx: str) -> List[np.ndarray]:
        """Load image as patches using TissueExtractor."""
        img, mask = self._load_image_and_mask(sample_idx)

        if mask is None:
            # If no mask, create a simple one (all tissue)
            mask = np.ones(img.shape[:2], dtype=np.uint8) * 255

        # Extract patches using TissueExtractor
        patches, _ = self.tissue_extractor.get_valid_patches(
            img=img,
            mask=mask,
            num_patches=self.num_patches,
            strategy=self.patch_strategy,
            stride=self.patch_size // 2 if self.patch_strategy == "grid" else None,
            shuffle=False,
        )

        # Handle case where fewer patches are found than requested
        if len(patches) == 0:
            # Fallback: extract center crop
            h, w = img.shape[:2]
            cy, cx = h // 2, w // 2
            half = self.patch_size // 2
            y1 = max(0, cy - half)
            x1 = max(0, cx - half)
            y2 = min(h, y1 + self.patch_size)
            x2 = min(w, x1 + self.patch_size)
            fallback_patch = img[y1:y2, x1:x2]
            fallback_patch = cv2.resize(
                fallback_patch, (self.patch_size, self.patch_size)
            )
            patches = [fallback_patch] * self.num_patches

        elif len(patches) < self.num_patches:
            # Duplicate existing patches to reach num_patches
            while len(patches) < self.num_patches:
                patches.append(patches[len(patches) % len(patches)])

        # Apply stain normalization to each patch
        normalized_patches = []
        for patch in patches:
            patch = self._apply_stain_normalization(patch)
            normalized_patches.append(patch)

        return normalized_patches

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]:
        sample_idx = self.samples[idx]

        if self.use_patches:
            patches = self._load_patches(sample_idx)

            # Transform each patch
            transformed_patches = []
            for patch in patches:
                patch_pil = Image.fromarray(patch)
                if self.transform:
                    patch_tensor = self.transform(patch_pil)
                else:
                    patch_tensor = transforms.ToTensor()(patch_pil)
                transformed_patches.append(patch_tensor)

            # Stack patches [num_patches, C, H, W]
            img_tensor = torch.stack(transformed_patches)
        else:
            img = self._load_and_preprocess(sample_idx)
            img_pil = Image.fromarray(img)

            if self.transform:
                img_tensor = self.transform(img_pil)
            else:
                img_tensor = transforms.ToTensor()(img_pil)

        if self.is_test:
            return img_tensor, sample_idx
        else:
            label = self.encoded_labels[idx]
            return img_tensor, torch.tensor(label, dtype=torch.long)

In [None]:
class PathologyDataModule(L.LightningDataModule):
    """Lightning DataModule for histopathology image classification.

    Args:
        train_data_dir: Directory containing training images and masks.
        test_data_dir: Directory containing test images and masks.
        train_labels_path: Path to CSV file with training labels.
        batch_size: Batch size for dataloaders.
        num_workers: Number of workers for dataloaders.
        img_size: Target image size (used when not using patches).
        use_mask: Whether to use masks for tissue extraction.
        use_patches: Whether to use patch-based loading.
        patch_size: Size of patches to extract.
        num_patches: Number of patches per image.
        min_tissue_ratio: Minimum tissue ratio for valid patches.
        use_stain_norm: Whether to apply stain normalization.
        val_split: Fraction of training data to use for validation.
        random_seed: Random seed for reproducibility.
    """

    def __init__(
        self,
        train_data_dir: str = Config.TRAIN_DATA_DIR,
        test_data_dir: str = Config.TEST_DATA_DIR,
        train_labels_path: str = Config.TRAIN_LABELS_PATH,
        batch_size: int = Config.BATCH_SIZE,
        num_workers: int = Config.NUM_WORKERS,
        img_size: int = Config.IMG_SIZE,
        use_mask: bool = Config.USE_MASK,
        use_patches: bool = Config.USE_PATCHES,
        patch_size: int = Config.PATCH_SIZE,
        num_patches: int = Config.NUM_PATCHES,
        min_tissue_ratio: float = Config.MIN_TISSUE_AREA,
        use_stain_norm: bool = Config.USE_STAIN_NORMALIZATION,
        val_split: float = Config.VAL_SPLIT,
        random_seed: int = Config.RANDOM_SEED,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.train_data_dir = train_data_dir
        self.test_data_dir = test_data_dir
        self.train_labels_path = train_labels_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.img_size = img_size
        self.use_mask = use_mask
        self.use_patches = use_patches
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.min_tissue_ratio = min_tissue_ratio
        self.use_stain_norm = use_stain_norm
        self.val_split = val_split
        self.random_seed = random_seed

        # Initialize label encoder
        self.label_encoder = LabelEncoder()
        self.label_encoder.fit(Config.CLASSES)

        # Will be set in setup()
        self.train_df = None
        self.val_df = None
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def _get_train_transforms(self) -> transforms.Compose:
        """Get augmentation transforms for training."""
        target_size = self.patch_size if self.use_patches else self.img_size
        return transforms.Compose(
            [
                transforms.Resize((target_size, target_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=90),
                transforms.ColorJitter(
                    brightness=0.2,
                    contrast=0.2,
                    saturation=0.1,
                    hue=0.05,
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )

    def _get_val_transforms(self) -> transforms.Compose:
        """Get transforms for validation/test (no augmentation)."""
        target_size = self.patch_size if self.use_patches else self.img_size
        return transforms.Compose(
            [
                transforms.Resize((target_size, target_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )

    def setup(self, stage: Optional[str] = None):
        """Setup datasets for each stage."""
        if stage == "fit" or stage is None:
            # Load and split training data
            full_df = pd.read_csv(self.train_labels_path)

            # Stratified split
            from sklearn.model_selection import train_test_split

            self.train_df, self.val_df = train_test_split(
                full_df,
                test_size=self.val_split,
                stratify=full_df["label"],
                random_state=self.random_seed,
            )

            # Training dataset: random strategy with half overlap
            self.train_dataset = PathologyDataset(
                data_dir=self.train_data_dir,
                labels_df=self.train_df,
                transform=self._get_train_transforms(),
                use_mask=self.use_mask,
                use_patches=self.use_patches,
                patch_size=self.patch_size,
                num_patches=self.num_patches,
                patch_strategy="random",  # Random for training
                min_tissue_ratio=self.min_tissue_ratio,
                use_stain_norm=self.use_stain_norm,
                is_test=False,
                label_encoder=self.label_encoder,
            )

            # Validation dataset: grid strategy with no overlap
            self.val_dataset = PathologyDataset(
                data_dir=self.train_data_dir,
                labels_df=self.val_df,
                transform=self._get_val_transforms(),
                use_mask=self.use_mask,
                use_patches=self.use_patches,
                patch_size=self.patch_size,
                num_patches=self.num_patches,
                patch_strategy="grid",  # Grid for validation
                min_tissue_ratio=self.min_tissue_ratio,
                use_stain_norm=self.use_stain_norm,
                is_test=False,
                label_encoder=self.label_encoder,
            )

        if stage == "test" or stage == "predict" or stage is None:
            # Test dataset: grid strategy with no overlap
            self.test_dataset = PathologyDataset(
                data_dir=self.test_data_dir,
                labels_df=None,
                transform=self._get_val_transforms(),
                use_mask=self.use_mask,
                use_patches=self.use_patches,
                patch_size=self.patch_size,
                num_patches=self.num_patches,
                patch_strategy="grid",
                min_tissue_ratio=self.min_tissue_ratio,
                use_stain_norm=self.use_stain_norm,
                is_test=True,
                label_encoder=self.label_encoder,
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )

    def predict_dataloader(self) -> DataLoader:
        return self.test_dataloader()

In [None]:
# %%


class MILGradCAM:
    def __init__(self, model, target_layer_name="layer4"):
        self.model = model
        self.model.eval()
        self.gradients = None
        self.activations = None
        self.target_layer_name = target_layer_name

        # Hook into the backbone (e.g., layer4 of ResNet)
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output

        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]

        # Find the target layer in the timm backbone
        # usually model.backbone.layer4 for ResNet
        target_layer = dict([*self.model.backbone.named_modules()])[
            self.target_layer_name
        ]

        target_layer.register_forward_hook(forward_hook)
        target_layer.register_full_backward_hook(backward_hook)

    def forward(self, x, target_class_idx=None):
        """
        Args:
            x: Input tensor [1, num_patches, C, H, W]
            target_class_idx: The class index to visualize. If None, uses predicted class.
        """
        b, n, c, h, w = x.shape

        # 1. Zero grads
        self.model.zero_grad()

        # 2. Forward pass to get logits and attention weights
        # We need to manually run parts of the forward pass to capture attention weights
        # because the original forward() consumes them internally.

        # A. Reshape
        x_reshaped = x.view(b * n, c, h, w)

        # B. Backbone features (Hooks capture activations here)
        features = self.model.backbone(x_reshaped)

        # C. Reshape features
        features = features.view(b, n, -1)

        # D. Get Attention Weights (The "Weak Supervision")
        attention_weights = None
        aggregated_features = None

        if self.model.patch_aggregation == "attention":
            attention_scores = self.model.attention(features)
            attention_weights = F.softmax(attention_scores, dim=1)  # [B, N, 1]
            aggregated_features = (attention_weights * features).sum(dim=1)
        elif self.model.patch_aggregation == "mean":
            aggregated_features = features.mean(dim=1)
            attention_weights = torch.ones((b, n, 1)).to(x.device) / n

        # E. Classifier
        logits = self.model.classifier(aggregated_features)

        # 3. Determine target class
        if target_class_idx is None:
            target_class_idx = torch.argmax(logits, dim=1).item()

        # 4. Backward pass to get gradients
        score = logits[0, target_class_idx]
        score.backward()

        return {
            "logits": logits,
            "attention_weights": attention_weights.detach().cpu().numpy(),  # [1, N, 1]
            "activations": self.activations.detach().cpu().numpy(),  # [N, 2048, 7, 7]
            "gradients": self.gradients.detach().cpu().numpy(),  # [N, 2048, 7, 7]
            "target_class": target_class_idx,
        }

    def generate_cam(self, activations, gradients):
        """
        Compute Grad-CAM heatmap from activations and gradients.
        Standard Grad-CAM formula: ReLU(Sum(Weights * Activations))
        """
        # Global Average Pooling of gradients to get weights
        weights = np.mean(gradients, axis=(2, 3))  # [N, Channels]

        # Weighted sum of activations
        # activations: [N, C, H, W]
        # weights: [N, C] -> reshape to [N, C, 1, 1]
        weights = weights[:, :, np.newaxis, np.newaxis]

        cam = np.sum(weights * activations, axis=1)  # [N, H, W]

        # Apply ReLU
        cam = np.maximum(cam, 0)

        # Normalize per patch
        cams_normalized = []
        for c in cam:
            if np.max(c) > 0:
                c = c / np.max(c)
            cams_normalized.append(c)

        return np.array(cams_normalized)

In [None]:
# %%
def visualize_weakly_supervised_cam(
    model, img_path, mask_path=None, patch_size=224, device="cuda"
):
    """
    Visualizes:
    1. The Attention Map (Which patches define the class).
    2. The Grad-CAM (Which pixels inside those patches define the class).
    """
    model.to(device)
    grad_cam = MILGradCAM(model, target_layer_name="layer4")

    # --- 1. Load Image & Prepare Patches ---
    img = np.array(Image.open(img_path).convert("RGB"))
    original_h, original_w = img.shape[:2]

    # Create a dummy mask if none exists (all tissue)
    if mask_path and Path(mask_path).exists():
        mask = np.array(Image.open(mask_path).convert("L"))
    else:
        mask = np.ones((original_h, original_w), dtype=np.uint8) * 255

    # Use grid extraction to cover the image
    extractor = TissueExtractor(patch_size=patch_size, min_tissue_ratio=0.05)
    patches_img, _, coords = extractor.get_all_valid_patches(
        img, mask, stride=patch_size
    )

    if len(patches_img) == 0:
        print("No tissue found.")
        return

    # Transform patches
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    input_tensor = torch.stack([transform(Image.fromarray(p)) for p in patches_img])
    input_tensor = input_tensor.unsqueeze(0).to(device)  # [1, N, C, H, W]

    # --- 2. Run Grad-CAM ---
    results = grad_cam.forward(input_tensor)

    # --- 3. Process Results ---
    attention_weights = results["attention_weights"][0, :, 0]  # [N]
    activations = results["activations"]
    gradients = results["gradients"]

    # Generate pixel-level heatmaps for patches
    patch_cams = grad_cam.generate_cam(activations, gradients)  # [N, 7, 7]

    # --- 4. Reconstruct Maps ---

    # Map 1: Attention Map (Blocky heatmap based on patch importance)
    attention_map = np.zeros((original_h, original_w), dtype=np.float32)
    # Map 2: Fine-grained Grad-CAM (Detailed heatmap)
    gradcam_map = np.zeros((original_h, original_w), dtype=np.float32)
    # Count map for averaging overlapping patches (if stride < patch_size)
    count_map = np.zeros((original_h, original_w), dtype=np.float32)

    # Normalize attention weights for visualization
    att_vis = (attention_weights - attention_weights.min()) / (
        attention_weights.max() - attention_weights.min() + 1e-8
    )

    for i, (y, x) in enumerate(coords):
        # Resize low-res CAM (7x7) to patch size (224x224)
        resized_cam = cv2.resize(patch_cams[i], (patch_size, patch_size))

        # Attention value for this patch
        att_val = att_vis[i]

        # Add to global maps
        # Note: We multiply CAM by Attention. High attention patch + High CAM pixel = High importance
        gradcam_map[y : y + patch_size, x : x + patch_size] += resized_cam * att_val
        attention_map[y : y + patch_size, x : x + patch_size] += att_val
        count_map[y : y + patch_size, x : x + patch_size] += 1

    # Average out overlaps
    mask_indices = count_map > 0
    gradcam_map[mask_indices] /= count_map[mask_indices]
    attention_map[mask_indices] /= count_map[mask_indices]

    # --- 5. Plotting ---
    predicted_class = Config.CLASSES[results["target_class"]]

    fig, axs = plt.subplots(1, 3, figsize=(20, 6))

    # Original
    axs[0].imshow(img)
    axs[0].set_title(f"Original Image\nPred: {predicted_class}")
    axs[0].axis("off")

    # Attention Map (Instance Importance)
    heatmap_att = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
    heatmap_att = cv2.cvtColor(heatmap_att, cv2.COLOR_BGR2RGB)
    overlay_att = cv2.addWeighted(img, 0.6, heatmap_att, 0.4, 0)
    axs[1].imshow(overlay_att)
    axs[1].set_title("Attention Map (Weak Supervision)\nShows important patches")
    axs[1].axis("off")

    # Grad-CAM (Pixel Importance)
    # Normalize for visualization
    gradcam_map = gradcam_map / (np.max(gradcam_map) + 1e-8)
    heatmap_cam = cv2.applyColorMap(np.uint8(255 * gradcam_map), cv2.COLORMAP_JET)
    heatmap_cam = cv2.cvtColor(heatmap_cam, cv2.COLOR_BGR2RGB)
    overlay_cam = cv2.addWeighted(img, 0.6, heatmap_cam, 0.4, 0)
    axs[2].imshow(overlay_cam)
    axs[2].set_title("Weighted Grad-CAM\nAttention * Backbone Gradients")
    axs[2].axis("off")

    plt.tight_layout()
    plt.show()


# %% [markdown]
# ## Execute Visualization
# Select a sample from the test set or validation set to visualize.

In [None]:
# %%
# 1. Load Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = trainer.checkpoint_callback.best_model_path
model = PathologyModel.load_from_checkpoint(checkpoint_path)
model.eval()

# 2. Pick an image
# Let's grab an image from the test folder
test_files = sorted(list(Path(Config.TEST_DATA_DIR).glob("img_*.png")))
if test_files:
    sample_img_path = test_files[0]
    sample_mask_path = (
        Path(Config.TEST_DATA_DIR) / f"mask_{sample_img_path.name.split('_')[1]}"
    )

    print(f"Visualizing: {sample_img_path}")

    visualize_weakly_supervised_cam(
        model=model,
        img_path=str(sample_img_path),
        mask_path=str(sample_mask_path),
        patch_size=Config.PATCH_SIZE,
        device=device,
    )
else:
    print("No test images found to visualize.")