# Cytokine Signaling Cascade Mapping via AB-MIL Dynamics

This notebook runs the full experiment:
1. Load config and data
2. Stage 1 — pre-train the InstanceEncoder with cell-type supervision
3. Stage 2 — train full AB-MIL (encoder frozen)
4. Stage 3 (optional) — fine-tune jointly
5. Dynamics analysis — learnability ranking, entropy, instance confidence
6. Validation — seed stability, known-group checks

Connect to the cluster kernel before running.
All paths in `configs/default.yaml` point to cluster storage.

In [None]:
import json
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from torch.utils.data import DataLoader

from cytokine_mil.data.label_encoder import CytokineLabel
from cytokine_mil.data.dataset import PseudoTubeDataset, CellDataset
from cytokine_mil.models.instance_encoder import InstanceEncoder
from cytokine_mil.models.attention import AttentionModule
from cytokine_mil.models.bag_classifier import BagClassifier
from cytokine_mil.models.cytokine_abmil import CytokineABMIL
from cytokine_mil.training.train_encoder import train_encoder
from cytokine_mil.training.train_mil import train_mil
from cytokine_mil.analysis.dynamics import (
    aggregate_to_donor_level,
    rank_cytokines_by_learnability,
    compute_cytokine_entropy_summary,
    compute_confusion_entropy_summary,
    build_cell_type_confidence_matrix,
)
from cytokine_mil.analysis.validation import (
    check_seed_stability,
    check_functional_groupings,
)

In [2]:
# --- Config ---
with open("cytokines/cytokines-mil/configs/default.yaml") as f:
    cfg = yaml.safe_load(f)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = cfg["dynamics"]["random_seeds"][0]
print(f"Device: {DEVICE}")
print(f"Seed: {SEED}")

Device: cuda
Seed: 42


## 1. Data

In [3]:
MANIFEST_PATH = cfg["data"]["manifest_path"]

with open(MANIFEST_PATH) as f:
    manifest = json.load(f)

# Load HVG list (saved by preprocess_tubes.ipynb)
HVG_PATH = str(Path(MANIFEST_PATH).parent / "hvg_list.json")
with open(HVG_PATH) as f:
    gene_names = json.load(f)

print(f"Manifest entries: {len(manifest)}")
print(f"HVGs: {len(gene_names)}")

Manifest entries: 10920
HVGs: 4000


In [4]:
# Label encoder — must be built once and saved for reproducibility
LABEL_ENCODER_PATH = str(Path(MANIFEST_PATH).parent / "label_encoder.json")
label_encoder = CytokineLabel().fit(manifest)
label_encoder.save(LABEL_ENCODER_PATH)
print(f"Classes: {label_encoder.n_classes()} (PBS at index {label_encoder.encode('PBS')})")

Classes: 91 (PBS at index 90)


In [5]:
from collections import defaultdict

# Pseudo-tube dataset (Stage 2/3)
# preload=True: loads all 10k tubes as sparse matrices at init (~8-10 GB).
# Eliminates all disk I/O during training and dynamics logging.
tube_dataset = PseudoTubeDataset(MANIFEST_PATH, label_encoder, gene_names=gene_names, preload=True)
print(f"Tubes: {len(tube_dataset)}")

# --- Stage 1 manifest: one tube per cytokine, rotating donors ---
# ~91 tubes × ~450 cells ≈ 40k cells ≈ 640 MB when preloaded.
# Using the full 10k-tube manifest with shuffle=True would require
# ~38 hours (random h5ad reads defeat the LRU cache).
_cyt_to_entries: dict = defaultdict(list)
for _e in manifest:
    if _e["tube_idx"] == 0:
        _cyt_to_entries[_e["cytokine"]].append(_e)

_stage1_manifest = []
for _i, _cyt in enumerate(sorted(_cyt_to_entries)):
    _entries = sorted(_cyt_to_entries[_cyt], key=lambda e: e["donor"])
    _stage1_manifest.append(_entries[_i % len(_entries)])

STAGE1_MANIFEST_PATH = str(Path(MANIFEST_PATH).parent / "manifest_stage1.json")
with open(STAGE1_MANIFEST_PATH, "w") as f:
    json.dump(_stage1_manifest, f)

# preload=True: loads all tubes at init → in-memory shuffling, no disk I/O per batch
cell_dataset = CellDataset(STAGE1_MANIFEST_PATH, gene_names=gene_names, preload=True)
print(f"Cells: {len(cell_dataset)}")
print(f"Cell types: {cell_dataset.n_cell_types()}")
# Sanity check — run this before train_encoder                                                                                                           
print(f"NaN in X: {np.isnan(cell_dataset._X).any()}")
print(f"Inf in X: {np.isinf(cell_dataset._X).any()}")                                                                                                    
print(f"X range: [{cell_dataset._X.min():.3f}, {cell_dataset._X.max():.3f}]")   

cell_loader = DataLoader(cell_dataset, batch_size=256, shuffle=True, num_workers=0)

Preloading tubes: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 10920/10920 [24:19<00:00,  7.48it/s]


Tubes: 10920
Cells: 39909
Cell types: 18
NaN in X: False
Inf in X: False
X range: [0.000, 7.628]


## 2. Stage 1 — Encoder Pre-training

In [6]:
encoder = InstanceEncoder(
    input_dim=len(gene_names),
    embed_dim=cfg["model"]["embedding_dim"],
    n_cell_types=cell_dataset.n_cell_types(),
)

encoder = train_encoder(
    encoder,
    cell_loader,
    # n_epochs=cfg["training"]["stage1_epochs"],
    n_epochs=16,
    lr=cfg["training"]["lr"],
    momentum=cfg["training"]["momentum"],
    device=DEVICE,
    verbose=True,
)

torch.save(encoder.state_dict(), "encoder_stage1.pt")
print("Encoder saved.")

                                                                                                                                                                

[Stage 1] Epoch   1/16 | loss=0.6498 | acc=0.7887


                                                                                                                                                                

[Stage 1] Epoch   2/16 | loss=0.2586 | acc=0.9027


                                                                                                                                                                

[Stage 1] Epoch   3/16 | loss=0.2056 | acc=0.9232


                                                                                                                                                                

[Stage 1] Epoch   4/16 | loss=0.1628 | acc=0.9413


                                                                                                                                                                

[Stage 1] Epoch   5/16 | loss=0.1240 | acc=0.9548


                                                                                                                                                                

[Stage 1] Epoch   6/16 | loss=0.0904 | acc=0.9692


                                                                                                                                                                

[Stage 1] Epoch   7/16 | loss=0.0608 | acc=0.9800


                                                                                                                                                                

[Stage 1] Epoch   8/16 | loss=0.0355 | acc=0.9900


                                                                                                                                                                

[Stage 1] Epoch   9/16 | loss=0.0236 | acc=0.9932


                                                                                                                                                                

[Stage 1] Epoch  10/16 | loss=0.0143 | acc=0.9963


                                                                                                                                                                

[Stage 1] Epoch  11/16 | loss=0.0089 | acc=0.9981


                                                                                                                                                                

[Stage 1] Epoch  12/16 | loss=0.0031 | acc=0.9998


                                                                                                                                                                

[Stage 1] Epoch  13/16 | loss=0.0053 | acc=0.9984


                                                                                                                                                                

[Stage 1] Epoch  14/16 | loss=0.0030 | acc=0.9992


                                                                                                                                                                

[Stage 1] Epoch  15/16 | loss=0.0043 | acc=0.9989


                                                                                                                                                                

[Stage 1] Epoch  16/16 | loss=0.0014 | acc=0.9996
Encoder saved.


## 3. Stage 2 — MIL Training (encoder frozen)

In [7]:
attention = AttentionModule(
    embed_dim=cfg["model"]["embedding_dim"],
    attention_hidden_dim=cfg["model"]["attention_hidden_dim"],
)
classifier = BagClassifier(
    embed_dim=cfg["model"]["embedding_dim"],
    n_classes=cfg["model"]["n_classes"],
)
mil_model = CytokineABMIL(encoder, attention, classifier, encoder_frozen=True)

dynamics_stage2 = train_mil(
    mil_model,
    tube_dataset,
    n_epochs=cfg["training"]["stage2_epochs"],
    lr=cfg["training"]["lr"],
    momentum=cfg["training"]["momentum"],
    lr_scheduler=cfg["training"]["lr_scheduler"],
    lr_warmup_epochs=cfg["training"]["lr_warmup_epochs"],
    log_every_n_epochs=cfg["dynamics"]["log_every_n_epochs"],
    device=DEVICE,
    seed=SEED,
    verbose=True,
)

torch.save(mil_model.state_dict(), "mil_stage2.pt")
print("Stage 2 model saved.")

                                                                                                                                                                

[Stage 2/3] Epoch   1/100 | loss=4.4924


                                                                                                                                                                

[Stage 2/3] Epoch   2/100 | loss=4.3441


                                                                                                                                                                

[Stage 2/3] Epoch   3/100 | loss=4.2676


                                                                                                                                                                

[Stage 2/3] Epoch   4/100 | loss=4.1268


                                                                                                                                                                

[Stage 2/3] Epoch   5/100 | loss=3.9487


                                                                                                                                                                

[Stage 2/3] Epoch   6/100 | loss=3.7950


                                                                                                                                                                

[Stage 2/3] Epoch   7/100 | loss=3.5910


                                                                                                                                                                

[Stage 2/3] Epoch   8/100 | loss=3.4568


                                                                                                                                                                

[Stage 2/3] Epoch   9/100 | loss=3.3224


                                                                                                                                                                

[Stage 2/3] Epoch  10/100 | loss=3.1926


                                                                                                                                                                

[Stage 2/3] Epoch  11/100 | loss=3.0712


                                                                                                                                                                

[Stage 2/3] Epoch  12/100 | loss=3.0636


                                                                                                                                                                

[Stage 2/3] Epoch  13/100 | loss=2.8444


                                                                                                                                                                

[Stage 2/3] Epoch  14/100 | loss=2.7606


                                                                                                                                                                

[Stage 2/3] Epoch  15/100 | loss=2.6773


                                                                                                                                                                

[Stage 2/3] Epoch  16/100 | loss=2.6041


                                                                                                                                                                

[Stage 2/3] Epoch  17/100 | loss=2.5319


                                                                                                                                                                

[Stage 2/3] Epoch  18/100 | loss=2.4321


                                                                                                                                                                

[Stage 2/3] Epoch  19/100 | loss=2.3499


                                                                                                                                                                

[Stage 2/3] Epoch  20/100 | loss=2.2348


                                                                                                                                                                

[Stage 2/3] Epoch  21/100 | loss=2.1870


                                                                                                                                                                

[Stage 2/3] Epoch  22/100 | loss=2.0711


                                                                                                                                                                

[Stage 2/3] Epoch  23/100 | loss=2.0174


                                                                                                                                                                

[Stage 2/3] Epoch  24/100 | loss=1.9651


                                                                                                                                                                

[Stage 2/3] Epoch  25/100 | loss=1.9409


                                                                                                                                                                

[Stage 2/3] Epoch  26/100 | loss=1.9236


                                                                                                                                                                

[Stage 2/3] Epoch  27/100 | loss=1.8326


                                                                                                                                                                

[Stage 2/3] Epoch  28/100 | loss=1.7595


                                                                                                                                                                

[Stage 2/3] Epoch  29/100 | loss=1.6844


                                                                                                                                                                

[Stage 2/3] Epoch  30/100 | loss=1.6655


                                                                                                                                                                

[Stage 2/3] Epoch  31/100 | loss=1.6048


                                                                                                                                                                

[Stage 2/3] Epoch  32/100 | loss=1.5549


                                                                                                                                                                

[Stage 2/3] Epoch  33/100 | loss=1.5394


                                                                                                                                                                

[Stage 2/3] Epoch  34/100 | loss=1.4558


                                                                                                                                                                

[Stage 2/3] Epoch  35/100 | loss=1.4056


                                                                                                                                                                

[Stage 2/3] Epoch  36/100 | loss=1.3916


                                                                                                                                                                

[Stage 2/3] Epoch  37/100 | loss=1.3642


                                                                                                                                                                

[Stage 2/3] Epoch  38/100 | loss=1.3352


                                                                                                                                                                

[Stage 2/3] Epoch  39/100 | loss=1.2954


                                                                                                                                                                

[Stage 2/3] Epoch  40/100 | loss=1.2584


                                                                                                                                                                

[Stage 2/3] Epoch  41/100 | loss=1.2408


                                                                                                                                                                

[Stage 2/3] Epoch  42/100 | loss=1.2310


                                                                                                                                                                

[Stage 2/3] Epoch  43/100 | loss=1.1615


                                                                                                                                                                

[Stage 2/3] Epoch  44/100 | loss=1.1145


                                                                                                                                                                

[Stage 2/3] Epoch  45/100 | loss=1.1420


                                                                                                                                                                

[Stage 2/3] Epoch  46/100 | loss=1.0965


                                                                                                                                                                

[Stage 2/3] Epoch  47/100 | loss=1.0928


                                                                                                                                                                

[Stage 2/3] Epoch  48/100 | loss=1.0915


                                                                                                                                                                

[Stage 2/3] Epoch  49/100 | loss=0.9859


                                                                                                                                                                

[Stage 2/3] Epoch  50/100 | loss=0.9745


                                                                                                                                                                

[Stage 2/3] Epoch  51/100 | loss=0.9579


                                                                                                                                                                

[Stage 2/3] Epoch  52/100 | loss=0.9373


                                                                                                                                                                

[Stage 2/3] Epoch  53/100 | loss=0.9033


                                                                                                                                                                

[Stage 2/3] Epoch  54/100 | loss=0.9078


                                                                                                                                                                

[Stage 2/3] Epoch  55/100 | loss=0.9089


                                                                                                                                                                

[Stage 2/3] Epoch  56/100 | loss=0.8889


                                                                                                                                                                

[Stage 2/3] Epoch  57/100 | loss=0.8181


                                                                                                                                                                

[Stage 2/3] Epoch  58/100 | loss=0.8360


                                                                                                                                                                

[Stage 2/3] Epoch  59/100 | loss=0.8033


                                                                                                                                                                

[Stage 2/3] Epoch  60/100 | loss=0.7673


                                                                                                                                                                

[Stage 2/3] Epoch  61/100 | loss=0.7440


                                                                                                                                                                

[Stage 2/3] Epoch  62/100 | loss=0.7583


                                                                                                                                                                

[Stage 2/3] Epoch  63/100 | loss=0.7272


                                                                                                                                                                

[Stage 2/3] Epoch  64/100 | loss=0.7888


                                                                                                                                                                

[Stage 2/3] Epoch  65/100 | loss=0.6843


                                                                                                                                                                

[Stage 2/3] Epoch  66/100 | loss=0.6482


                                                                                                                                                                

[Stage 2/3] Epoch  67/100 | loss=0.6751


                                                                                                                                                                

[Stage 2/3] Epoch  68/100 | loss=0.6108


                                                                                                                                                                

[Stage 2/3] Epoch  69/100 | loss=0.6057


                                                                                                                                                                

[Stage 2/3] Epoch  70/100 | loss=0.6343


                                                                                                                                                                

[Stage 2/3] Epoch  71/100 | loss=0.5852


                                                                                                                                                                

[Stage 2/3] Epoch  72/100 | loss=0.6073


                                                                                                                                                                

[Stage 2/3] Epoch  73/100 | loss=0.6372


                                                                                                                                                                

[Stage 2/3] Epoch  74/100 | loss=0.6121


                                                                                                                                                                

[Stage 2/3] Epoch  75/100 | loss=0.5819


                                                                                                                                                                

[Stage 2/3] Epoch  76/100 | loss=0.6146


                                                                                                                                                                

[Stage 2/3] Epoch  77/100 | loss=0.5504


                                                                                                                                                                

[Stage 2/3] Epoch  78/100 | loss=0.5683


                                                                                                                                                                

[Stage 2/3] Epoch  79/100 | loss=0.5285


                                                                                                                                                                

[Stage 2/3] Epoch  80/100 | loss=0.4732


                                                                                                                                                                

[Stage 2/3] Epoch  81/100 | loss=0.5678


                                                                                                                                                                

[Stage 2/3] Epoch  82/100 | loss=0.4923


                                                                                                                                                                

[Stage 2/3] Epoch  83/100 | loss=0.5005


                                                                                                                                                                

[Stage 2/3] Epoch  84/100 | loss=0.4887


                                                                                                                                                                

[Stage 2/3] Epoch  85/100 | loss=0.5270


                                                                                                                                                                

[Stage 2/3] Epoch  86/100 | loss=0.5062


                                                                                                                                                                

[Stage 2/3] Epoch  87/100 | loss=0.5092


                                                                                                                                                                

[Stage 2/3] Epoch  88/100 | loss=0.4534


                                                                                                                                                                

[Stage 2/3] Epoch  89/100 | loss=0.4263


                                                                                                                                                                

[Stage 2/3] Epoch  90/100 | loss=0.4285


                                                                                                                                                                

[Stage 2/3] Epoch  91/100 | loss=0.4164


                                                                                                                                                                

[Stage 2/3] Epoch  92/100 | loss=0.3985


                                                                                                                                                                

[Stage 2/3] Epoch  93/100 | loss=0.4213


                                                                                                                                                                

[Stage 2/3] Epoch  94/100 | loss=0.4356


                                                                                                                                                                

[Stage 2/3] Epoch  95/100 | loss=0.4060


                                                                                                                                                                

[Stage 2/3] Epoch  96/100 | loss=0.3573


                                                                                                                                                                

[Stage 2/3] Epoch  97/100 | loss=0.3889


                                                                                                                                                                

[Stage 2/3] Epoch  98/100 | loss=0.3850


                                                                                                                                                                

[Stage 2/3] Epoch  99/100 | loss=0.3672


                                                                                                                                                                

[Stage 2/3] Epoch 100/100 | loss=0.3970
Stage 2 model saved.


## 4. Stage 3 — Joint Fine-tuning (optional)

In [None]:
mil_model.unfreeze_encoder()

dynamics_stage3 = train_mil(
    mil_model,
    tube_dataset,
    n_epochs=cfg["training"]["stage3_epochs"],
    lr=cfg["training"]["lr"] * 0.1,  # lower LR for fine-tuning
    momentum=cfg["training"]["momentum"],
    log_every_n_epochs=cfg["dynamics"]["log_every_n_epochs"],
    device=DEVICE,
    seed=SEED,
    verbose=True,
)

torch.save(mil_model.state_dict(), "mil_stage3.pt")
print("Stage 3 model saved.")

## 5. Dynamics Analysis

In [None]:
# Use Stage 2 dynamics for primary analysis (encoder frozen = cleaner dynamics)
donor_traj = aggregate_to_donor_level(dynamics_stage2["records"])

# Learnability ranking (exclude PBS from biological interpretation)
learnability_result = rank_cytokines_by_learnability(donor_traj, exclude=["PBS"])
ranking = learnability_result["ranking"]

print("Cytokine learnability ranking")
print(f"Metric: {learnability_result['metric_description']}")
print()
for i, (cyt, auc) in enumerate(ranking, 1):
    print(f"  {i:2d}. {cyt:20s}  AUC(mean_donor_p_correct_trajectory) = {auc:.3f}")

In [None]:
# Plot learning curves for top-10 and bottom-10 cytokines
top10 = [r[0] for r in ranking[:10]]
bot10 = [r[0] for r in ranking[-10:]]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, group, group_label in zip(
    axes,
    [top10, bot10],
    ["Top-10 (highest AUC — learned earliest)", "Bottom-10 (lowest AUC — learned latest)"],
):
    for cyt in group:
        donor_curves = list(donor_traj[cyt].values())
        mean_curve = np.mean(donor_curves, axis=0)
        epochs = dynamics_stage2["logged_epochs"]
        ax.plot(epochs, mean_curve, label=cyt, alpha=0.8)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("P(Y_correct | t) — softmax probability of correct cytokine class")
    ax.set_title(group_label)
    ax.legend(fontsize=7, ncol=2)

plt.suptitle(
    "Stage 2 learning curves\n"
    "Metric: mean p_correct_trajectory(t), aggregated to donor level "
    "(median across pseudo-tubes per donor, then mean across donors)",
    fontsize=9,
)
plt.tight_layout()
plt.savefig("learning_curves.png", dpi=150)
plt.show()

In [None]:
# Attention entropy summary
entropy_result = compute_cytokine_entropy_summary(dynamics_stage2["records"])
entropy_summary = entropy_result["summary"]

# Sort by mean entropy (low=focused, high=pleiotropic)
entropy_sorted = sorted(entropy_summary.items(), key=lambda x: x[1]["mean_entropy"])

print("Cytokine attention entropy summary")
print(f"Metric: {entropy_result['metric_description']}")
print()
for cyt, stats in entropy_sorted:
    print(f"  {cyt:20s}  mean_entropy = {stats['mean_entropy']:.3f}  std = {stats['std_entropy']:.3f}")

In [None]:
# Confusion entropy summary
confusion_result = compute_confusion_entropy_summary(
    dynamics_stage2["confusion_entropy_trajectory"], exclude=["PBS"]
)

print("Cytokine confusion entropy ranking")
print(f"Metric: {confusion_result['metric_description']}")
print()
for cyt, auc in confusion_result["ranking"]:
    print(f"  {cyt:20s}  AUC(confusion_entropy_trajectory) = {auc:.3f}")

## 6. Validation

In [11]:
# Seed stability — run with all three seeds from config
# NOTE: Pre-register your directional predictions BEFORE looking at these results.

all_dynamics = [dynamics_stage2]  # Add dynamics from other seeds here

# Example: to run with additional seeds, re-run train_mil with seed=123 and seed=7
# and append to all_dynamics.

if len(all_dynamics) > 1:
    stability = check_seed_stability(all_dynamics, exclude=["PBS"])
    print(f"Mean Spearman rho across seeds: {stability['mean_rho']:.3f}")
    print(f"Stable ordering: {stability['stable']}")
else:
    print("Run with multiple seeds to assess stability. See config random_seeds.")

Run with multiple seeds to assess stability. See config random_seeds.


In [12]:
# Known functional groupings (IL-2 / IL-15 should be similar)
known_groups = {
    "IL-2_IL-15_family": ["IL-2", "IL-15"],
    "type_I_IFN": ["IFN-alpha", "IFN-beta"],  # adjust to actual cytokine names
}

grouping_result = check_functional_groupings(donor_traj, known_groups)
for group, result in grouping_result.items():
    print(f"\n{group}:")
    for k, v in result.items():
        print(f"  {k}: {v}")


IL-2_IL-15_family:
  members_found: ['IL-2', 'IL-15']
  within_auc_std: 2.9490238849151282
  between_auc_std: 13.025140774941931
  passes: True

type_I_IFN:
  error: fewer than 2 members found: ['IFN-beta']


donor_traj_s3 = aggregate_to_donor_level(dynamics_stage3["records"])
result_s3 = rank_cytokines_by_learnability(donor_traj_s3, exclude=["PBS"])
ranking_s3 = result_s3["ranking"]

from cytokine_mil.analysis.validation import check_seed_stability
# Reuse seed stability check to compare two orderings (Stage 2 vs Stage 3)
stability_s2_s3 = check_seed_stability(
    [dynamics_stage2, dynamics_stage3], exclude=["PBS"]
)
print("Stage 2 vs Stage 3 ranking correlation")
print(
    "Metric: Spearman rho between cytokine learnability rankings "
    "(AUC of donor-level p_correct_trajectory, median per donor, mean across donors)"
)
print(f"  Spearman rho = {stability_s2_s3['mean_rho']:.3f}")
print(f"  Stable across stages (rho > 0.7): {stability_s2_s3['stable']}")

In [None]:
donor_traj_s3 = aggregate_to_donor_level(dynamics_stage3["records"])
ranking_s3 = rank_cytokines_by_learnability(donor_traj_s3, exclude=["PBS"])

from cytokine_mil.analysis.validation import check_seed_stability
# Reuse seed stability to compare two orderings (Stage 2 vs 3)
stability_s2_s3 = check_seed_stability(
    [dynamics_stage2, dynamics_stage3], exclude=["PBS"]
)
print(f"Stage2 vs Stage3 ranking correlation: {stability_s2_s3['mean_rho']:.3f}")
print(f"Stable across stages: {stability_s2_s3['stable']}")