Effendi estimates the causal delay τ_g — how long it takes for chromatin accessibility changes (ATAC) to propagate to gene expression (RNA) — using a learned manifold and Neural ODE.
Core idea: The vector field v_θ defines cell state trajectories. ATAC is the initial condition; the ODE integrates forward from ATAC. RNA is the target — it never directly interacts with ATAC, it only supervises the result.
gene_emb (n_genes, 512) ← scGPT frozen
┌──────────┴──────────┐
↓ ↓
AmortizedTau GeneDecoder
↓ ↑
τ_g (B,n_genes) │
↓ │
z_ATAC (256d) ──┐ ↓ │
AlignMLP │ ┌────┴─────┐ │
z_A (512d) ────┤ │ │ │
├───┤ ODE ├── hidden ─────┘
z_R (512d) │ │ integrate │ (B,n_genes,D)
│ └──────────┘
│ ↑
│ v_θ fields
│ │
└─── L_ode = ‖z_R - φ_τ(z_A)‖²
(RNA only as supervision)
z_ATAC = HyenaDNA(atac_peaks) # (B, 256) peak sequences → embedding
z_R = scGPT(rna_counts) # (B, 512) gene expression → embedding
z_A = AlignMLP(z_ATAC) # (B, 512) align ATAC to shared latent spaceTwo separate encoding paths, independently producing embeddings in the same 512d space.
τ_g = AmortizedTauPredictor(z_A, gene_emb) # (B, n_genes)How τ_g is computed:
| Component | Formula | Shape |
|---|---|---|
| Gene baseline | τ_base = softplus(MLP(gene_emb)) |
(n_genes,) |
| Cell modulation | mod = MLP(z_A) |
(B, 512) |
| Gene-specific delta | delta = mod @ gene_emb.T |
(B, n_genes) |
| Final τ_g | softplus(τ_base + delta) |
(B, n_genes) |
gene_embis frozen (from scGPT GeneEncoder). OnlyMLP(gene_emb)andMLP(z_A)are trained.τ_base: each gene has a "default" delay learned from its embedding (e.g. FOS → short, CEBPA → long)delta: cell state modulates the delay — same gene can have different delay in different cellsmod @ gene_emb.T: projects the cell's modulation vector onto each gene's embedding space
# Integrate forward from z_A for τ_max time
trajectory = v_θ.integrate(z_A, [0, ..., τ_max]) # (200 steps, B, 512)
# At each gene g, sample the trajectory at that gene's τ_g time
hidden_states = interpolate(trajectory, τ_g) # (B, n_genes, 512)Key: ODE starts from z_A (ATAC). RNA is never an input to the ODE. The integration is done once to τ_max, then for each gene g, we sample at time τ_g[cell, g] via linear interpolation.
ŷ_g = GeneDecoder(hidden_states, gene_emb) # (B, n_genes)
# internal: h = proj(hidden_states.mean(dim=1))
# ŷ_g = sigmoid(h @ gene_emb.T)# ATAC → ODE → should reach RNA state (median τ̄)
L_ode = ‖z_R - φ_τ̄(z_A)‖²
# Correct τ_g → correct expression prediction
L_recon = ‖ŷ_g - gene_expr‖²
# Regularization
L_τ = |τ_g|
L_reverse = ‖z_A - φ_{-τ̄}(φ_τ̄(z_A))‖²
L = L_ode + ½ L_recon + 0.01 L_τ + 0.1 L_reverseNo direct ground truth for τ_g exists. τ_g is learned through gradient descent on two losses:
- L_ode: Use median τ_g to integrate ATAC → should reach RNA state
- L_recon: Decode expression at each gene's τ_g → should match true expression
If τ_g is wrong for gene g:
- The hidden state at time τ_g won't contain the right information
L_reconwill be high → gradient flows back → τ_g adjusts
They never directly interact. The only coupling is through the ODE:
z_A (ATAC) ──→ ODE integrate ──→ φ_τ(z_A) ← compares with → z_R (RNA)
- ATAC is the initial condition of the ODE
- RNA is the supervision target (provides the loss signal)
- RNA teaches the vector field which direction to flow
- RNA teaches τ_g when to read out each gene
| Output | Shape | Meaning |
|---|---|---|
τ_g |
(B, n_genes) or (n_genes,) | Per-gene causal delay (primary result) |
ŷ_g |
(B, n_genes) | Predicted gene expression (used for supervision) |
z_integrated |
(B, 512) | ATAC integrated forward by median τ̄ |
The primary goal is τ_g. ŷ_g is the supervision signal that validates τ_g is correct.
Training proceeds in 3 stages, each building on the previous:
Stage 0: Embedding Precomputation
─────────────────────────────────
CIMA RNA h5ad → scGPT encode → effendi_rna_emb.npy (N, 512)
CIMA ATAC h5ad → HyenaDNA → effendi_atac_emb.npy (N, 256)
Stage 1: Train Vector Field v_θ
───────────────────────────────
CIMA unpaired data (~400K cells)
z_ATAC (256d) ── AlignMLP ──→ z_A (512d)
│
z_R (512d) ─────────────────────┼──→ EffendiVectorField
SlicedW2 │ (SpectralNorm MLP)
distribution │ v_θ.integrate(z_A, t)
matching loss ────┤
│
φ_τ̄(z_A)
│
L_reverse = ‖z_A - φ_{-τ̄}(φ_τ̄(z_A))‖²
→ vector_field_frozen.pt, align_mlp.pt
Stage 2.1: Train Delay Model (v_θ frozen)
────────────────────────────────────────
10X Multiome paired data
gene_emb (n_genes, 512) ← scGPT GeneEncoder (frozen)
z_A ──→ AmortizedTauPredictor(z_A, gene_emb) ──→ τ_g (B, n_genes)
│
└── tau_mlp(gene_emb) + cell_mod(z_A) @ gene_emb.T
z_A ──→ v_θ.integrate ──→ trajectory (200 steps)
│
interpolate at τ_g ──→ hidden_states (B, n_genes, 512)
│
GeneDecoder(hidden_states, gene_emb) ──→ ŷ_g (B, n_genes)
L = L_ode + ½ L_recon + 0.01 L_τ
→ effendi_stage1.pt
Stage 2.2: Fine-tune Delay Model (v_θ unfrozen, lr=1e-6)
─────────────────────────────────────────────────────
Same architecture, v_θ unfrozen with small learning rate
→ effendi_final.pt, gene_tau_estimates.csv
# Encode CIMA cells (bottleneck: 44GB + 34GB h5ad files)
python precompute_embeddings.py rna --max-cells 200000 # → (N, 512) CIMA RNA embeddings
python precompute_embeddings.py atac --max-cells 200000 # → (N, 256) CIMA ATAC embeddingsOutput: effendi_rna_emb.npy, effendi_rna_obs.pkl, effendi_atac_emb.npy, effendi_atac_obs.pkl, peak_embeddings.npy
python train_vf.pyGoal: learn shared manifold + ATAC→RNA flow direction
Input: random immune-cell embeddings from CIMA (unpaired)
Model: AlignMLP(256→512) + EffendiVectorField + global τ̄
Loss: SlicedW2(φ_τ̄(z_A_distribution), z_R_distribution) + 0.1 × L_reverse
Output: vector_field_frozen.pt, align_mlp.pt
No bins, no pseudotime, no cell types. Pure distribution matching.
# First extract gene embeddings from scGPT
python -c "from scgpt_encoder import save_gene_embeddings; save_gene_embeddings(gene_names, config.SCGPT_MODEL_DIR, 'effendi_output/gene_embeddings.npy')"
# Then train delay model
python train_delay.pyGoal: learn per-gene τ_g
Input: paired z_R (N,512), z_A (N,512), gene_expr (N,n_genes), gene_emb (n_genes,512)
Model: Stage 2.1: Freeze v_θ → train AmortizedTau + GeneDecoder + AlignMLP
Stage 2.2: Fine-tune v_θ (lr=1e-6) + continue training
Loss: L_ode + ½ L_recon + 0.01 L_τ
Output: effendi.pt, gene_tau_estimates.csv
| Aspect | Stage 1 | Stage 2.1 | Stage 2.2 |
|---|---|---|---|
| Data | CIMA unpaired | 10X paired | 10X paired |
| Supervision | Distribution matching | Reconstruction | Reconstruction |
| v_θ | Train | Frozen | Fine-tune (lr=1e-6) |
| AmortizedTau | — | Train | Train |
| GeneDecoder | — | Train | Train |
| AlignMLP | Train | Train | Train |
Computes per-gene delay from gene embeddings and cell state:
class AmortizedTauPredictor(nn.Module):
def __init__(self, hidden=128, mod_scale=0.1):
self.hidden = hidden
self.mod_scale = mod_scale
def _init_mlps(self, z_dim, gene_dim):
# gene_emb → scalar τ_base
self.tau_mlp = nn.Sequential(
nn.Linear(gene_dim, self.hidden), nn.Tanh(), nn.Linear(self.hidden, 1)
)
# z_A → cell modulation vector
self.cell_mod = nn.Sequential(
nn.Linear(z_dim, self.hidden), nn.LayerNorm(self.hidden),
nn.Tanh(), nn.Dropout(0.1), nn.Linear(self.hidden, z_dim),
)
def forward(self, z_A, gene_emb):
gene_dim = gene_emb.shape[1]
z_dim = z_A.shape[1] if z_A is not None else gene_dim
if not hasattr(self, 'tau_mlp') or self._gene_dim != gene_dim:
self._init_mlps(z_dim, gene_dim)
tau_base = F.softplus(self.tau_mlp(gene_emb)).squeeze(-1) + 1e-3 # (n_genes,)
if z_A is None:
return tau_base # Global mode
mod = self.cell_mod(z_A) * self.mod_scale # (B, D)
delta = (mod @ gene_emb.T) * self.mod_scale # (B, n_genes)
return F.softplus(tau_base + delta) + 1e-3Key design:
tau_mlp: learns "which semantic features determine delay" from gene embeddingscell_mod: learns "how cell state modulates delay" from ATAC latentmod @ gene_emb.T: projects cell modulation onto each gene's embedding → gene-specific delta* 0.1: small amplitude modulation ensures training stability
Dot-product decoder (gene-agnostic):
class GeneDecoder(nn.Module):
def __init__(self, dim=512, hidden=256):
self.proj = nn.Sequential(
nn.Linear(dim, hidden), nn.ReLU(),
nn.Dropout(0.1), nn.Linear(hidden, dim),
)
def forward(self, hidden_states, gene_emb):
h = self.proj(hidden_states.mean(dim=1)) # (B, D)
return torch.sigmoid(h @ gene_emb.T) # (B, n_genes)Lipschitz-bounded MLP with SpectralNorm and manual ODE integration (no external dependencies):
class EffendiVectorField(nn.Module):
def __init__(self, dim=512, hidden_dim=128):
self.net = nn.Sequential(
SpectralNormLinear(dim, hidden_dim), nn.Tanh(),
SpectralNormLinear(hidden_dim, hidden_dim), nn.Tanh(),
SpectralNormLinear(hidden_dim, dim),
)
def integrate(self, z0, t, n_steps=100):
"""Manual Euler/RK4 integration — no torchdiffeq, MPS-compatible."""
dt = (t[-1] - t[0]) / n_steps
z = z0
results = [z]
for _ in range(n_steps):
z = z + dt * self.net(z) # Euler step
results.append(z)
return torch.stack(results, dim=0)SpectralNorm on all linear layers ensures Lipschitz ≤ 1, guaranteeing stable ODE integration and invertibility. Manual integration uses float32 only — fully compatible with MPS (Apple Silicon).
Nonlinear ATAC alignment:
class AlignMLP(nn.Module):
def __init__(self, in_dim=256, out_dim=512, hidden=512):
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.LayerNorm(hidden), nn.GELU(),
nn.Dropout(0.1), nn.Linear(hidden, out_dim), nn.LayerNorm(out_dim),
)Previous architecture hardcoded nn.Parameter(3000) per gene — breaks on different datasets.
Replace parameter lookup with computation from scGPT gene embeddings:
# Previous: independent parameters
τ_g[g] = softplus(log_tau[g]) # log_tau = nn.Parameter(3000)
# Now: computed from gene semantics
τ_g[g] = softplus(MLP(gene_emb[g]))
# Previous: fixed output dimension
ŷ = Linear(512, 3000)(h)
# Now: dot product with gene embeddings
ŷ = sigmoid(proj(h) @ gene_emb.T)| Aspect | Previous | Gene-agnostic |
|---|---|---|
| Generalization | Only training genes | Any gene in scGPT (~30K) |
| Parameters | ~1.5M | ~65K (23× reduction) |
| Dataset migration | Retrain required | Swap gene_emb |
| Semantic consistency | Independent | Functionally similar genes → similar τ |
10X paired data:
- RNA: (N, 36K genes) → HVG selection → (N, n_genes)
- Extract scGPT gene embedding → (n_genes, 512) gene_emb
Forward pass:
z_A → AmortizedTau(z_A, gene_emb) → τ_g (B, n_genes)
z_A → ODE integrate → trajectory → interpolate at τ_g → hidden_states
hidden_states → GeneDecoder(hidden_states, gene_emb) → ŷ_g (B, n_genes)
loss = L_ode + L_recon + L_tau + L_reverse
New dataset:
- Extract HVG → (n_genes_new, 512) new_gene_emb
- scGPT encode → z_R
Inference:
z_A_new → AmortizedTau(z_A_new, new_gene_emb) → τ_g_new # works directly
z_A_new → ODE integrate → hidden_states_new
hidden_states_new → GeneDecoder(hidden, new_gene_emb) → ŷ_new # works directly
No retraining needed. Just swap gene_emb lookup table.
Effendi/
├── config.py # Paths and hyperparameters
├── effendi_loader.py # h5ad load (backed/streaming) + data prep
│
├── scgpt_encoder.py # scGPT inference + gene embedding extraction
├── hyenadna_encoder.py # HyenaDNA inference (+ peak_embeddings export)
│
├── align_mlp.py # AlignMLP — nonlinear ATAC alignment
├── slice_w2.py # Sliced W2 distance (pure PyTorch)
├── vector_field.py # EffendiVectorField (SpectralNorm + manual Euler ODE)
├── train_vf.py # Stage 1: v_θ + AlignMLP + reverse consistency
│
├── effendi.py # AmortizedTauPredictor + GeneDecoder + Effendi model
├── train_delay.py # Stage 2: delay training (frozen VF → fine-tune)
│
├── demo_real_emb.py # Demo: real scGPT + HyenaDNA embeddings + Stage 1
├── inference.py # scRNA_to_tau, reverse_ode
├── intervention.py # Bootstrap inference + displacement classification
├── validate.py # Five-layer validation
└── run.py # Main entry point
Run the full pipeline with real scGPT (RNA) and HyenaDNA (ATAC) embeddings:
python demo_real_emb.pyWhat it does:
Step 0: Auto-detect settings → batch=16, peaks=auto, seq_len=auto
Step 1: Load 10X Multiome (10k cells)
Step 2: RNA embeddings with scGPT → (10000, 512) ✓
Step 3: HVG selection → 2001 genes
Step 4: Gene embeddings from scGPT → (1654, 512) matched genes
Step 5: ATAC filtering → 111k → ~30k peaks (auto)
Step 6: ATAC embeddings with HyenaDNA → (10000, 256) ✓
Step 7: Align RNA + ATAC
Step 7.5: Stage 1 — Train Vector Field → SW2=0.733, Rev≈0
Step 8: Model forward pass → Effendi ready
Step 9: Cell-specific τ → (10000, 1654) ✓
Step 10: Full forward pass
Step 11: Tau analysis
Results (10k PBMC cells, 1654 genes):
| Metric | Value |
|---|---|
| RNA embeddings | (10000, 512) via scGPT |
| ATAC embeddings | (10000, 256) via HyenaDNA |
| Stage 1 SW2 | 0.733 (improving) |
| Rev error | ~0 (vector field invertible) |
| τ_g range | 0.649 – 0.722 |
| Gene-wise variation | 0.1% |
| Cell-wise variation | 0.7% |
Shortest delay genes: TRIB1, LINC02315, FKBP11, ID3, IL15 (immune regulation)
Longest delay genes: FCER1A, NAMPT, LIN7A, AP002370.2, IL7 (metabolism/signaling)
Hardware: Runs on MPS (Apple Silicon M-series) — no CUDA required. All layers use float32 for MPS compatibility.
All embedding settings are in config.py. Auto-detection is the default — override any setting as needed:
# config.py
# scGPT: max sequence length per cell
# None = auto (min(n_vars + 1, 1200)) — uses actual gene count, capped at 1200
# 300 = fast (fewer genes per cell)
# 1200 = full vocab (max precision)
SCGPT_MAX_SEQ_LEN = None
# scGPT: use Flash Attention (requires pip install flash-attn)
# True = ~2-3x faster, ~50% less memory
SCGPT_USE_FAST_TRANSFORMER = False
# HyenaDNA: batch size for peak embedding computation
# None = auto: MPS=16, CUDA=64, CPU=8
# Higher = faster but more memory
HYENADNA_BATCH_SIZE = None
# ATAC peak filtering (removes sparse/noisy peaks before HyenaDNA)
# None = auto: filter peaks present in <10% of cells → ~30k peaks
# 0 = disable (keep all 111k peaks, slow)
# 20000 = keep only top-20k by variance
ATAC_TOP_PEAKS = NonePerformance guide:
| Config | RNA time | ATAC time | Total |
|---|---|---|---|
| Default (auto) | ~4.5 min | ~18 min | ~25 min |
| ATAC_TOP_PEAKS=20000 | ~4.5 min | ~3 min | ~10 min |
| All auto + fast transformer | ~2 min | ~18 min | ~22 min |
L = L_ode + ½ L_recon + 0.01 L_τ + 0.1 L_reverse
L_ode = ‖z_R - φ_τ̄(z_A)‖²
ATAC should reach RNA after ODE integration (median τ̄)
L_recon = ‖ŷ_g - gene_expr‖²
Expression prediction should match true expression
L_τ = |τ_g|
Prevent τ from exploding or collapsing to zero
L_reverse = ‖z_A - φ_{-τ̄}(φ_τ̄(z_A))‖²
Vector field must be invertible (essential for scRNA→z_A inference)
| Parameter | Default | Description |
|---|---|---|
DIM |
512 | Latent dimension |
HYENADNA_D_MODEL |
256 | HyenaDNA output dimension |
ALIGN_HIDDEN |
512 | AlignMLP hidden dimension |
TAU_HIDDEN |
128 | AmortizedTau hidden dimension |
TAU_MOD_SCALE |
0.1 | Cell-specific τ modulation amplitude |
DEC_HIDDEN |
256 | GeneDecoder hidden dimension |
N_PROJ_W2 |
200 | Sliced W2 projections |
VF_EPOCHS |
50 | Vector field epochs |
VF_LR |
1e-3 | Vector field learning rate |
STAGE1_EPOCHS |
10 | Stage 1 epochs (frozen v_θ) |
STAGE2_EPOCHS |
40 | Stage 2 epochs (fine-tune v_θ) |
STAGE1_LR |
1e-3 | Stage 1 learning rate |
STAGE2_LR_VF |
1e-6 | Stage 2 v_θ learning rate |
BATCH_SIZE |
128 | Training batch size |
W_ODE / W_RECON / W_TAU |
1.0 / 0.5 / 0.01 | Main loss weights |
W_REVERSE |
0.1 | Reverse consistency loss weight |
EXPLODE_NORM |
50.0 | Gradient clipping threshold |
Internal consistency checks. No external data required.
| Metric | Target | Method |
|---|---|---|
| Lipschitz constant | ≤ 1.0 | Product of SpectralNorm across layers |
| Reversibility error | < 0.01 | ‖z_A - φ_{-τ}(φ_τ(z_A))‖ |
| Reverse consistency loss | < 0.1 | L_reverse during training |
| Numerical stability | No NaN/Inf | torch.isfinite() on outputs |
| τ_g distribution | Right-skewed | Shapiro-Wilk test p < 0.05 |
from validate import check_model_health
health = check_model_health(model, z_test)
assert health['lip_in_range']
assert health['rev_err_acceptable']
assert health['numerically_stable']
# τ_g distribution check
from scipy.stats import shapiro
tau_g = model.tau_predictor(None, gene_emb).cpu().numpy()
stat, p = shapiro(tau_g)
print(f"τ_g Shapiro-Wilk: stat={stat:.3f}, p={p:.2e}")
# Expected p < 0.05 (reject normality → right-skewed is reasonable)Biological plausibility of τ_g estimates.
Short delay genes (immediate-early genes):
- FOS, JUN, JUNB, EGR1, NR4A1, NR4A2, ATF3
- Expected: τ < 25th percentile
- Mechanism: ATAC changes quickly lead to RNA changes
Priming genes (pioneer TFs):
- SPI1 (PU.1), CEBPA, CEBPB, GATA2, IRF8, RUNX1
- Expected: τ > 75th percentile
- Mechanism: ATAC primes chromatin, RNA response is delayed
Open-no-expression genes (chromatin accessible but silenced):
- Requires: H3K27me3 ChIP-seq data (Roadmap Epigenomics)
- Expected: high H3K27me3 signal at promoters
Reverse regulation genes (negative feedback):
- AHR, NR4A1 (literature-validated feedback loops)
- Expected: τ reflects feedback delay
from scipy.stats import mannwhitneyu
early_genes = {'FOS', 'JUN', 'JUNB', 'EGR1', 'NR4A1', 'NR4A2', 'ATF3'}
priming_tfs = {'SPI1', 'CEBPA', 'CEBPB', 'GATA2', 'IRF8', 'RUNX1'}
early_tau = [tau[gene_names.index(g)] for g in early_genes if g in gene_names]
priming_tau = [tau[gene_names.index(g)] for g in priming_tfs if g in gene_names]
u, p = mannwhitneyu(early_tau, priming_tau, alternative='less')
print(f"Immediate-early < Priming: U={u:.0f}, p={p:.3f}")
# Expected p < 0.05Genetic variants provide natural randomization (Mendelian randomization). caQTL = natural ATAC perturbation, eQTL = natural RNA response. SMR provides independent causal evidence.
Data sources: CIMA_caQTL_eQTL_SMR.csv, CIMA_Lead_cis-xQTL.csv
Validation 1: xQTL genes have higher τ than background
smr = pd.read_csv("CIMA_Resource/xQTL/CIMA_caQTL_eQTL_SMR.csv")
xqtl_genes = set(smr['gene']).intersection(set(gene_names))
tau_xqtl = tau[np.isin(gene_names, list(xqtl_genes))]
tau_bg = tau[~np.isin(gene_names, list(xqtl_genes))]
u, p = mannwhitneyu(tau_xqtl, tau_bg, alternative='greater')
print(f"xQTL genes τ > background: U={u:.0f}, p={p:.2e}")
# Expected p < 0.05Validation 2: CE direction matches xQTL direction
# For each caQTL→eQTL gene pair, compute model CE
direction_consistent = 0
for gene in xqtl_genes:
ce = compute_causal_effect(model, z_A_base, z_A_perturbed, gene_emb)
if np.sign(ce) == np.sign(smr_beta[gene]):
direction_consistent += 1
direction_rate = direction_consistent / len(xqtl_genes)
print(f"Direction consistency: {direction_rate:.1%}")
# Expected > 60%Validation 3: CE magnitude correlates with SMR effect size
from scipy.stats import spearmanr
smr_effect = smr.set_index('gene')['beta'].abs()
ce_effect = pd.Series({g: abs(ce[g]) for g in xqtl_genes})
rho, p = spearmanr(smr_effect, ce_effect)
print(f"CE vs SMR Spearman ρ={rho:.3f}, p={p:.2e}")
# Expected ρ > 0.3Validation 4: caQTL+eQTL colocalized genes have higher τ than caQTL-only
lead_xqtl = pd.read_csv("CIMA_Resource/xQTL/CIMA_Lead_cis-xQTL.csv")
caqtl_only = lead_xqtl[(lead_xqtl['caQTL_p'] < 1e-5) & (lead_xqtl['eQTL_p'] > 0.05)]
caqtl_eqtl = lead_xqtl[(lead_xqtl['caQTL_p'] < 1e-5) & (lead_xqtl['eQTL_p'] < 1e-5)]
tau_coloc = tau[np.isin(gene_names, list(caqtl_eqtl['gene']))]
tau_caonly = tau[np.isin(gene_names, list(caqtl_only['gene']))]
u, p = mannwhitneyu(tau_coloc, tau_caonly, alternative='greater')
print(f"Colocalized τ > caQTL-only: U={u:.0f}, p={p:.2e}")
# Expected p < 0.05CRISPR perturbation data validates do-intervention predictions.
Baseline = 0 genes: Should they be activated?
- Silence precision > 0.85 (should NOT falsely predict activation when gene stays silent)
- Silence recall > 0.3 (should detect real activations)
Baseline ≠ 0 genes: Is the direction correct?
- Direction rate > 0.6
- Spearman ρ > 0.2
- Recall@50 > 0.1
def perturb_validation(model, z_R, tf_direction, perturb_data):
# 1. Latent intervention: simulate TF knockout
result = do_intervention_latent(model, z_R, tf_direction, alpha=1.0)
ce_pred = result['CE'].mean(axis=1) # (n_genes,)
# 2. Compare with experimental DE
de_obs = perturb_data.set_index('gene')['DE_log2FC']
y_before = perturb_data.set_index('gene')['y_before']
y_after = perturb_data.set_index('gene')['y_after']
# 3. Baseline = 0: silence specificity
zero_mask = (y_before < 0.01)
y_activated = (y_after[zero_mask] > 0.1)
ce_activated = (ce_pred[zero_mask].abs() > 0.05)
silence_precision = 1 - ce_activated[~y_activated].mean()
silence_recall = ce_activated[y_activated].mean()
# 4. Baseline ≠ 0: direction consistency
nonzero_mask = ~zero_mask
y_direction = np.sign(y_after[nonzero_mask] - y_before[nonzero_mask])
ce_direction = np.sign(ce_pred[nonzero_mask])
direction_rate = (y_direction == ce_direction).mean()
# 5. Ranking consistency
rho, p = spearmanr(ce_pred[nonzero_mask], de_obs[nonzero_mask])
return {
'silence_precision': silence_precision, # > 0.85
'silence_recall': silence_recall, # > 0.3
'direction_rate': direction_rate, # > 0.6
'spearman_rho': rho, # > 0.2
}Data sources (by PBMC relevance):
- Schmidt et al. 2022 (primary T cells, ~100 TFs) — highest relevance
- Norman et al. 2019 (K562, bidirectional CRISPR) — medium
- Replogle et al. 2022 (K562/RPE1, genome-wide) — low but broad coverage
Age-related genes (implemented, data exists):
age_ereg = pd.read_csv("CIMA_Resource/GRN/CIMA_Age_Related_eRegulons.csv")
age_genes = set(age_ereg['Gene_signature_name'])
tau_age = tau[np.isin(gene_names, list(age_genes))]
tau_nonage = tau[~np.isin(gene_names, list(age_genes))]
u, p = mannwhitneyu(tau_age, tau_nonage, alternative='greater')
print(f"Age genes τ > others: U={u:.0f}, p={p:.2e}")
# Expected p < 0.05eRegulon enrichment (Pioneer TF targets):
ereg = pd.read_csv("CIMA_Resource/GRN/CIMA_eRegulons_Metadata.csv",
usecols=['TF', 'Gene_signature_name', 'R2G_importance'])
pioneer_tfs = {'PU.1', 'SPI1', 'CEBPA', 'GATA2', 'RUNX1'}
pioneer_targets = set(ereg[ereg['TF'].isin(pioneer_tfs)]['Gene_signature_name'])
high_tau_genes = set(gene_names[tau > np.percentile(tau, 75)])
from scipy.stats import fisher_exact
overlap = high_tau_genes & pioneer_targets
contingency = [
[len(overlap), len(high_tau_genes - pioneer_targets)],
[len(pioneer_targets - high_tau_genes),
len(set(gene_names) - high_tau_genes - pioneer_targets)]
]
odds, p = fisher_exact(contingency)
print(f"Pioneer TF target enrichment: OR={odds:.2f}, p={p:.2e}")
# Expected OR > 1, p < 0.05Time-series scRNA-seq (requires external data):
- LPS-stimulated monocytes (public datasets)
- Short-delay genes should respond rapidly post-stimulation
- Expected: correlation between τ_g and response time
Roadmap/ENCODE epigenomics (requires external data):
- H3K27me3 ChIP-seq for PBMC
- Open-no-expression genes should have high H3K27me3 at promoters
- Expected: enrichment OR > 2, p < 0.05
Run after training completes:
- Model health (Lipschitz, reversibility, τ_g distribution)
- Immediate-early vs priming TF delay difference (Mann-Whitney U test)
- Expression reconstruction quality (Spearman ρ > 0.5)
- xQTL natural experiment (4 sub-validations, strongest causal evidence)
- eRegulon pioneer TF target enrichment
- Age eRegulon enrichment
- Perturb-seq validation
- Time-series scRNA validation
- Roadmap/ENCODE epigenomic marks
| Stage | Checkpoint | Pass Criteria | Failure Action |
|---|---|---|---|
| Stage 1 Epoch 10 | τ_g distribution | Not collapsed to 0 or ∞ | Increase L_τ weight |
| Stage 2 Epoch 50 | ODE consistency | L_ODE < 0.1 | Check v_θ Lipschitz or adjust LR |
| Post-training | Lipschitz constant | ≤ 1.0 | Verify SpectralNorm is active |
| Post-training | Reversibility | error < 0.01 | Shorten integration interval |
| Post-training | Reverse ODE explosion rate | < 10% | Check v_θ Lipschitz, reduce step size |
| Post-training | eRegulon enrichment | OR > 1, p < 0.05 | Check τ_g distribution flatness |
| Post-training | xQTL direction | consistency > 60% | Re-examine v_θ training quality |
- Extract CIMA embeddings (bottleneck: ~20-40 hours GPU)
- Train vector field with reverse consistency
- Process 10X data + extract gene embeddings
- Train delay model (Stage 1 frozen VF + Stage 2 fine-tune)
- First τ_g estimates + biological sanity check
- Model health + displacement classification
- xQTL natural experiment
- eRegulon + Age enrichment
Add ATACDecoder to enable cross-modality prediction:
z_A (512d) → ATACDecoder → peak accessibility (n_peaks,)
Requires: cell × peak accessibility matrix from 10X h5, atac_decoder.py.
Perturb a gene and predict the expression cascade:
RNA (perturbed) → reverse_ode → z_A
→ query eRegulon: TF → target peaks
→ modify accessibility[peaks] ∝ δ × weights
→ re-encode → z_A' → forward ODE → RNA' (perturbed prediction)
Requires: eRegulon peak mapping, Perturb-seq validation data.
These are noted for CIMA-scale training (4M cells), not demo-scale.
Problem: MLP from frozen 512d gene embeddings has limited expressive power. τ range is narrow (0.65-0.72), gene-wise variation is low (0.1%).
Direction: Replace MLP with attention over gene embeddings.
# Current (weak)
tau_base = MLP(gene_emb) # MLP bottleneck
# Future: attention lets genes "see" each other
tau_base = multi_head_attention(gene_emb, gene_emb, gene_emb)Problem: Frozen scGPT embeddings may not match downstream task distribution.
Direction: Add learnable adapter (not full fine-tuning).
# Current: fully frozen
gene_emb = scGPT_encoder.weight
# Future: small adapter, preserves pretrained + adapts to data
gene_emb = scGPT_encoder.weight + adapter(scGPT_encoder.weight)Problem: Fixed-step Euler integration + interpolation wastes compute and introduces error.
Direction: O(1) attention-based readout.
# Current: full trajectory + fixed interpolation
trajectory = integrate(z_A, τ_max) # 200 steps
hidden = interpolate(trajectory, τ_g)
# Future: direct O(1) readout
hidden = attention_readout(z_A, τ_g, gene_emb) # no interpolation errorProblem: SW2=0.733 indicates ATAC→RNA distribution alignment is not tight.
Direction: Add contrastive loss for paired cells.
# Current: only SW2 (unpaired distribution matching)
L = sliced_w2(z_integrated, z_R)
# Future: add contrastive signal for paired cells
L_contrastive = contrastive_loss(AlignMLP(z_ATAC), z_R)
L = SW2 + λ * L_contrastive# Current: float32
with torch.autocast(device_type='mps', dtype=torch.bfloat16):
outputs = model(input_ids=tokens)
# Expected: 1.5-2x speedup, 50% memory reductionPriority: Hybrid precision (trivial) > Attention τ_predictor (high impact) > Adapter (medium) > Attention readout (high complexity) > Contrastive loss (requires paired data).