In [None]:
# Standard library imports
import os
import re
import random
import shutil
from datetime import datetime
from collections import OrderedDict

# Data handling and visualization libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch and TorchVision libraries for deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms, datasets
from PIL import Image

# Scikit-Learn for evaluation metrics and data splitting|
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report

# Imbalanced-learn for oversampling to address class imbalance
from imblearn.over_sampling import RandomOverSampler

from torchvision import transforms
from torchvision.transforms import functional as F
from PIL import Image
from PIL import Image
from PIL import UnidentifiedImageError
from tqdm import tqdm

import os, re, random
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split
from google.colab import drive
from typing import List, Dict, Any, Tuple, Optional

In [None]:
# ─── 1) MOUNT GOOGLE DRIVE ──────────────────────────────────────────────
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
# ─── LOAD LABELS & BUILD CLASS MAP ───────────────────────────────────────
labels = pd.read_csv(
    "/content/drive/MyDrive/case_grade_match.csv"
).drop(index=64, errors='ignore').reset_index(drop=True)

In [None]:
# Note that case 65 is not included in the analysis as it was not labelled
labels.loc[62:65,:]

Unnamed: 0,Case,Class
62,63,3
63,64,3
64,66,3
65,67,4


In [None]:
filtered_patches_dir = '/content/drive/MyDrive/CMIL_SP2025_Patches_Apr27'

if os.path.exists("/content/drive/MyDrive/patch_dir.npy"):
    all_files = np.load("/content/drive/MyDrive/patch_dir.npy", allow_pickle=True).tolist()
else:
    all_files = os.listdir('/content/drive/MyDrive/CMIL_SP2025_Patches_Apr27')

len(all_files)

83717

In [None]:
# === Slice-Level Grouping ===

# Edit 10/19: identified scenario where patches are not being matched:
#     1) missing space between 'match' / 'unmatched' and number, eg. case_38_match14_h&e_patch25.png
#     2) Additionally, some names include 'labels' (not relevant to this regex which is only looking at case matching but thought I would flag),
#         eg. case_82_unmatched3_h&e-labels_patch32.png
def group_patches_by_slice(all_files, root):
    case_slices = defaultdict(list)
    # list of unconventionally-named patches
    invalid_file_names = []
    flexibility_needed_counter = 0
    for filename in all_files:
        if filename.endswith(".png"):
            match = re.match(r"case_(\d+)_([a-z]+_\d+)_", filename) #check if some patches are named differently
            if match:
                case_id = int(match.group(1))
                slice_id = match.group(2)
                key = (case_id, slice_id)
                case_slices[key].append(os.path.join(root, filename))

                continue
            # if a file doesn't match, try regex without "_" between "match" / "unmatched" and number
            match = re.match(r"case_(\d+)_([a-z]+\d+)_", filename)
            if match:
                case_id = int(match.group(1))
                slice_id = match.group(2)
                # adding underscore between "match" / "unmatched" and number
                slice_id = re.sub(r'([A-Za-z])(\d)', r'\1_\2', slice_id)
                key = (case_id, slice_id)
                case_slices[key].append(os.path.join(root, filename))
                flexibility_needed_counter += 1

                continue
            invalid_file_names.append(os.path.join(root, filename))

    # Print summary of invalid files
    if invalid_file_names:
        print(f"Found {len(invalid_file_names)} files not following naming convention:")
        for f in invalid_file_names:
            print("  ", f)
    else:
        print(f"All {flexibility_needed_counter} invalid file names were handled.")
    return case_slices

In [None]:
# === Load Patches by Slice ===
patches = group_patches_by_slice(all_files, filtered_patches_dir)

All 3808 invalid file names were handled.


In [None]:
len(patches)

378

In [None]:
tot_patches = 0
for ke in patches.keys():
  tot_patches = tot_patches + len(patches[ke])
tot_patches

83717

Some patches are getting lost (need to make the 'match = re.match(....)' code more general)

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
# Training transform (augmented)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
# === Build Label Map by Slice ===
slice_to_class = {}
valid_classes = [1.0, 3.0, 4.0]
for (case_id, slice_id), paths in patches.items():
    raw_label = labels.loc[labels['Case'] == case_id, 'Class']
    if not raw_label.empty and raw_label.item() in valid_classes:
        label = 0 if raw_label.item() == 1.0 else 1
        slice_to_class[(case_id, slice_id)] = label

In [None]:
len(slice_to_class) # Low-grade slices are removed

366

In [None]:
slice_to_class[(106, 'match_1')]

1

In [None]:
# === Stratified Split by Slice ===
slices_by_class = defaultdict(list)
for key, label in slice_to_class.items():
    slices_by_class[label].append(key) # Dictonary of length 2

In [None]:
slices_by_class.keys()

dict_keys([1, 0])

In [None]:
# number of cases
unique_case = {t[0] for v in slices_by_class.values() for t in v}
len(unique_case)

86

In [None]:
def split_by_case_stratified(slices_by_class, random_state=42):
    # 1) Build case -> label map and validate no mixed-label cases
    case_to_labels = defaultdict(set)
    for label, items in slices_by_class.items():
        for case_id, _ in items:
            case_to_labels[case_id].add(label)

    # Flatten to case list and aligned labels
    case_ids = []
    case_labels = []
    for cid, labs in case_to_labels.items():
        case_ids.append(cid)
        case_labels.append(next(iter(labs)))  # the single label for this case

    # 2) Split cases with stratification by case-level label
    # 60% train, 40% temp
    case_train, case_temp, y_train, y_temp = train_test_split(
        case_ids, case_labels, test_size=0.4, stratify=case_labels, random_state=random_state
    )
    # temp -> 50/50 to make 20%/20% val/test
    case_val, case_test, _, _ = train_test_split(
        case_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=random_state
    )

    case_train = set(case_train)
    case_val   = set(case_val)
    case_test  = set(case_test)

    # 3) Map case splits back to slice-level lists (no leakage)
    train_slices, val_slices, test_slices = [], [], []
    for label, items in slices_by_class.items():
        for case_id, slice_key in items:
            if case_id in case_train:
                train_slices.append((case_id, slice_key))
            elif case_id in case_val:
                val_slices.append((case_id, slice_key))
            elif case_id in case_test:
                test_slices.append((case_id, slice_key))
            else:
                print('critical error! found case id not in train, test, or val')

    return train_slices, val_slices, test_slices

train_slices, val_slices, test_slices = split_by_case_stratified(slices_by_class)


In [None]:
# check train, test and validation set doesnt have overlapping cases
set([slice[0] for slice in train_slices]) & set([slice[0] for slice in test_slices]) & set([slice[0] for slice in val_slices])

set()

In [None]:
# number of cases after spliting
len(set([slice[0] for slice in train_slices]) | set([slice[0] for slice in test_slices]) | set([slice[0] for slice in val_slices]))

86

In [None]:
train_patches = {k: patches[k] for k in train_slices}
val_patches   = {k: patches[k] for k in val_slices}
test_patches  = {k: patches[k] for k in test_slices}

In [None]:
# Provide summarize info of number of cases, patches, class distribution and number of slices
def summarize_split(patch_dict, label_map=None, split_name=""):
    # Count cases
    num_cases = len(set(k[0] for k in patch_dict if isinstance(k, tuple) and len(k) > 0))

    # Count total patches
    total_patches = sum(len(paths) for paths in patch_dict.values())

    # Count per class (if label map provided)
    class_counts = Counter()
    if label_map:
        class_counts.update([label_map[c] for c in patch_dict.keys() if c in label_map])

    # Output
    print(f"\n🔹 {split_name} Split")
    print(f"Cases: {num_cases}")
    print(f"Slices: {len(patch_dict)}")
    print(f"Patches: {total_patches}")
    if class_counts:
        print(f"Class distribution: {dict(class_counts)}")

# Run for all splits
summarize_split(train_patches, slice_to_class, split_name="Train")
summarize_split(val_patches, slice_to_class, split_name="Validation")
summarize_split(test_patches, slice_to_class, split_name="Test")


🔹 Train Split
Cases: 51
Slices: 208
Patches: 45518
Class distribution: {1: 111, 0: 97}

🔹 Validation Split
Cases: 17
Slices: 74
Patches: 16041
Class distribution: {1: 56, 0: 18}

🔹 Test Split
Cases: 18
Slices: 84
Patches: 18569
Class distribution: {1: 71, 0: 13}


## From slice level to case level

In [None]:
def build_case_dict(patches, slice_to_class):
    """
    patches: {(case_id, slice_id): [paths,...]}  # mixed-stain lists
    slice_to_class: {(case_id, slice_id): int}   # e.g. ('match_1')
    returns:
      case_dict = { case_id: { stain: [ [patches_of_slice1], [patches_of_slice2], ... ] } }
      label_map = { case_id: int }
    """

    # simple pattern to get stain from filename
    stain_re = re.compile(r"(h&e|melan|sox10)", re.IGNORECASE)

    case_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))  # case->stain->slice->paths
    label_map = {}

    for (case_id, slice_id), path_list in patches.items():
        for p in path_list:
            fname = os.path.basename(p).lower()
            m = stain_re.search(fname)
            if not m:
                continue
            stain = m.group(1).lower()

            # bucket by case -> stain -> slice
            case_dict[case_id][stain][slice_id].append(p)

            # set case label once (assume consistent by leakage check)
            lab = slice_to_class.get((case_id, slice_id))
            if lab is not None:
                label_map[case_id] = lab

    # convert slice dicts to lists of lists (dataset expects lists)
    case_dict_lists = {}
    for cid, stain_dict in case_dict.items():
        case_dict_lists[cid] = {}
        for stain, slice_map in stain_dict.items():
            slices = [slice_map[k] for k in sorted(slice_map.keys())]  # stable order
            case_dict_lists[cid][stain] = slices

    return case_dict_lists, label_map


# check no data leakage
def get_case_ids(case_dict):
    return set(case_dict.keys())

def get_all_paths(case_dict):
    """Flatten to a set of full file paths."""
    out = set()
    for stains in case_dict.values():            # dict: stain -> list of slices
        for slice_lists in stains.values():      # list of slices
            for paths in slice_lists:            # each slice is a list[str]
                out.update(paths)
    return out

def check_disjoint_sets(A, B, nameA="A", nameB="B"):
    inter = A & B
    ok = len(inter) == 0
    return ok, inter

def report_no_leak(train_case_dict, val_case_dict, test_case_dict):
    # case-level
    train_cases = get_case_ids(train_case_dict)
    val_cases   = get_case_ids(val_case_dict)
    test_cases  = get_case_ids(test_case_dict)

    print("Cases per split:", len(train_cases), len(val_cases), len(test_cases))

    ok_tv, leak_tv = check_disjoint_sets(train_cases, val_cases, "train", "val")
    ok_tt, leak_tt = check_disjoint_sets(train_cases, test_cases, "train", "test")
    ok_vt, leak_vt = check_disjoint_sets(val_cases,   test_cases, "val",   "test")

    # file-level
    train_paths = get_all_paths(train_case_dict)
    val_paths   = get_all_paths(val_case_dict)
    test_paths  = get_all_paths(test_case_dict)

    print("Paths per split:", len(train_paths), len(val_paths), len(test_paths))

    ok_tv_p, leak_tv_p = check_disjoint_sets(train_paths, val_paths, "train", "val")
    ok_tt_p, leak_tt_p = check_disjoint_sets(train_paths, test_paths, "train", "test")
    ok_vt_p, leak_vt_p = check_disjoint_sets(val_paths,   test_paths, "val",   "test")

    # Summary
    def summarise(ok, leak, label):
        if ok:
            print(f"No leakage between {label}.")
        else:
            print(f"[LEAK!!!! Nooo] {label} overlap count = {len(leak)}")

    summarise(ok_tv,   leak_tv,   "train & val (cases)")
    summarise(ok_tt,   leak_tt,   "train & test (cases)")
    summarise(ok_vt,   leak_vt,   "val & test (cases)")
    summarise(ok_tv_p, leak_tv_p, "train & val (paths)")
    summarise(ok_tt_p, leak_tt_p, "train & test (paths)")
    summarise(ok_vt_p, leak_vt_p, "val & test (paths)")

def summarize_case_dict(case_dict, label_map=None, split_name="train"):
    """
    Returns a DataFrame with per-case counts of:
      - total patches
      - patches per stain
      - number of slices per stain
      - missing stain flags
      - label (if label_map given)
    """
    records = []

    for case_id, stains in case_dict.items():
        record = {"case_id": case_id, "split": split_name}
        total_patches = 0

        for stain in ("h&e", "melan", "sox10"):
            slice_lists = stains.get(stain, [])
            num_slices = len(slice_lists)
            num_patches = sum(len(paths) for paths in slice_lists)
            record[f"{stain}_slices"] = num_slices
            record[f"{stain}_patches"] = num_patches
            record[f"{stain}_missing"] = int(num_patches == 0)
            total_patches += num_patches

        record["total_patches"] = total_patches
        if label_map and case_id in label_map:
            record["label"] = label_map[case_id]
        else:
            record["label"] = None

        records.append(record)

    return pd.DataFrame.from_records(records)

class StainBagCaseDataset(Dataset):
    """
    For each CASE returns:
    {
      "case_id": int/str,
      "stain_slices": {
          "h&e":   [ Tensor(P1, C, H, W), Tensor(P2, C, H, W), ... ],  # list of S_h&e slices
          "melan": [ ... ],
          "sox10": [ ... ],
      },
      "label": LongTensor scalar
    }

    Inputs:
      case_dict: { case_id: { stain: [ [patches_of_slice1], [patches_of_slice2], ... ] } }
      label_map: { case_id: int_label }

    Notes:
      - No merging across slices: each slice stays a separate bag (list element).
      - Missing stains return an empty list [].
    """

    def __init__(
        self,
        case_dict: Dict[Any, Dict[str, List[List[str]]]],
        label_map: Dict[Any, int],
        transform=None,
        stains: Tuple[str, ...] = ("h&e", "melan", "sox10"),
        per_slice_cap: Optional[int] = 800,        # cap patches per slice at 800
        max_slices_per_stain: Optional[int] = None, # optional: cap #slices per stain
        shuffle_patches: bool = True,
        drop_empty_slices: bool = True,            # drop slices that fail to load any patch
    ):

        self.transform = transform
        self.stains = list(stains)
        self.per_slice_cap = per_slice_cap
        self.max_slices_per_stain = max_slices_per_stain
        self.shuffle_patches = shuffle_patches
        self.drop_empty_slices = drop_empty_slices

        # Flatten case_dict to an indexable list; normalize stain keys
        self.items = []
        for case_id, stain_map in case_dict.items():
            if case_id not in label_map:
                continue

            norm_map = {}
            for k, v in stain_map.items():
                kk = k.lower()
                norm_map[kk] = v

            self.items.append((case_id, norm_map))

        self.label_map = label_map

    def __len__(self):
        return len(self.items)

    def _load_slice_tensor(self, paths: List[str]) -> Optional[torch.Tensor]:
        """
        One slice = list of patch paths -> Tensor(P, C, H, W).
        Applies transform; shuffles & caps per-slice; skips unreadable images.
        """
        patch_paths = list(paths)
        if self.shuffle_patches:
            random.shuffle(patch_paths)

        if self.per_slice_cap and len(patch_paths) > self.per_slice_cap:
            patch_paths = patch_paths[:self.per_slice_cap]

        imgs = []
        for p in patch_paths:
            try:
                img = Image.open(p).convert("RGB")
                if self.transform:
                    img = self.transform(img)
                imgs.append(img)
            except Exception:
                # unreadable
                continue

        if len(imgs) == 0:
            return None
        return torch.stack(imgs)  # (P, C, H, W)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        case_id, stain_map = self.items[idx]
        label = torch.tensor(self.label_map[case_id], dtype=torch.long)

        stain_slices: Dict[str, List[torch.Tensor]] = {}
        for stain in self.stains:
            # Get list of slices (each slice is list[str] of patch paths)
            slice_lists = stain_map.get(stain, []) or []

            # Optionally cap the number of slices per stain
            if self.max_slices_per_stain is not None and len(slice_lists) > self.max_slices_per_stain:
                # deterministic crop
                slice_lists = slice_lists[:self.max_slices_per_stain]

            tensors_for_stain: List[torch.Tensor] = []
            for sl in slice_lists:
                if not sl:
                    continue
                sl_tensor = self._load_slice_tensor(sl)
                if sl_tensor is None:
                    if self.drop_empty_slices:
                        continue
                    else:
                        # Represent empty as zero-length tensor??
                        continue
                tensors_for_stain.append(sl_tensor)

            stain_slices[stain] = tensors_for_stain

        return {
            "case_id": case_id,
            "stain_slices": stain_slices,  # dict[stain] -> list[Tensor(P,C,H,W)]
            "label": label,
        }


def case_collate_fn(batch):
    """
    Keep variable-length structures; model.forward expects List[case_dict].
    Use batch_size=1 to avoid padding/masking.
    """
    return batch


In [None]:
train_case_dict, train_label_map = build_case_dict(train_patches, slice_to_class)
val_case_dict,   val_label_map   = build_case_dict(val_patches, slice_to_class)
test_case_dict,  test_label_map  = build_case_dict(test_patches, slice_to_class)

report_no_leak(train_case_dict, val_case_dict, test_case_dict)

# Build the tables for each split
train_df = summarize_case_dict(train_case_dict, train_label_map, "train")
val_df   = summarize_case_dict(val_case_dict,   val_label_map,   "val")
test_df  = summarize_case_dict(test_case_dict,  test_label_map,  "test")

# Combine
all_df = pd.concat([train_df, val_df, test_df], ignore_index=True)

print("\nLabel distribution per split:")
print(all_df.groupby(["split", "label"])["case_id"].nunique().unstack(fill_value=0))

stain_patch_cols = ["h&e_patches", "melan_patches", "sox10_patches"]
print("\nMean patches per stain per split:")
print(all_df.groupby("split")[stain_patch_cols].mean().round(1))

print("\nMedian patches per stain per split:")
print(all_df.groupby("split")[stain_patch_cols].median().round(1))

missing_cols = ["h&e_missing", "melan_missing", "sox10_missing"]
print("\nMissing stain proportion:")
print((all_df.groupby("split")[missing_cols].mean()).round(3))

train_ds = StainBagCaseDataset(
    train_case_dict, train_label_map,
    transform=train_transform,
    shuffle_patches=True,
)

val_ds = StainBagCaseDataset(
    val_case_dict, val_label_map,
    transform=transform,
    shuffle_patches=False, # determinism for eval
)

test_ds = StainBagCaseDataset(
    test_case_dict, test_label_map,
    transform=transform,
    shuffle_patches=False,
)

train_loader = DataLoader(
    train_ds, batch_size=1, shuffle=True,
    num_workers=2, pin_memory=True, collate_fn=case_collate_fn
)
val_loader = DataLoader(
    val_ds, batch_size=1, shuffle=False,
    num_workers=2, pin_memory=True, collate_fn=case_collate_fn
)
test_loader = DataLoader(
    test_ds, batch_size=1, shuffle=False,
    num_workers=2, pin_memory=True, collate_fn=case_collate_fn
)

print('DataLoader created successfully!')

Cases per split: 51 17 18
Paths per split: 45518 16041 18569
No leakage between train & val (cases).
No leakage between train & test (cases).
No leakage between val & test (cases).
No leakage between train & val (paths).
No leakage between train & test (paths).
No leakage between val & test (paths).

Label distribution per split:
label   0   1
split        
test    6  12
train  17  34
val     5  12

Mean patches per stain per split:
       h&e_patches  melan_patches  sox10_patches
split                                           
test         444.1          293.8          293.8
train        468.9          235.7          187.9
val          460.7          275.2          207.7

Median patches per stain per split:
       h&e_patches  melan_patches  sox10_patches
split                                           
test         283.0          222.5          242.0
train        323.0          182.0          165.0
val          429.0          172.0          124.0

Missing stain proportion:
       h&

In [None]:

# # get one case from the train loader
# sample_batch = next(iter(train_loader))
# sample = sample_batch[0]

# print(f"Case ID: {sample['case_id']}")
# print(f"Label: {sample['label'].item()}")

# for stain, slices in sample["stain_slices"].items():
#     if not slices:  # empty list (no slices for this stain)
#         print(f"{stain}: missing")
#         continue

#     # quick summary per stain
#     n_slices = len(slices)
#     n_patches_total = sum(sl.shape[0] for sl in slices)
#     shapes = [tuple(sl.shape) for sl in slices[:3]]  # show up to 3 slice shapes

#     print(f"{stain}: {n_slices} slices | {n_patches_total} patches total")
#     print(f"  example slice shapes: {shapes}")

In [None]:
# this pools the patch level features into single bag level representation for MIL
class AttentionPool(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
      #input_dim: the dimensionality of each input feature vector (e.g., 512 if you’re using a CNN encoder like ResNet).
      #hidden_dim: the size of the hidden layer used inside the attention mechanism (default = 128).
      #This defines how much capacity the attention sub-network has to learn complex relationships.
        super().__init__()
        # creates small neural network to compute attention scores for each patch
        # each patch embedding is passed through a linear layer, tanh for nonlinearity, and another linear layer to get a scalar score
        self.attention = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(), #non-linear layer with vals -1 and 1
            nn.Linear(hidden_dim, 1) # linear layer that gives scalar
        )
    def forward(self, x, return_weights=False):
        # x is of shape (B, M, D) where B is batch size or number of cases,
        # M is number of patches per bag, and D is the embedding dimensions for each patch
        weights = self.attention(x)     # (B, M, 1)
        weights = torch.softmax(weights, dim=1) # softmax so that weights are positive
        # outputs attention scores for each patch and normalized with softmax
        # D is the embedding dimension which is size of feature vector for each patch after going through the patch classifier
        weighted_x = (weights * x).sum(dim=1)  # (B, D) # multiplying the weights by z (z is the tensor of shape B, D)
        # returning the raw attention weights per patch just to help with visualization of the weights for each patch
        if return_weights:
            return weighted_x, weights.squeeze(-1)  # (B, D), (B, M) # D is a 1D vector of the low-level embedding
        return weighted_x # use these weights to visualize which patches have high attention
        # B =1 means 1 slice per batch, each patch has a weight, so there are M weights (M is # of patches per batch)
        # they took B as 1 as last quarter
        # B = 1 slice with M patches

In [None]:
class HierarchicalAttnMIL(nn.Module):
    def __init__(self, base_model, num_classes=2, embed_dim=512):
        super().__init__()

        # Shared feature extractor (pretrained CNN)
        self.features = base_model.features

        # Adaptive pooling to get richer features than just 1x1
        self.pool = nn.AdaptiveAvgPool2d((2, 2))

        # Patch projector: maps CNN features to patch embeddings
        self.patch_projector = nn.Linear(base_model.classifier.in_features * 4, embed_dim)

        # First level: Patch-level attention (within each stain-slice)
        self.patch_attention = AttentionPool(embed_dim)

        # Second level: Stain-level attention (across slices within each stain)
        self.stain_attention = AttentionPool(embed_dim)

        # Third level: Case-level attention (across different stains)
        self.case_attention = AttentionPool(embed_dim)

        # Final classifier
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, stain_slices_dict, return_attn_weights=False):
        """
        Input: stain_slices_dict = {
            "h&e": [slice1_tensor, slice2_tensor, ...],   # each slice: (P, C, H, W)
            "melan": [slice1_tensor, slice2_tensor, ...],
            "sox10": [slice1_tensor, slice2_tensor, ...]
        }
        """

        stain_embeddings = {}
        stain_attention_weights = {}

        # Process each stain type separately
        for stain_name, slice_list in stain_slices_dict.items():
            if not slice_list:  # Skip if no slices for this stain
                continue

            slice_embeddings = []
            slice_attention_weights = []

            # Process each slice within this stain
            for slice_tensor in slice_list:
                # slice_tensor shape: (P, C, H, W) where P = number of patches
                P, C, H, W = slice_tensor.shape

                # Extract features for all patches in this slice
                patch_features = self.features(slice_tensor)  # (P, F, h, w)
                pooled = self.pool(patch_features).view(P, -1)  # (P, 4*F)
                patch_embeddings = self.patch_projector(pooled)  # (P, D)

                # Apply patch-level attention to get slice embedding
                if return_attn_weights:
                    slice_emb, patch_weights = self.patch_attention(
                        patch_embeddings.unsqueeze(0), return_weights=True
                    )
                    slice_attention_weights.append(patch_weights.squeeze(0))
                else:
                    slice_emb = self.patch_attention(patch_embeddings.unsqueeze(0))

                slice_embeddings.append(slice_emb.squeeze(0))  # (D,) (D is 512 in our case!)

            # Stack slice embeddings for this stain
            if slice_embeddings:
                stain_slice_embeddings = torch.stack(slice_embeddings)  # (num_slices, D)

                # Apply stain-level attention across slices
                if return_attn_weights:
                    stain_emb, stain_weights = self.stain_attention(
                        stain_slice_embeddings.unsqueeze(0), return_weights=True
                    )
                    stain_attention_weights[stain_name] = {
                        'slice_weights': stain_weights.squeeze(0),
                        'patch_weights': slice_attention_weights
                    }
                else:
                    stain_emb = self.stain_attention(stain_slice_embeddings.unsqueeze(0))

                stain_embeddings[stain_name] = stain_emb.squeeze(0)  # (D,)

        # If no stains have data, return zero logits
        if not stain_embeddings:
            logits = torch.zeros(1, self.classifier.out_features).to(next(self.parameters()).device)
            if return_attn_weights:
                return logits, {}
            return logits

        # Stack stain embeddings for case-level attention (fusion point)
        stain_emb_list = list(stain_embeddings.values())
        case_stain_embeddings = torch.stack(stain_emb_list)  # (num_stains, D)

        # Apply case-level attention across stains
        if return_attn_weights:
            case_emb, case_weights = self.case_attention(
                case_stain_embeddings.unsqueeze(0), return_weights=True
            )
            # Package all attention weights for return -- can use this for identifying important cases / stains / slices !
            all_weights = {
                'case_weights': case_weights.squeeze(0),
                'stain_weights': stain_attention_weights,
                'stain_order': list(stain_embeddings.keys())
            }
        else:
            case_emb = self.case_attention(case_stain_embeddings.unsqueeze(0))

        # Final classification
        logits = self.classifier(case_emb)  # (1, num_classes)

        if return_attn_weights:
            return logits.squeeze(0), all_weights

        # Old model expected batched input. New model works with one case at a time (batch_size=1 in your data loader), so return single prediction
        return logits.squeeze(0)

## hannah's code below

In [None]:
from collections import defaultdict

def count_patches_by_class(patch_dict, case_to_class, split_name):
    class_patch_counts = defaultdict(int)

    for case, patches in patch_dict.items():
        label = case_to_class[case]
        class_patch_counts[label] += len(patches)

    print(f"\n🧬 Patch count by class for {split_name}:")
    print(f"  Benign (0):     {class_patch_counts[0]} patches")
    print(f"  High-grade (1): {class_patch_counts[1]} patches")

    return class_patch_counts

# Count patches per class for each split
train_counts = count_patches_by_class(train_patches, slice_to_class, "Train")
val_counts   = count_patches_by_class(val_patches, slice_to_class, "Validation")
test_counts  = count_patches_by_class(test_patches, slice_to_class, "Test")


🧬 Patch count by class for Train:
  Benign (0):     18125 patches
  High-grade (1): 27393 patches

🧬 Patch count by class for Validation:
  Benign (0):     3568 patches
  High-grade (1): 12473 patches

🧬 Patch count by class for Test:
  Benign (0):     4113 patches
  High-grade (1): 14456 patches


In [None]:
# def save_checkpoint_to_drive(model, arch, optimizer, epoch, drive_folder="MyDrive/Checkpoints"):
#     checkpoint_dir = os.path.join("/content/drive", drive_folder)
#     os.makedirs(checkpoint_dir, exist_ok=True)

#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     filename = f"{checkpoint_dir}/{timestamp}_{arch}_epoch{epoch}.pth"

#     checkpoint = {
#         "arch": arch,
#         "model_state_dict": model.state_dict(),
#         "epoch": epoch,
#         "optimizer_state_dict": optimizer.state_dict(),
#     }

#     torch.save(checkpoint, filename)
#     print(f"✅ Checkpoint saved to Google Drive: {filename}")

# def validation(model, criterion, val_loader):
#     val_loss = 0
#     correct_total = 0
#     sample_total = 0
#     model.eval()
#     with torch.no_grad():
#         for bags, labels in val_loader:
#             bags = bags.to(device)
#             labels = labels.to(device)
#             outputs = model(bags)
#             loss = criterion(outputs, labels)
#             val_loss += loss.item() * labels.size(0)
#             preds = torch.argmax(outputs, dim=1)
#             correct_total += (preds == labels).sum().item()
#             sample_total += labels.size(0)
#     return val_loss / sample_total, correct_total / sample_total

# # def train_model(model, optimizer, criterion, train_loader, val_loader, arch, checkpoint_dir, epochs=5):
# def train_model(model, optimizer, criterion, train_loader, val_loader, arch, epochs=5, start_epoch=0):
#     for epoch in range(start_epoch, epochs):
#         model.train()
#         running_loss = 0
#         for bags, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
#             bags = bags.to(device)
#             labels = labels.to(device)
#             optimizer.zero_grad()
#             outputs = model(bags)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()
#         val_loss, val_acc = validation(model, criterion, val_loader)
#         print(f"Epoch {epoch+1}/{epochs}, Train Loss: {running_loss/len(train_loader):.3f}, Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.3f}")
#         save_checkpoint_to_drive(model, arch, optimizer, epoch+1, drive_folder="MyDrive/Checkpoints")
#     return model

In [None]:
def save_checkpoint_to_drive(model, arch, optimizer, epoch, drive_folder="MyDrive/Checkpoints"):
    checkpoint_dir = os.path.join("/content/drive", drive_folder)
    os.makedirs(checkpoint_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{checkpoint_dir}/{timestamp}_{arch}_epoch{epoch}.pth"

    checkpoint = {
        "arch": arch,
        "model_state_dict": model.state_dict(),
        "epoch": epoch,
        "optimizer_state_dict": optimizer.state_dict(),
    }

    torch.save(checkpoint, filename)
    print(f"✅ Checkpoint saved to Google Drive: {filename}")


In [None]:
# function adds Complete training loop with progress bars using tqdm

# Validation function to monitor performance during training

# Proper tensor shape handling - dealing with batch dimensions since you're using batch_size=1

# Loss calculation and backpropagation

# Accuracy tracking

def train_model(model, optimizer, criterion, train_loader, val_loader, arch, epochs=5, start_epoch=0):
    model.train()

    for epoch in range(start_epoch, epochs):
        running_loss = 0.0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            case_data = batch[0]  # Get the first (and only) case in the batch

            stain_slices = case_data["stain_slices"]
            label = case_data["label"].to(device)

            # Forward pass - model outputs [2] (logits for 2 classes)
            outputs = model(stain_slices)

            # Add batch dimension: [2] -> [1, 2]
            outputs = outputs.unsqueeze(0)

            # Ensure label has batch dimension: scalar -> [1]
            if label.dim() == 0:
                label = label.unsqueeze(0)

            # Now shapes are:
            # outputs: [1, 2] (batch_size=1, num_classes=2)
            # label: [1] (batch_size=1)
            # This is exactly what CrossEntropyLoss expects!

            loss = criterion(outputs, label)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Validation
        val_loss, val_acc = validation(model, criterion, val_loader)
        print(f"Epoch {epoch+1}/{epochs}, "
              f"Train Loss: {running_loss/len(train_loader):.3f}, "
              f"Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.3f}")
        save_checkpoint_to_drive(model, arch, optimizer, epoch+1, drive_folder="MyDrive/Checkpoints")
    return model

def validation(model, criterion, val_loader):
    val_loss = 0
    correct_total = 0
    sample_total = 0
    model.eval()

    with torch.no_grad():
        for batch in val_loader:
            case_data = batch[0]
            stain_slices = case_data["stain_slices"]
            label = case_data["label"].to(device)

            outputs = model(stain_slices)

            # Add batch dimension for loss calculation
            outputs = outputs.unsqueeze(0)  # [1, 2]
            if label.dim() == 0:
                label = label.unsqueeze(0)  # [1]

            loss = criterion(outputs, label)
            val_loss += loss.item()

            # Calculate accuracy
            pred = torch.argmax(outputs, dim=1)  # [1]
            correct_total += (pred == label).sum().item()
            sample_total += 1

    model.train()
    return val_loss / max(sample_total, 1), correct_total / max(sample_total, 1)

In [None]:
# ------------------- Load pretrained patch model -------------------
# used for classifiing individual image patches, its convolutional layers are used as frozen feature extractor in MIL
class PatchClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        # loading the pre trained densenet model
        self.features = base.features
        # compress final convolutional output of each image to 2x2 feature map
        self.pool = nn.AdaptiveAvgPool2d((2,2))
        self.classifier = nn.Linear(base.classifier.in_features * 4, 2)

    def forward(self, x):
        # first passes input images through Densenet convolutional layers to get feature maps
        x = self.features(x)
        # applies pooling
        x = self.pool(x).view(x.size(0), -1)
        # passes to final classifier layer to get the logits
        return self.classifier(x)

patch_model = PatchClassifier()
# patch_model.load_state_dict(torch.load(os.path.join(filtered_patches_dir, "patch_classifier.pth")))
patch_model.eval()

PatchClassifier(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      

In [None]:
# Initialize the model with the new architecture

# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create base model (same as before)
base_model = models.densenet121(weights=None)
base_model.features = patch_model.features  # Use pretrained features

# Create hierarchical attention MIL model
model = HierarchicalAttnMIL(base_model=base_model).to(device)

# Freeze the feature extractor -- Freezes the CNN feature extractor (DenseNet layers) since we only trains the attention mechanisms and classifier
for param in model.features.parameters():
    param.requires_grad = False

# Optimizer and loss
optimizer = torch.optim.Adam([
    {'params': model.patch_projector.parameters()},
    {'params': model.patch_attention.parameters()},
    {'params': model.stain_attention.parameters()},
    {'params': model.case_attention.parameters()},
    {'params': model.classifier.parameters()}
], lr=0.001)

criterion = nn.CrossEntropyLoss()
arch = "densenet121_hierarchical_mil"

In [None]:
# check if we need to load from checkpoint before training
checkpoint_dir = "/content/drive/MyDrive/Checkpoints"
checkpoint_pattern = re.compile(r'epoch(\d+)\.pth')
start_epoch = 0
checkpoint_path = None
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]

def load_latest_checkpoint(checkpoint_dir, model, optimizer, device):
    # Find all files in the directory
    checkpoint_files = [f for f in os.listdir(checkpoint_dir) if checkpoint_pattern.search(f)]
    if len(checkpoint_files)>=1:
        # Sort by extracted epoch number
        checkpoint_files.sort(
            key=lambda x: int(checkpoint_pattern.search(x).group(1)) if checkpoint_pattern.search(x) else -1)
        # Get the path of the latest checkpoint (last one in the sorted list)
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_files[-1])
        print(f"Latest checkpoint in directory is **{checkpoint_path}**")

        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        return model, optimizer, start_epoch
    return model, optimizer, 0

model, optimizer, start_epoch = load_latest_checkpoint(checkpoint_dir, model, optimizer, device)

In [None]:
train_model(model, optimizer, criterion, train_loader, val_loader, arch, epochs=5, start_epoch=start_epoch)

Epoch 1/5:   2%|▏         | 1/51 [17:36<14:40:04, 1056.09s/it]


KeyboardInterrupt: 

In [None]:
# mini dateset to check for errors prior to the training period -- not actually used for training / testing :)

# Create a minimal training set with just 2-3 cases
def create_mini_dataset(full_case_dict, full_label_map, num_cases=3):
    """Create a tiny dataset for debugging"""
    case_ids = list(full_case_dict.keys())[:num_cases]
    mini_case_dict = {cid: full_case_dict[cid] for cid in case_ids}
    mini_label_map = {cid: full_label_map[cid] for cid in case_ids}
    return mini_case_dict, mini_label_map

# Create mini datasets
mini_train_case_dict, mini_train_label_map = create_mini_dataset(train_case_dict, train_label_map, num_cases=2)
mini_val_case_dict, mini_val_label_map = create_mini_dataset(val_case_dict, val_label_map, num_cases=1)

print(f"Mini train set: {len(mini_train_case_dict)} cases")
print(f"Mini val set: {len(mini_val_case_dict)} cases")

# Create mini dataloaders
mini_train_ds = StainBagCaseDataset(
    mini_train_case_dict, mini_train_label_map,
    transform=train_transform,
    shuffle_patches=True,
    per_slice_cap=50,  # Limit patches per slice for speed
    max_slices_per_stain=2  # Limit slices per stain
)

mini_val_ds = StainBagCaseDataset(
    mini_val_case_dict, mini_val_label_map,
    transform=transform,
    shuffle_patches=False,
    per_slice_cap=50,
    max_slices_per_stain=2
)

mini_train_loader = DataLoader(
    mini_train_ds, batch_size=1, shuffle=True,
    num_workers=0,  # Set to 0 for debugging to avoid multiprocessing issues
    pin_memory=False,
    collate_fn=case_collate_fn
)

mini_val_loader = DataLoader(
    mini_val_ds, batch_size=1, shuffle=False,
    num_workers=0,
    pin_memory=False,
    collate_fn=case_collate_fn
)

def quick_debug_training(model, optimizer, criterion, train_loader):
    """Run just one batch to see if everything works"""
    model.train()

    # Get just one batch
    batch = next(iter(train_loader))
    case_data = batch[0]

    print("=== DEBUGGING SINGLE BATCH ===")
    print(f"Case ID: {case_data['case_id']}")
    print(f"Label: {case_data['label']}")

    # Check stain slices structure
    for stain, slices in case_data["stain_slices"].items():
        print(f"{stain}: {len(slices)} slices")
        for i, slice_tensor in enumerate(slices):
            print(f"  Slice {i}: {slice_tensor.shape}")

    # Move to device
    stain_slices = case_data["stain_slices"]
    label = case_data["label"].to(device)

    # Forward pass
    print("\n=== FORWARD PASS ===")
    try:
        outputs = model(stain_slices)
        print(f"Outputs: {outputs}")
        print(f"Output shape: {outputs.shape}")

        # ✅ FIX: Add batch dimension for CrossEntropyLoss
        # outputs should be [batch_size, num_classes] = [1, 2]
        # label should be [batch_size] = [1]
        outputs_batch = outputs.unsqueeze(0)  # [2] -> [1, 2]
        label_batch = label.unsqueeze(0)      # scalar -> [1]

        print(f"Shapes for loss calculation:")
        print(f"  Outputs: {outputs_batch.shape}")
        print(f"  Label: {label_batch.shape}")

        loss = criterion(outputs_batch, label_batch)
        print(f"Loss: {loss.item()}")

        # Backward pass
        print("\n=== BACKWARD PASS ===")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print("✅ SUCCESS: Forward and backward pass completed!")

    except Exception as e:
        print(f"❌ ERROR: {e}")
        import traceback
        traceback.print_exc()

# Run the quick debug
quick_debug_training(model, optimizer, criterion, mini_train_loader)



Mini train set: 2 cases
Mini val set: 1 cases
=== DEBUGGING SINGLE BATCH ===
Case ID: 89
Label: 1
h&e: 2 slices
  Slice 0: torch.Size([50, 3, 224, 224])
  Slice 1: torch.Size([50, 3, 224, 224])
melan: 1 slices
  Slice 0: torch.Size([50, 3, 224, 224])
sox10: 1 slices
  Slice 0: torch.Size([50, 3, 224, 224])

=== FORWARD PASS ===
Outputs: tensor([-0.1371,  0.1461], grad_fn=<SqueezeBackward1>)
Output shape: torch.Size([2])
Shapes for loss calculation:
  Outputs: torch.Size([1, 2])
  Label: torch.Size([1])
Loss: 0.5615600943565369

=== BACKWARD PASS ===
✅ SUCCESS: Forward and backward pass completed!


In [None]:
# ------------ model evaluation ---------------------
model.eval()
all_preds, all_trues = [], []
all_patch_preds, all_patch_trues = [], []
all_attn_weights = []

with torch.no_grad():
    for X_bag, y in tqdm(test_loader, desc="Evaluating", leave=False):
        X_bag, y = X_bag.to(device), y.to(device)

        # 1. Slice-level prediction (also getting attention weights)
        bag_logits,attn_weights = model(X_bag, return_attn_weights=True)
        bag_pred = bag_logits.argmax(dim=1)
        all_preds.extend(bag_pred.cpu().numpy())
        all_trues.extend(y.cpu().numpy())

        # 2. Patch-level prediction
        patch_logits = model(X_bag, return_patch_logits=True)
        patch_pred = patch_logits.argmax(dim=2).squeeze(0).cpu().numpy()
        patch_labels = y.cpu().item() * np.ones_like(patch_pred)
        all_patch_preds.extend(patch_pred)
        all_patch_trues.extend(patch_labels)

        # 3. Save Attention Weights (for visualization)
        attn_weights = attn_weights.squeeze(0).cpu().numpy()
        all_attn_weights.append(attn_weights)

# 4. Case Level Majority Vote Aggregation
slice_keys = list(test_patches.keys())
case_preds = defaultdict(list)
case_trues = {}

for i in range(len(all_preds)):

    pred = all_preds[i]
    true_label = all_trues[i]
    case_id = slice_keys[i][0]
    case_preds[case_id].append(pred)

    # Map the true label to the case_id
    case_trues[case_id] = true_label

# Majority vote per case (same logic as before)
final_preds = []
final_trues = []
for case_id in sorted(case_preds.keys()):
    votes = case_preds[case_id]
    # Use max with key=votes.count for majority vote
    maj_pred = max(set(votes), key=votes.count)
    final_preds.append(maj_pred)
    final_trues.append(case_trues[case_id])

# Report
print("=== Aggregated Case-Level Classification Report ===")
print(classification_report(final_trues, final_preds, target_names=['Benign', 'High-grade CMIL']))

print("=== Slice-Level Classification Report ===")
print(classification_report(all_trues, all_preds, target_names=['Benign', 'High-grade CMIL'], labels=[0,1]))

print("\n=== Patch-Level Classification Report (weak labels) ===")
print(classification_report(all_patch_trues, all_patch_preds, target_names=['Benign', 'High-grade CMIL'], labels=[0,1]))

In [None]:
def plot_confusion_matrix(trues, preds, title):
    cm = confusion_matrix(trues, preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Benign', 'High-grade CMIL'],
                yticklabels=['Benign', 'High-grade CMIL'])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(title)
    plt.show()

# --- Plotting the matrices ---

# 1. Slice-Level and Patch-Level
for name, preds, trues in [
    ("Slice-Level", all_preds, all_trues),
    ("Patch-Level", all_patch_preds, all_patch_trues)
]:
    plot_confusion_matrix(trues, preds, f"{name} Confusion Matrix")

# 2. Aggregated Case-Level
plot_confusion_matrix(
    final_trues,
    final_preds,
    "Aggregated Case-Level Confusion Matrix (Majority Vote)"
)

In [None]:
def visualize_top_attended_patches(test_ds, all_attn_weights, all_preds, all_trues, num_cases=3, top_k=5):
    # all_preds and all_trues should be patch-level prediction
    selected_indices = random.sample(range(len(test_ds)), num_cases)

    for idx in selected_indices:
        # --- Retrieve Cached Results ---
        bag, _ = test_ds[idx]

        # Retrieve the pre-computed results
        label = all_trues[idx]
        pred = all_preds[idx]
        attn_weights = all_attn_weights[idx]
        correct = (pred == label)

        # --- Visualization Logic (Same as before) ---
        patch_paths = test_ds.bags[idx]
        top_indices = attn_weights.argsort()[-top_k:][::-1]

        print(f"\n🧪 Case #{idx}: True label = {label}, Predicted = {pred}, Correct = {correct}")
        print("Top patches with highest attention:")

        plt.figure(figsize=(15, 3))
        for i, patch_idx in enumerate(top_indices):
            patch_path = patch_paths[patch_idx]
            # Ensure the path is correct and the image exists
            try:
                img = Image.open(patch_path)
            except FileNotFoundError:
                print(f"Error: Patch image not found at {patch_path}. Skipping.")
                continue

            weight = attn_weights[patch_idx]

            print(f"  - {os.path.basename(patch_path)}: attention = {weight:.4f}")

            plt.subplot(1, top_k, i + 1)
            plt.imshow(img)
            plt.axis('off')
            plt.title(f"attn={weight:.3f}")
        plt.suptitle(f"Case {idx} | True: {label} | Pred: {pred} | Correct: {correct}")
        plt.tight_layout()
        plt.show()

In [None]:
visualize_top_attended_patches(test_ds, all_attn_weights, all_patch_preds, all_patch_trues, num_cases=100, top_k=2)