In [204]:
# Standard library imports
import os
import re
import random
import shutil
from datetime import datetime
from collections import OrderedDict

# Data handling and visualization libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch and TorchVision libraries for deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms, datasets
from PIL import Image

# Scikit-Learn for evaluation metrics and data splitting|
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report

# Imbalanced-learn for oversampling to address class imbalance
from imblearn.over_sampling import RandomOverSampler

from torchvision import transforms
from torchvision.transforms import functional as F
from PIL import Image
from PIL import Image
from PIL import UnidentifiedImageError
from tqdm import tqdm

import os, re, random
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split
from google.colab import drive

In [205]:
# ─── 1) MOUNT GOOGLE DRIVE ──────────────────────────────────────────────
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [206]:
# ─── LOAD LABELS & BUILD CLASS MAP ───────────────────────────────────────
labels = pd.read_csv(
    "/content/drive/MyDrive/case_grade_match.csv"
).drop(index=64, errors='ignore').reset_index(drop=True)

In [207]:
# Note that case 65 is not included in the analysis as it was not labelled
labels.loc[62:65,:]

Unnamed: 0,Case,Class
62,63,3
63,64,3
64,66,3
65,67,4


In [208]:
filtered_patches_dir = '/content/drive/MyDrive/CMIL_SP2025_Patches_Apr27'
os.path.exists(filtered_patches_dir)

True

In [209]:
all_files = os.listdir(filtered_patches_dir)
len(all_files)

83717

In [210]:
# === Slice-Level Grouping ===

# Edit 10/19: identified scenario where patches are not being matched:
#     1) missing space between 'match' / 'unmatched' and number, eg. case_38_match14_h&e_patch25.png
#     2) Additionally, some names include 'labels' (not relevant to this regex which is only looking at case matching but thought I would flag),
#         eg. case_82_unmatched3_h&e-labels_patch32.png
def group_patches_by_slice(root_dir):
    case_slices = defaultdict(list)
    # list of unconventionally-named patches
    invalid_file_names = []
    flexibility_needed_counter = 0
    for root, _, files in os.walk(root_dir):
        for filename in files:
            if filename.endswith(".png"):
              #  match = re.match(r"case_(\d+)_([\w&\-]+_\d+)_", filename)
                match = re.match(r"case_(\d+)_([a-z]+_\d+)_", filename) #check if some patches are named differently
                if match:
                    case_id = int(match.group(1))
                    slice_id = match.group(2)
                    key = (case_id, slice_id)
                    case_slices[key].append(os.path.join(root, filename))

                    continue
                # if a file doesn't match, try regex without "_" between "match" / "unmatched" and number
                match = re.match(r"case_(\d+)_([a-z]+\d+)_", filename)
                if match:
                    case_id = int(match.group(1))
                    slice_id = match.group(2)
                    # adding underscore between "match" / "unmatched" and number
                    slice_id = re.sub(r'([A-Za-z])(\d)', r'\1_\2', slice_id)
                    key = (case_id, slice_id)
                    case_slices[key].append(os.path.join(root, filename))
                    flexibility_needed_counter += 1

                    continue
                invalid_file_names.append(os.path.join(root, filename))

    # Print summary of invalid files
    if invalid_file_names:
        print(f"Found {len(invalid_file_names)} files not following naming convention:")
        for f in invalid_file_names:
            print("  ", f)
    else:
        print(f"All {flexibility_needed_counter} invalid file names were handled.")
    return case_slices

In [211]:
# === Load Patches by Slice ===
patches = group_patches_by_slice(filtered_patches_dir)

All 3808 invalid file names were handled.


In [212]:
len(patches)

378

In [213]:
tot_patches = 0
for ke in patches.keys():
  tot_patches = tot_patches + len(patches[ke])
tot_patches

83717

Some patches are getting lost (need to make the 'match = re.match(....)' code more general)

In [214]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [215]:
# Training transform (augmented)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [216]:
# === Build Label Map by Slice ===
slice_to_class = {}
valid_classes = [1.0, 3.0, 4.0]
for (case_id, slice_id), paths in patches.items():
    raw_label = labels.loc[labels['Case'] == case_id, 'Class']
    if not raw_label.empty and raw_label.item() in valid_classes:
        label = 0 if raw_label.item() == 1.0 else 1
        slice_to_class[(case_id, slice_id)] = label

In [217]:
len(slice_to_class) # Low-grade slices are removed

366

In [218]:
slice_to_class[(106, 'match_1')]

1

In [219]:
# === Stratified Split by Slice ===
slices_by_class = defaultdict(list)
for key, label in slice_to_class.items():
    slices_by_class[label].append(key) # Dictonary of length 2

In [220]:
slices_by_class.keys()

dict_keys([1, 0])

In [221]:
slices_by_class[0][0:2]

[(25, 'match_1'), (25, 'match_2')]

In [222]:
train_slices, val_slices, test_slices = [], [], []
for label, slice_keys in slices_by_class.items():
    train, temp = train_test_split(slice_keys, test_size=0.4, random_state=42)
    val, test = train_test_split(temp, test_size=0.5, random_state=42)
    train_slices += train
    val_slices += val
    test_slices += test

In [223]:
label

0

In [224]:
slice_keys[0:2]

[(25, 'match_1'), (25, 'match_2')]

In [225]:
len(train_slices)

218

In [226]:
len(test_slices)

74

In [227]:
len(val_slices)

74

In [228]:
ele_list_train = []
for ele in train_slices:
  if ele[0] not in ele_list_train:
    ele_list_train.append(ele[0])
ele_list_train.sort()
print(ele_list_train)

[1, 2, 3, 4, 5, 7, 9, 11, 12, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 27, 28, 30, 32, 34, 36, 38, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 51, 52, 54, 56, 57, 58, 59, 61, 63, 64, 66, 67, 68, 69, 70, 72, 73, 75, 77, 78, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 100, 101, 104, 106]


In [229]:
ele_list_test = []
for ele in test_slices:
  if ele[0] not in ele_list_test:
    ele_list_test.append(ele[0])
ele_list_test.sort()
print(ele_list_test)

[1, 2, 4, 7, 9, 11, 14, 15, 19, 21, 22, 24, 25, 26, 34, 36, 38, 41, 48, 50, 55, 56, 66, 68, 69, 70, 72, 79, 80, 82, 83, 84, 87, 89, 91, 92, 93, 97, 98, 99, 104]


In [230]:
print(list(set(ele_list_train) & set(ele_list_test)))

[1, 2, 4, 7, 9, 11, 14, 15, 19, 21, 22, 24, 25, 26, 34, 36, 38, 41, 48, 50, 56, 66, 68, 69, 70, 72, 79, 80, 82, 83, 84, 87, 89, 91, 93, 97, 98, 104]


In [231]:
len(list(set(ele_list_train) & set(ele_list_test)))

38

In [232]:
len(ele_list_test)

41

In [233]:
for ele in train_slices:
  if ele[0] == 2:
    print(ele)

(2, 'match_1')
(2, 'unmatched_1')


In [234]:
for ele in test_slices:
  if ele[0] == 2:
    print(ele)

(2, 'unmatched_3')
(2, 'unmatched_2')


In [235]:
train_patches = {k: patches[k] for k in train_slices}
val_patches   = {k: patches[k] for k in val_slices}
test_patches  = {k: patches[k] for k in test_slices}

In [236]:
# === Filter to H&E Only ===

def filter_by_stain(d, keyword):
    out, dropped = {}, []
    for k, paths in d.items():
        filtered = [p for p in paths if keyword.lower() in os.path.basename(p).lower()]
        if filtered:
            out[k] = filtered
        else:
            dropped.append(k)
    if dropped:
        print(f"⚠️ Dropped slices with no '{keyword}' patches: {dropped}")
    return out


train_patches = filter_by_stain(train_patches, "h&e")
val_patches   = filter_by_stain(val_patches, "h&e")
test_patches  = filter_by_stain(test_patches, "h&e")

⚠️ Dropped slices with no 'h&e' patches: [(104, 'match_15'), (18, 'unmatched_4'), (73, 'unmatched_4'), (72, 'unmatched_5'), (50, 'match_4'), (38, 'match_22'), (15, 'unmatched_10'), (50, 'match_5'), (38, 'match_28'), (15, 'match_1'), (15, 'unmatched_11'), (27, 'unmatched_9'), (23, 'match_1'), (24, 'unmatched_9'), (27, 'match_2'), (21, 'match_1')]
⚠️ Dropped slices with no 'h&e' patches: [(73, 'unmatched_3'), (38, 'match_29'), (72, 'unmatched_6'), (27, 'unmatched_1'), (27, 'unmatched_8'), (22, 'unmatched_12')]
⚠️ Dropped slices with no 'h&e' patches: [(87, 'unmatched_2'), (15, 'match_2'), (15, 'unmatched_12'), (15, 'unmatched_13'), (38, 'match_27'), (22, 'unmatched_13'), (21, 'unmatched_4'), (24, 'unmatched_10')]


In [237]:
len(train_patches)

202

In [238]:
# === MILDataset with slice-level keys ===
class SliceMILDataset(Dataset):
    def __init__(self, patch_dict, label_map, transform=None, emergency_cap=800):
      #emergency_cap: upper limit on how many patches per bag to include, why have a limit? how set the limit?
      #If a bag has too many patches (say 3000), only randomly keeps up to emergency_cap (default 800)
      #Prevents out-of-memory errors during training. - needed or not?
      #transform: optional torchvision transform (e.g., resizing, normalization) - same as manual?
        self.transform = transform
        self.emergency_cap = emergency_cap
        # self.bags: list of lists of image paths (each “bag” = slice).
        # self.labels: list of corresponding labels for each bag.
        self.bags, self.labels = [], []
        for slice_key, paths in patch_dict.items():
            self.bags.append(paths)
            self.labels.append(label_map[slice_key])

    # Required for PyTorch datasets — tells how many samples (bags) exist.
    # Enables len(dataset) to work and is used by DataLoader for batching.
    def __len__(self): return len(self.bags)


    # The method below is called whenever PyTorch asks for a sample.
    # Retrieves the list of patch file paths for that bag.
    def __getitem__(self, idx):
        paths = self.bags[idx] # idx is the index of the bag you want.
        imgs = []
        for p in paths:   # Loops through each patch path in the bag.
            try:
                img = Image.open(p).convert('RGB')#Change the color transformation - YCbCr, HSV
                if self.transform: # Applies transforms if provided (resize, normalize, etc.)
                    img = self.transform(img)
                imgs.append(img)
            except:
                continue #why try-except, what error? Why corrupt patches?
        if len(imgs) == 0:
            raise ValueError(f"No usable patches in slice {paths}")
        if self.emergency_cap and len(imgs) > self.emergency_cap:
            imgs = random.sample(imgs, self.emergency_cap)
        return torch.stack(imgs), torch.tensor(self.labels[idx], dtype=torch.long)
        #Stacks all transformed image tensors into a single tensor of shape:
        #(num_patches, C, H, W)
        #(e.g., (500, 3, 224, 224) for 500 patches).
        #Converts the label to a PyTorch tensor of type long (needed for classification loss).
        #Returns (bag_tensor, label_tensor) as one sample.



In [239]:
# this pools the patch level features into single bag level representation for MIL
class AttentionPool(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
      #input_dim: the dimensionality of each input feature vector (e.g., 512 if you’re using a CNN encoder like ResNet).
      #hidden_dim: the size of the hidden layer used inside the attention mechanism (default = 128).
      #This defines how much capacity the attention sub-network has to learn complex relationships.
        super().__init__()
        # creates small neural network to compute attention scores for each patch
        # each patch embedding is passed through a linear layer, tanh for nonlinearity, and another linear layer to get a scalar score
        self.attention = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self, x, return_weights=False):
        # x is of shape (B, M, D) where B is batch size or number of cases,
        # M is number of patches per bag, and D is the embedding dimensions for each patch
        weights = self.attention(x)     # (B, M, 1)
        weights = torch.softmax(weights, dim=1) # softmax so that weights are positive
        # outputs attention scores for each patch and normalized with softmax
        # D is the embedding dimension which is size of feature vector for each patch after going through the patch classifier
        weighted_x = (weights * x).sum(dim=1)  # (B, D)
        # returning the raw attention weights per patch just to help with visualization of the weights for each patch
        if return_weights:
            return weighted_x, weights.squeeze(-1)  # (B, D), (B, M)
        return weighted_x

# using Densenet to extract features at the patch-level, and adaptaive pooling for patch-level feature pooling
# the patch features are passed into Attentionpool which learns the weights across patches and combines them into a single bag label

In [240]:
class AttnMIL(nn.Module):
    def __init__(self, base_model, num_classes=2, embed_dim=512):
        super().__init__()
        # grabbing the convolutional feature extractor from the pretrained model
        self.features = base_model.features
        # applying adaptive average pooling to compress to feature map of 2x2 grid
        # you get 4 spatial vectors per patch
        self.pool = nn.AdaptiveAvgPool2d((2,2))  # richer than (1,1) - tunable!
        # meaning that you'll get 4 vectors per patch which will then be flattened
        self.patch_projector = nn.Linear(base_model.classifier.in_features * 4, embed_dim)

        # Attention module (your earlier class) that produces per-patch weights and returns a bag embedding (weighted sum) of size embed_dim.
        self.attention_pool = AttentionPool(embed_dim)

        # Final classifier that maps the bag embedding to logits over classes.
        self.classifier = nn.Linear(embed_dim, num_classes)


    def forward(self, x, return_patch_logits=False, return_attn_weights=False):
        if x.dim() == 4: # why the need to check?
            x = x.unsqueeze(0)
        # typically after CNN you get 3D tensor with num channels, height and width of image
        # but we packed the patches into a bag by case (the tensor), so B is batch size, M is number of patches per bag
        B, M, C, H, W = x.shape
        x = x.view(B*M, C, H, W) # put all patches in the batch together

        features = self.features(x) # exxtracting cnn features for each patch, output shape: (B*M, C', H', W')
        pooled = self.pool(features).view(B*M, -1) # pool each feature map to a 2x2 grid and flatten, output shape: (B*M, 4*C')
        embedded = self.patch_projector(pooled).view(B, M, -1) # project each patch into shared embedding space, output shape (B, M, embed_dim)
        # just ensuring all the patches are transformed into vectors of the same length for attention

        # in order to get patch level predictions
        # Optional per-patch logits: apply the bag classifier to each patch embedding independently. Shape (B, M, num_classes). Useful for diagnostics/auxiliary losses.
        if return_patch_logits:
            logits = self.classifier(embedded)  # (B, M, 2)
            return logits
        # returning attention weights for visualization
        if return_attn_weights:
            bag_emb, attn_weights = self.attention_pool(embedded, return_weights=True)
            logits = self.classifier(bag_emb)
            return logits, attn_weights  # bag prediction + per-patch attention scores, why logits here when returning at the end

        # applying attention
        #computing a weighted average of patch embeddings using attention, and then is passed through the classifier to get bag level prediction
        bag_emb = self.attention_pool(embedded) #Shape: (B, embed_dim) + weights (B, M)
        logits = self.classifier(bag_emb) # Shape: (B, 2)
        return logits

# What is the effect of M varying in each bag?

In [241]:
# Set up datasets
# Set up slice-based MIL datasets
# dict mapping slice_key -> list_of_patch_paths for the train split.
# dict mapping slice_key -> class_label.
# image augmentations for training (e.g., flips/crops/normalization).
train_ds = SliceMILDataset(train_patches, slice_to_class, transform=train_transform)

# Uses (usually) lighter or no augmentation (transform) to mimic evaluation conditions.
val_ds   = SliceMILDataset(val_patches, slice_to_class, transform=transform)

# emergency_cap=None disables the cap on patches per bag (use all patches). Helpful to evaluate full performance without subsampling.
test_ds  = SliceMILDataset(test_patches, slice_to_class, transform=transform, emergency_cap=None)

# help generalize better and reduce overfitting
# Wraps train_ds in a PyTorch DataLoader.
# batch_size=1: in MIL, each bag (slice) can contain a variable number of patches; batching multiple bags together often requires padding/collate tricks. Why? Is it because bag sizes are different? Using 1 bag per batch keeps it simple and memory-safe.
# shuffle=True: randomizes bag order each epoch → improves generalization and reduces overfitting.
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)
# don't want to shuffle so taht you get consistent, repeatable evaluation
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False)

# If bags vary wildly in size and you want larger batches, you can implement a custom collate_fn to pad/truncate bags, but that adds complexity.

In [242]:
def summarize_split(patch_dict, label_map=None, split_name=""):
    # Count cases
    num_cases = len(patch_dict)

    # Count total patches
    total_patches = sum(len(paths) for paths in patch_dict.values())

    # Count per class (if label map provided)
    class_counts = Counter()
    if label_map:
        class_counts.update([label_map[c] for c in patch_dict.keys() if c in label_map])

    # Count unique slice IDs (if present in patch filenames)
    slice_set = set()
    for paths in patch_dict.values():
        for path in paths:
            # Extract slice identifier after 2nd underscore (e.g., unmatched_2 from case_01_unmatched_2_patch43.png)
            match = re.search(r"case_\d+_(\w+_\d+)", os.path.basename(path))
            if match:
                slice_set.add(match.group(1))

    # Output
    print(f"\n🔹 {split_name} Split")
    print(f"Cases: {num_cases}")
    print(f"Patches: {total_patches}")
    if class_counts:
        print(f"Class distribution: {dict(class_counts)}")
    print(f"Slices (unique IDs): {len(slice_set)}")

# Run for all splits
summarize_split(train_patches, slice_to_class, split_name="Train")
summarize_split(val_patches, slice_to_class, split_name="Validation")
summarize_split(test_patches, slice_to_class, split_name="Test")


🔹 Train Split
Cases: 202
Patches: 23849
Class distribution: {1: 131, 0: 71}
Slices (unique IDs): 26

🔹 Validation Split
Cases: 68
Patches: 7682
Class distribution: {1: 45, 0: 23}
Slices (unique IDs): 20

🔹 Test Split
Cases: 66
Patches: 8206
Class distribution: {1: 43, 0: 23}
Slices (unique IDs): 17
