
# Learning Series #6 — Image × Omics Fusion (ABMIL demo)

This notebook trains an **Attention-based MIL** model on tile-level features with **omics-derived labels** (CD274_high or KRAS_mut).



## 0) Setup


In [None]:

# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# !pip install pandas numpy scikit-learn matplotlib
import os, json, math, random
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

plt.rcParams["figure.dpi"] = 140
plt.rcParams["savefig.bbox"] = "tight"

OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)
print("Saving outputs to:", OUT_DIR)

SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)



## 1) Load data
Change `tiles_path`, `labels_path`, and `splits_path` if using your own.


In [None]:

tiles_path = "example_tiles.csv"
labels_path = "example_labels.csv"
splits_path = "example_splits.csv"

tiles = pd.read_csv(tiles_path)
labels = pd.read_csv(labels_path)
splits = pd.read_csv(splits_path)

# Basic checks
assert {"tile_id","sample_id"}.issubset(tiles.columns)
feat_cols = [c for c in tiles.columns if c.startswith("f")]
assert len(feat_cols) > 0, "No feature columns starting with 'f' found"
assert "sample_id" in labels.columns and "CD274_high" in labels.columns and "KRAS_mut" in labels.columns
assert "sample_id" in splits.columns and "split" in splits.columns

print("Tiles:", tiles.shape, "| Labels:", labels.shape, "| Splits:", splits["split"].value_counts().to_dict())
tiles.head()



## 2) Build MIL bags
One bag = all tiles for a `sample_id`. Choose your label column.


In [None]:

LABEL_COL = "CD274_high"   # or "KRAS_mut"

# Assemble bags: list of (sample_id, np.array[n_tiles, feat_dim])
bags = []
for sid, df in tiles.groupby("sample_id"):
    feats = df[feat_cols].values.astype("float32")
    bags.append((sid, feats))

# Map labels
lab_map = labels.set_index("sample_id")[LABEL_COL].to_dict()
split_map = splits.set_index("sample_id")["split"].to_dict()

def split_bags(bags):
    tr, va, te = [], [], []
    for sid, feats in bags:
        if sid not in lab_map or sid not in split_map: 
            continue
        y = int(lab_map[sid])
        s = split_map[sid]
        if s == "train": tr.append((sid, feats, y))
        elif s == "val": va.append((sid, feats, y))
        else: te.append((sid, feats, y))
    return tr, va, te

train_bags, val_bags, test_bags = split_bags(bags)
len(train_bags), len(val_bags), len(test_bags)



## 3) ABMIL model
Simple attention pooling (Ilse et al. 2018). Batch size = 1 bag for simplicity.


In [None]:

class ABMIL(nn.Module):
    def __init__(self, in_dim, hid_dim=128):
        super().__init__()
        self.embed = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.Tanh()
        )
        self.attn = nn.Sequential(
            nn.Linear(hid_dim, hid_dim),
            nn.Tanh(),
            nn.Linear(hid_dim, 1)
        )
        self.clf = nn.Linear(hid_dim, 1)

    def forward(self, x):
        # x: [n_tiles, in_dim]
        H = self.embed(x)              # [n_tiles, hid_dim]
        a = self.attn(H)               # [n_tiles, 1]
        w = torch.softmax(a.squeeze(-1), dim=0)  # [n_tiles]
        M = torch.sum(H * w.unsqueeze(-1), dim=0)  # [hid_dim]
        logit = self.clf(M).squeeze(0)  # scalar
        return logit, w, M



## 4) Training utilities


In [None]:

class BagDataset(Dataset):
    def __init__(self, bag_tuples):
        self.bag_tuples = bag_tuples  # list of (sid, feats, y)

    def __len__(self): return len(self.bag_tuples)

    def __getitem__(self, idx):
        sid, feats, y = self.bag_tuples[idx]
        return sid, torch.from_numpy(feats), torch.tensor(y, dtype=torch.float32)

def run_epoch(model, loader, opt=None):
    is_train = opt is not None
    model.train(is_train)
    y_true, y_score, losses = [], [], []
    for sid, feats, y in loader:
        feats = feats[0]  # batch_size=1
        y = y[0]
        if is_train: opt.zero_grad()
        logit, w, M = model(feats)
        loss = F.binary_cross_entropy_with_logits(logit, y)
        if is_train:
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            opt.step()
        prob = torch.sigmoid(logit).detach().cpu().item()
        y_true.append(int(y.item()))
        y_score.append(prob)
        losses.append(loss.item())
    auroc = roc_auc_score(y_true, y_score) if len(set(y_true)) > 1 else float("nan")
    acc = accuracy_score(y_true, [1 if s>=0.5 else 0 for s in y_score])
    return np.mean(losses), auroc, acc

def evaluate_and_save_curves(y_true, y_score, name):
    if len(set(y_true)) < 2:
        return
    fpr, tpr, _ = roc_curve(y_true, y_score)
    fig = plt.figure(figsize=(5,4))
    plt.plot(fpr, tpr)
    plt.plot([0,1],[0,1], linestyle="--")
    plt.xlabel("FPR")
    plt.ylabel("TPR")
    plt.title(f"ROC — {name}")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, f"roc_curve_{name}.png"), dpi=300)
    plt.savefig(os.path.join(OUT_DIR, f"roc_curve_{name}.svg"))
    plt.show()



## 5) Train


In [None]:

in_dim = len(feat_cols)
model = ABMIL(in_dim=in_dim, hid_dim=128)
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

train_loader = DataLoader(BagDataset(train_bags), batch_size=1, shuffle=True)
val_loader   = DataLoader(BagDataset(val_bags), batch_size=1, shuffle=False)
test_loader  = DataLoader(BagDataset(test_bags), batch_size=1, shuffle=False)

EPOCHS = 8
history = {"train_loss": [], "val_loss": [], "val_auroc": [], "val_acc": []}

best_val = -1
best_state = None

for epoch in range(1, EPOCHS+1):
    tr_loss, _, _ = run_epoch(model, train_loader, opt)
    with torch.no_grad():
        val_loss, val_auroc, val_acc = run_epoch(model, val_loader, None)
    history["train_loss"].append(tr_loss)
    history["val_loss"].append(val_loss)
    history["val_auroc"].append(val_auroc)
    history["val_acc"].append(val_acc)
    if not math.isnan(val_auroc) and val_auroc > best_val:
        best_val = val_auroc
        best_state = {k:v.cpu().clone() for k,v in model.state_dict().items()}
    print(f"Epoch {epoch:02d} | train_loss={tr_loss:.4f} val_loss={val_loss:.4f} val_auroc={val_auroc:.3f} val_acc={val_acc:.3f}")

if best_state is not None:
    model.load_state_dict(best_state)
torch.save(model.state_dict(), os.path.join(OUT_DIR, "model.pt"))

# Loss curve
fig = plt.figure(figsize=(5,4))
plt.plot(history["train_loss"], label="train")
plt.plot(history["val_loss"], label="val")
plt.xlabel("Epoch")
plt.ylabel("BCE loss")
plt.title("Training/Validation Loss")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "loss_curve.png"), dpi=300)
plt.savefig(os.path.join(OUT_DIR, "loss_curve.svg"))
plt.show()

# Final eval on val + test and ROC curves
def collect_scores(loader):
    y_true, y_score = [], []
    model.eval()
    with torch.no_grad():
        for sid, feats, y in loader:
            feats = feats[0]
            y = y[0]
            logit, w, _ = model(feats)
            prob = torch.sigmoid(logit).item()
            y_true.append(int(y.item()))
            y_score.append(prob)
    return y_true, y_score

y_val, s_val = collect_scores(val_loader)
y_test, s_test = collect_scores(test_loader)
evaluate_and_save_curves(y_val, s_val, f"{LABEL_COL}_val")
evaluate_and_save_curves(y_test, s_test, f"{LABEL_COL}_test")

metrics = {
    "label": LABEL_COL,
    "val": {
        "auroc": float(roc_auc_score(y_val, s_val)) if len(set(y_val))>1 else None,
        "acc": float(accuracy_score(y_val, [1 if z>=0.5 else 0 for z in s_val]))
    },
    "test": {
        "auroc": float(roc_auc_score(y_test, s_test)) if len(set(y_test))>1 else None,
        "acc": float(accuracy_score(y_test, [1 if z>=0.5 else 0 for z in s_test]))
    }
}
with open(os.path.join(OUT_DIR, "metrics.json"), "w") as f:
    json.dump(metrics, f, indent=2)
metrics



## 6) Export attention for interpretability
Saves tile-level attention per selected samples for downstream overlays.


In [None]:

def export_attention_for_samples(sample_ids, top_k=3):
    exported = []
    for sid in sample_ids[:top_k]:
        df_s = tiles[tiles["sample_id"] == sid].reset_index(drop=True)
        feats = torch.from_numpy(df_s[feat_cols].values.astype("float32"))
        with torch.no_grad():
            _, w, _ = model(feats)
        attn = w.detach().cpu().numpy().reshape(-1)
        out = df_s[["tile_id","sample_id","x","y"]].copy()
        out["attention"] = attn
        out = out.sort_values("attention", ascending=False)
        out_path = os.path.join(OUT_DIR, f"attention_{sid}.csv")
        out.to_csv(out_path, index=False)
        exported.append(out_path)
    return exported

# Example: export top-2 samples by predicted probability on test set
sample_scores = []
model.eval()
with torch.no_grad():
    for sid, feats, y in test_loader:
        feats = feats[0]
        logit, w, _ = model(feats)
        prob = torch.sigmoid(logit).item()
        sample_scores.append((sid[0], prob))

sample_scores = sorted(sample_scores, key=lambda x: x[1], reverse=True)
selected = [sid for sid,_ in sample_scores[:2]]
paths = export_attention_for_samples(selected, top_k=2)
paths
