##1.0 Environment Setup

##2.0 Mount Drive

##3.0 Imports and Configuration

##4.0 Build Manifest (Map Feature File and to Slide Label)

In [None]:
# Cell A - Build manifest mapping feature .npy -> slide label (numeric)
import pandas as pd
from pathlib import Path

BASE_DIR = Path("/content/drive/MyDrive/WSI-MIL-Pipeline")
FEATURES_DIR = BASE_DIR / "data" / "features"          # contains slide_XXX_features.npy
LABELS_CSV = BASE_DIR / "data" / "stage_labels.csv"    # your csv with slide-level and patch-level rows

# load labels CSV
labels_df = pd.read_csv(LABELS_CSV)

# We assume slide-level rows are the ones where 'patient' or 'slide' contains ".zip" or direct slide id.
# Adjust this selection depending on your CSV: here we keep rows where value endswith '.zip' or matches a slide id
# Try these heuristics:
if 'patient' in labels_df.columns:
    lab_col = 'patient'
elif 'slide_id' in labels_df.columns:
    lab_col = 'slide_id'
else:
    lab_col = labels_df.columns[0]

# Create slide-level mapping: choose rows that look like slide identifiers (zip or no "_node_")
slide_level = labels_df[~labels_df[lab_col].astype(str).str.contains(r"_node_|node|_patch_", regex=True)]
slide_level = slide_level.copy()

# Normalize slide id strings to match your feature filenames.
# Example: patient_000.zip -> patient_000  OR patient_000.zip may correspond to slide folder name.
def normalize_slide_id(s):
    s = str(s)
    s = s.replace('.zip','').replace('.svs','').replace('.tif','')
    return s

slide_level['slide_id'] = slide_level[lab_col].apply(normalize_slide_id)
slide_level = slide_level.rename(columns={labels_df.columns[1]: 'label'})  # assume second column is label

# numeric encode labels
label_map = {lab:i for i, lab in enumerate(sorted(slide_level['label'].unique()))}
slide_level['label_idx'] = slide_level['label'].map(label_map)

# find feature files and match
manifest = []
for p in FEATURES_DIR.glob("*_features.npy"):
    slide_name = p.stem.replace("_features","")
    match = slide_level[slide_level['slide_id']==slide_name]
    if not match.empty:
        label_idx = int(match['label_idx'].iloc[0])
        manifest.append({'slide_id': slide_name, 'feature_path': str(p), 'label': match['label'].iloc[0], 'label_idx': label_idx})
    else:
        # not found, skip or log
        print("No slide-level label for:", slide_name)

manifest_df = pd.DataFrame(manifest)
print("Prepared manifest:", manifest_df.shape)
print(manifest_df.head())

# Save manifest
manifest_csv = FEATURES_DIR / "manifest.csv"
manifest_df.to_csv(manifest_csv, index=False)
print("Saved manifest to:", manifest_csv)


##5.0 Bag Label (Collate Dataset)

In [None]:
# Cell B - PyTorch Dataset that yields (features_tensor, label, slide_id)
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

class SlideBagDataset(Dataset):
    def __init__(self, manifest_df):
        self.df = manifest_df.reset_index(drop=True)
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        feats = np.load(row['feature_path'])   # shape (N_instances, feat_dim)
        # optionally: shuffle instances for training stability:
        # np.random.shuffle(feats)
        feats = torch.from_numpy(feats).float()
        label = torch.tensor(int(row['label_idx']), dtype=torch.long)
        slide_id = row['slide_id']
        return feats, label, slide_id

# collate_fn: MIL typically uses batch_size=1 (bags of variable size),
# or you can batch multiple bags but here we use batch 1 for simplicity
def mil_collate(batch):
    # batch is list length B (usually 1). we return single bag and label
    feats, labels, sids = zip(*batch)
    return feats[0].to(device), labels[0].to(device), sids[0]

# Usage:
import pandas as pd
manifest_df = pd.read_csv(FEATURES_DIR / "manifest.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# split train/test (stratified)
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(manifest_df, test_size=0.2, stratify=manifest_df['label_idx'], random_state=42)

train_ds = SlideBagDataset(train_df)
test_ds  = SlideBagDataset(test_df)

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=mil_collate)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=mil_collate)

print("Train slides:", len(train_ds), "Test slides:", len(test_ds))


##6.0 Attention MIL Model

In [None]:
# Cell C - Attention MIL model (instance encoder already computed -> pooling + classifier)
import torch.nn as nn
import torch.nn.functional as F

class AttentionMIL(nn.Module):
    def __init__(self, in_dim, hidden_dim=512, n_classes=2):
        super().__init__()
        self.fc = nn.Linear(in_dim, hidden_dim)
        self.attn = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        self.classifier = nn.Linear(hidden_dim, n_classes)
    def forward(self, x):
        # x: (n_instances, feat_dim)
        H = F.relu(self.fc(x))             # (n, hidden)
        A = self.attn(H)                   # (n,1)
        A = torch.softmax(A, dim=0)        # normalize over instances
        M = (A * H).sum(dim=0)             # (hidden,)
        out = self.classifier(M.unsqueeze(0))  # (1, n_classes)
        return out, A.squeeze(1)

# Instantiate
feat_example = np.load(manifest_df['feature_path'].iloc[0])
in_dim = feat_example.shape[1]
num_classes = len(label_map)
model = AttentionMIL(in_dim, hidden_dim=512, n_classes=num_classes).to(device)
print("Model ready. in_dim=", in_dim, "n_classes=", num_classes)


##7.0 Training and Evaluation

In [None]:
# Cell D - Training loop (batch_size=1 bags)
import torch.optim as optim
from sklearn.metrics import accuracy_score, roc_auc_score

optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
EPOCHS = 10
best_acc = 0.0

def evaluate(loader):
    model.eval()
    trues, preds = [], []
    with torch.no_grad():
        for feats, label, sid in loader:
            out, att = model(feats)  # feats already on device
            pred = int(out.argmax(dim=1).cpu().item())
            preds.append(pred); trues.append(int(label.cpu().item()))
    acc = accuracy_score(trues, preds)
    return acc, trues, preds

for epoch in range(1, EPOCHS+1):
    model.train()
    losses = []
    for feats, label, sid in train_loader:
        optimizer.zero_grad()
        out, att = model(feats)
        loss = criterion(out, label.unsqueeze(0))
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    train_loss = float(np.mean(losses)) if losses else 0.0
    val_acc, _, _ = evaluate(test_loader)
    print(f"Epoch {epoch} loss={train_loss:.4f} val_acc={val_acc:.4f}")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), BASE_DIR / "models" / "mil_best.pth")
print("Training done. Best val acc:", best_acc)
