# Cytokine Signaling Cascade Mapping via AB-MIL Dynamics

This notebook runs the full experiment:
1. Load config and data
2. Stage 1 — pre-train the InstanceEncoder with cell-type supervision
3. Stage 2 — train full AB-MIL (encoder frozen)
4. Stage 3 (optional) — fine-tune jointly
5. Dynamics analysis — learnability ranking, entropy, instance confidence
6. Validation — seed stability, known-group checks

Connect to the cluster kernel before running.
All paths in `configs/default.yaml` point to cluster storage.

In [1]:
import json
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from torch.utils.data import DataLoader

from cytokine_mil.data.label_encoder import CytokineLabel
from cytokine_mil.data.dataset import PseudoTubeDataset, CellDataset
from cytokine_mil.models.instance_encoder import InstanceEncoder
from cytokine_mil.models.attention import AttentionModule
from cytokine_mil.models.bag_classifier import BagClassifier
from cytokine_mil.models.cytokine_abmil import CytokineABMIL
from cytokine_mil.training.train_encoder import train_encoder
from cytokine_mil.training.train_mil import train_mil
from cytokine_mil.analysis.dynamics import (
    aggregate_to_donor_level,
    rank_cytokines_by_learnability,
    compute_cytokine_entropy_summary,
)
from cytokine_mil.analysis.validation import (
    check_seed_stability,
    check_functional_groupings,
)

In [2]:
# --- Config ---
with open("cytokines/cytokines-mil/configs/default.yaml") as f:
    cfg = yaml.safe_load(f)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = cfg["dynamics"]["random_seeds"][0]
print(f"Device: {DEVICE}")
print(f"Seed: {SEED}")

Device: cuda
Seed: 42


## 1. Data

In [3]:
MANIFEST_PATH = cfg["data"]["manifest_path"]

with open(MANIFEST_PATH) as f:
    manifest = json.load(f)

# Load HVG list (saved by preprocess_tubes.ipynb)
HVG_PATH = str(Path(MANIFEST_PATH).parent / "hvg_list.json")
with open(HVG_PATH) as f:
    gene_names = json.load(f)

print(f"Manifest entries: {len(manifest)}")
print(f"HVGs: {len(gene_names)}")

Manifest entries: 10920
HVGs: 4000


In [4]:
# Label encoder — must be built once and saved for reproducibility
LABEL_ENCODER_PATH = str(Path(MANIFEST_PATH).parent / "label_encoder.json")
label_encoder = CytokineLabel().fit(manifest)
label_encoder.save(LABEL_ENCODER_PATH)
print(f"Classes: {label_encoder.n_classes()} (PBS at index {label_encoder.encode('PBS')})")

Classes: 91 (PBS at index 90)


In [None]:
from collections import defaultdict

# Pseudo-tube dataset (Stage 2/3)
tube_dataset = PseudoTubeDataset(MANIFEST_PATH, label_encoder, gene_names=gene_names)
print(f"Tubes: {len(tube_dataset)}")

# --- Stage 1 manifest: one tube per cytokine, rotating donors ---
# ~91 tubes × ~450 cells ≈ 40k cells ≈ 640 MB when preloaded.
# Using the full 10k-tube manifest with shuffle=True would require
# ~38 hours (random h5ad reads defeat the LRU cache).
_cyt_to_entries: dict = defaultdict(list)
for _e in manifest:
    if _e["tube_idx"] == 0:
        _cyt_to_entries[_e["cytokine"]].append(_e)

_stage1_manifest = []
for _i, _cyt in enumerate(sorted(_cyt_to_entries)):
    _entries = sorted(_cyt_to_entries[_cyt], key=lambda e: e["donor"])
    _stage1_manifest.append(_entries[_i % len(_entries)])

STAGE1_MANIFEST_PATH = str(Path(MANIFEST_PATH).parent / "manifest_stage1.json")
with open(STAGE1_MANIFEST_PATH, "w") as f:
    json.dump(_stage1_manifest, f)

# preload=True: loads all tubes at init → in-memory shuffling, no disk I/O per batch
cell_dataset = CellDataset(STAGE1_MANIFEST_PATH, gene_names=gene_names, preload=True)
print(f"Cells: {len(cell_dataset)}")
print(f"Cell types: {cell_dataset.n_cell_types()}")

cell_loader = DataLoader(cell_dataset, batch_size=256, shuffle=True, num_workers=0)

## 2. Stage 1 — Encoder Pre-training

In [6]:
encoder = InstanceEncoder(
    input_dim=len(gene_names),
    embed_dim=cfg["model"]["embedding_dim"],
    n_cell_types=cell_dataset.n_cell_types(),
)

encoder = train_encoder(
    encoder,
    cell_loader,
    n_epochs=cfg["training"]["stage1_epochs"],
    lr=cfg["training"]["lr"],
    momentum=cfg["training"]["momentum"],
    device=DEVICE,
    verbose=True,
)

torch.save(encoder.state_dict(), "encoder_stage1.pt")
print("Encoder saved.")

                                                                                                      

KeyboardInterrupt: 

## 3. Stage 2 — MIL Training (encoder frozen)

In [None]:
attention = AttentionModule(
    embed_dim=cfg["model"]["embedding_dim"],
    attention_hidden_dim=cfg["model"]["attention_hidden_dim"],
)
classifier = BagClassifier(
    embed_dim=cfg["model"]["embedding_dim"],
    n_classes=cfg["model"]["n_classes"],
)
mil_model = CytokineABMIL(encoder, attention, classifier, encoder_frozen=True)

dynamics_stage2 = train_mil(
    mil_model,
    tube_dataset,
    n_epochs=cfg["training"]["stage2_epochs"],
    lr=cfg["training"]["lr"],
    momentum=cfg["training"]["momentum"],
    lr_scheduler=cfg["training"]["lr_scheduler"],
    lr_warmup_epochs=cfg["training"]["lr_warmup_epochs"],
    log_every_n_epochs=cfg["dynamics"]["log_every_n_epochs"],
    device=DEVICE,
    seed=SEED,
    verbose=True,
)

torch.save(mil_model.state_dict(), "mil_stage2.pt")
print("Stage 2 model saved.")

## 4. Stage 3 — Joint Fine-tuning (optional)

In [None]:
mil_model.unfreeze_encoder()

dynamics_stage3 = train_mil(
    mil_model,
    tube_dataset,
    n_epochs=cfg["training"]["stage3_epochs"],
    lr=cfg["training"]["lr"] * 0.1,  # lower LR for fine-tuning
    momentum=cfg["training"]["momentum"],
    log_every_n_epochs=cfg["dynamics"]["log_every_n_epochs"],
    device=DEVICE,
    seed=SEED,
    verbose=True,
)

torch.save(mil_model.state_dict(), "mil_stage3.pt")
print("Stage 3 model saved.")

## 5. Dynamics Analysis

In [None]:
# Use Stage 2 dynamics for primary analysis (encoder frozen = cleaner dynamics)
donor_traj = aggregate_to_donor_level(dynamics_stage2["records"])

# Learnability ranking (exclude PBS from biological interpretation)
ranking = rank_cytokines_by_learnability(donor_traj, exclude=["PBS"])

print("Cytokine learnability ranking (highest AUC = learned earliest):")
for i, (cyt, auc) in enumerate(ranking, 1):
    print(f"  {i:2d}. {cyt:20s}  AUC={auc:.3f}")

In [None]:
# Plot learning curves for top-10 and bottom-10 cytokines
top10 = [r[0] for r in ranking[:10]]
bot10 = [r[0] for r in ranking[-10:]]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, group, title in zip(axes, [top10, bot10], ["Top-10 (earliest)", "Bottom-10 (latest)"]):
    for cyt in group:
        # Mean across donors
        donor_curves = list(donor_traj[cyt].values())
        mean_curve = np.mean(donor_curves, axis=0)
        epochs = dynamics_stage2["logged_epochs"]
        ax.plot(epochs, mean_curve, label=cyt, alpha=0.8)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("P(Y_correct)")
    ax.set_title(title)
    ax.legend(fontsize=7, ncol=2)

plt.tight_layout()
plt.savefig("learning_curves.png", dpi=150)
plt.show()

In [None]:
# Attention entropy summary
entropy_summary = compute_cytokine_entropy_summary(dynamics_stage2["records"])

# Sort by mean entropy
entropy_sorted = sorted(entropy_summary.items(), key=lambda x: x[1]["mean"])

print("Attention entropy (low=focused, high=pleiotropic):")
for cyt, stats in entropy_sorted:
    print(f"  {cyt:20s}  mean={stats['mean']:.3f}  std={stats['std']:.3f}")

## 6. Validation

In [None]:
# Seed stability — run with all three seeds from config
# NOTE: Pre-register your directional predictions BEFORE looking at these results.

all_dynamics = [dynamics_stage2]  # Add dynamics from other seeds here

# Example: to run with additional seeds, re-run train_mil with seed=123 and seed=7
# and append to all_dynamics.

if len(all_dynamics) > 1:
    stability = check_seed_stability(all_dynamics, exclude=["PBS"])
    print(f"Mean Spearman rho across seeds: {stability['mean_rho']:.3f}")
    print(f"Stable ordering: {stability['stable']}")
else:
    print("Run with multiple seeds to assess stability. See config random_seeds.")

In [None]:
# Known functional groupings (IL-2 / IL-15 should be similar)
known_groups = {
    "IL-2_IL-15_family": ["IL-2", "IL-15"],
    "type_I_IFN": ["IFN-alpha", "IFN-beta"],  # adjust to actual cytokine names
}

grouping_result = check_functional_groupings(donor_traj, known_groups)
for group, result in grouping_result.items():
    print(f"\n{group}:")
    for k, v in result.items():
        print(f"  {k}: {v}")

## 7. Stage 2 vs Stage 3 Comparison

If the learnability ordering is stable across Stage 2 and Stage 3, this is
evidence that the dynamics signal is robust to encoder fine-tuning.

In [None]:
donor_traj_s3 = aggregate_to_donor_level(dynamics_stage3["records"])
ranking_s3 = rank_cytokines_by_learnability(donor_traj_s3, exclude=["PBS"])

from cytokine_mil.analysis.validation import check_seed_stability
# Reuse seed stability to compare two orderings (Stage 2 vs 3)
stability_s2_s3 = check_seed_stability(
    [dynamics_stage2, dynamics_stage3], exclude=["PBS"]
)
print(f"Stage2 vs Stage3 ranking correlation: {stability_s2_s3['mean_rho']:.3f}")
print(f"Stable across stages: {stability_s2_s3['stable']}")