Skip to content

Jtw998/Effendi

Repository files navigation

Effendi Architecture

Overview

Effendi estimates the causal delay τ_g — how long it takes for chromatin accessibility changes (ATAC) to propagate to gene expression (RNA) — using a learned manifold and Neural ODE.

Core idea: The vector field v_θ defines cell state trajectories. ATAC is the initial condition; the ODE integrates forward from ATAC. RNA is the target — it never directly interacts with ATAC, it only supervises the result.


Architecture

                           gene_emb (n_genes, 512)  ← scGPT frozen
                           ┌──────────┴──────────┐
                           ↓                     ↓
                    AmortizedTau          GeneDecoder
                           ↓                     ↑
                        τ_g (B,n_genes)           │
                           ↓                     │
z_ATAC (256d)  ──┐        ↓                     │
  AlignMLP       │   ┌────┴─────┐               │
z_A (512d)  ────┤   │          │               │
                ├───┤   ODE    ├── hidden ─────┘
z_R (512d)      │   │ integrate │ (B,n_genes,D)
                │   └──────────┘
                │        ↑
                │   v_θ fields
                │        │
                └─── L_ode = ‖z_R - φ_τ(z_A)‖²
                     (RNA only as supervision)

Step-by-Step Forward Pass

Step 0: Encode Inputs

z_ATAC = HyenaDNA(atac_peaks)        # (B, 256)  peak sequences → embedding
z_R    = scGPT(rna_counts)           # (B, 512)  gene expression → embedding
z_A    = AlignMLP(z_ATAC)            # (B, 512)  align ATAC to shared latent space

Two separate encoding paths, independently producing embeddings in the same 512d space.

Step 1: Compute τ_g

τ_g = AmortizedTauPredictor(z_A, gene_emb)  # (B, n_genes)

How τ_g is computed:

Component Formula Shape
Gene baseline τ_base = softplus(MLP(gene_emb)) (n_genes,)
Cell modulation mod = MLP(z_A) (B, 512)
Gene-specific delta delta = mod @ gene_emb.T (B, n_genes)
Final τ_g softplus(τ_base + delta) (B, n_genes)
  • gene_emb is frozen (from scGPT GeneEncoder). Only MLP(gene_emb) and MLP(z_A) are trained.
  • τ_base: each gene has a "default" delay learned from its embedding (e.g. FOS → short, CEBPA → long)
  • delta: cell state modulates the delay — same gene can have different delay in different cells
  • mod @ gene_emb.T: projects the cell's modulation vector onto each gene's embedding space

Step 2: ODE Integration with τ_g Sampling

# Integrate forward from z_A for τ_max time
trajectory = v_θ.integrate(z_A, [0, ..., τ_max])  # (200 steps, B, 512)

# At each gene g, sample the trajectory at that gene's τ_g time
hidden_states = interpolate(trajectory, τ_g)  # (B, n_genes, 512)

Key: ODE starts from z_A (ATAC). RNA is never an input to the ODE. The integration is done once to τ_max, then for each gene g, we sample at time τ_g[cell, g] via linear interpolation.

Step 3: Decode Expression

ŷ_g = GeneDecoder(hidden_states, gene_emb)  # (B, n_genes)
# internal: h = proj(hidden_states.mean(dim=1))
#           ŷ_g = sigmoid(h @ gene_emb.T)

Step 4: Compute Loss

# ATAC → ODE → should reach RNA state (median τ̄)
L_ode =z_R - φ_τ̄(z_A)‖²

# Correct τ_g → correct expression prediction
L_recon =ŷ_g - gene_expr‖²

# Regularization
L_τ = |τ_g|
L_reverse =z_A - φ_{-τ̄}(φ_τ̄(z_A))‖²

L = L_ode + ½ L_recon + 0.01 L_τ + 0.1 L_reverse

How τ_g Is Learned

No direct ground truth for τ_g exists. τ_g is learned through gradient descent on two losses:

  1. L_ode: Use median τ_g to integrate ATAC → should reach RNA state
  2. L_recon: Decode expression at each gene's τ_g → should match true expression

If τ_g is wrong for gene g:

  • The hidden state at time τ_g won't contain the right information
  • L_recon will be high → gradient flows back → τ_g adjusts

Interaction Between ATAC and RNA

They never directly interact. The only coupling is through the ODE:

z_A (ATAC)  ──→ ODE integrate ──→ φ_τ(z_A) ← compares with → z_R (RNA)
  • ATAC is the initial condition of the ODE
  • RNA is the supervision target (provides the loss signal)
  • RNA teaches the vector field which direction to flow
  • RNA teaches τ_g when to read out each gene

Outputs

Output Shape Meaning
τ_g (B, n_genes) or (n_genes,) Per-gene causal delay (primary result)
ŷ_g (B, n_genes) Predicted gene expression (used for supervision)
z_integrated (B, 512) ATAC integrated forward by median τ̄

The primary goal is τ_g. ŷ_g is the supervision signal that validates τ_g is correct.


Training Pipeline

Training proceeds in 3 stages, each building on the previous:

Stage 0: Embedding Precomputation
─────────────────────────────────
CIMA RNA h5ad   →  scGPT encode  →  effendi_rna_emb.npy  (N, 512)
CIMA ATAC h5ad  →  HyenaDNA     →  effendi_atac_emb.npy  (N, 256)


Stage 1: Train Vector Field v_θ
───────────────────────────────
CIMA unpaired data (~400K cells)

z_ATAC (256d) ── AlignMLP ──→ z_A (512d)
                                │
z_R (512d) ─────────────────────┼──→ EffendiVectorField
                SlicedW2        │      (SpectralNorm MLP)
              distribution      │      v_θ.integrate(z_A, t)
              matching loss ────┤
                                │
                          φ_τ̄(z_A)
                                │
          L_reverse = ‖z_A - φ_{-τ̄}(φ_τ̄(z_A))‖²

→ vector_field_frozen.pt, align_mlp.pt


Stage 2.1: Train Delay Model (v_θ frozen)
────────────────────────────────────────
10X Multiome paired data

gene_emb (n_genes, 512) ← scGPT GeneEncoder (frozen)

z_A ──→ AmortizedTauPredictor(z_A, gene_emb) ──→ τ_g (B, n_genes)
         │
         └── tau_mlp(gene_emb) + cell_mod(z_A) @ gene_emb.T

z_A ──→ v_θ.integrate ──→ trajectory (200 steps)
                                │
                    interpolate at τ_g ──→ hidden_states (B, n_genes, 512)
                                         │
                    GeneDecoder(hidden_states, gene_emb) ──→ ŷ_g (B, n_genes)

L = L_ode + ½ L_recon + 0.01 L_τ

→ effendi_stage1.pt


Stage 2.2: Fine-tune Delay Model (v_θ unfrozen, lr=1e-6)
─────────────────────────────────────────────────────
Same architecture, v_θ unfrozen with small learning rate

→ effendi_final.pt, gene_tau_estimates.csv

Stage 0: Precompute Embeddings

# Encode CIMA cells (bottleneck: 44GB + 34GB h5ad files)
python precompute_embeddings.py rna --max-cells 200000  # → (N, 512) CIMA RNA embeddings
python precompute_embeddings.py atac --max-cells 200000  # → (N, 256) CIMA ATAC embeddings

Output: effendi_rna_emb.npy, effendi_rna_obs.pkl, effendi_atac_emb.npy, effendi_atac_obs.pkl, peak_embeddings.npy

Stage 1: Train Vector Field (CIMA, unpaired data)

python train_vf.py
Goal: learn shared manifold + ATAC→RNA flow direction

Input:  random immune-cell embeddings from CIMA (unpaired)
Model:  AlignMLP(256→512) + EffendiVectorField + global τ̄
Loss:   SlicedW2(φ_τ̄(z_A_distribution), z_R_distribution) + 0.1 × L_reverse
Output: vector_field_frozen.pt, align_mlp.pt

No bins, no pseudotime, no cell types. Pure distribution matching.

Stage 2: Train Delay Model (10X Multiome, paired data)

# First extract gene embeddings from scGPT
python -c "from scgpt_encoder import save_gene_embeddings; save_gene_embeddings(gene_names, config.SCGPT_MODEL_DIR, 'effendi_output/gene_embeddings.npy')"

# Then train delay model
python train_delay.py
Goal: learn per-gene τ_g

Input:  paired z_R (N,512), z_A (N,512), gene_expr (N,n_genes), gene_emb (n_genes,512)
Model:  Stage 2.1: Freeze v_θ → train AmortizedTau + GeneDecoder + AlignMLP
        Stage 2.2: Fine-tune v_θ (lr=1e-6) + continue training
Loss:   L_ode + ½ L_recon + 0.01 L_τ
Output: effendi.pt, gene_tau_estimates.csv

Key Design Decisions

Aspect Stage 1 Stage 2.1 Stage 2.2
Data CIMA unpaired 10X paired 10X paired
Supervision Distribution matching Reconstruction Reconstruction
v_θ Train Frozen Fine-tune (lr=1e-6)
AmortizedTau Train Train
GeneDecoder Train Train
AlignMLP Train Train Train

Key Components

AmortizedTauPredictor

Computes per-gene delay from gene embeddings and cell state:

class AmortizedTauPredictor(nn.Module):
    def __init__(self, hidden=128, mod_scale=0.1):
        self.hidden = hidden
        self.mod_scale = mod_scale

    def _init_mlps(self, z_dim, gene_dim):
        # gene_emb → scalar τ_base
        self.tau_mlp = nn.Sequential(
            nn.Linear(gene_dim, self.hidden), nn.Tanh(), nn.Linear(self.hidden, 1)
        )
        # z_A → cell modulation vector
        self.cell_mod = nn.Sequential(
            nn.Linear(z_dim, self.hidden), nn.LayerNorm(self.hidden),
            nn.Tanh(), nn.Dropout(0.1), nn.Linear(self.hidden, z_dim),
        )

    def forward(self, z_A, gene_emb):
        gene_dim = gene_emb.shape[1]
        z_dim = z_A.shape[1] if z_A is not None else gene_dim
        if not hasattr(self, 'tau_mlp') or self._gene_dim != gene_dim:
            self._init_mlps(z_dim, gene_dim)

        tau_base = F.softplus(self.tau_mlp(gene_emb)).squeeze(-1) + 1e-3  # (n_genes,)

        if z_A is None:
            return tau_base  # Global mode

        mod = self.cell_mod(z_A) * self.mod_scale  # (B, D)
        delta = (mod @ gene_emb.T) * self.mod_scale  # (B, n_genes)
        return F.softplus(tau_base + delta) + 1e-3

Key design:

  • tau_mlp: learns "which semantic features determine delay" from gene embeddings
  • cell_mod: learns "how cell state modulates delay" from ATAC latent
  • mod @ gene_emb.T: projects cell modulation onto each gene's embedding → gene-specific delta
  • * 0.1: small amplitude modulation ensures training stability

GeneDecoder

Dot-product decoder (gene-agnostic):

class GeneDecoder(nn.Module):
    def __init__(self, dim=512, hidden=256):
        self.proj = nn.Sequential(
            nn.Linear(dim, hidden), nn.ReLU(),
            nn.Dropout(0.1), nn.Linear(hidden, dim),
        )

    def forward(self, hidden_states, gene_emb):
        h = self.proj(hidden_states.mean(dim=1))  # (B, D)
        return torch.sigmoid(h @ gene_emb.T)  # (B, n_genes)

EffendiVectorField

Lipschitz-bounded MLP with SpectralNorm and manual ODE integration (no external dependencies):

class EffendiVectorField(nn.Module):
    def __init__(self, dim=512, hidden_dim=128):
        self.net = nn.Sequential(
            SpectralNormLinear(dim, hidden_dim), nn.Tanh(),
            SpectralNormLinear(hidden_dim, hidden_dim), nn.Tanh(),
            SpectralNormLinear(hidden_dim, dim),
        )

    def integrate(self, z0, t, n_steps=100):
        """Manual Euler/RK4 integration — no torchdiffeq, MPS-compatible."""
        dt = (t[-1] - t[0]) / n_steps
        z = z0
        results = [z]
        for _ in range(n_steps):
            z = z + dt * self.net(z)  # Euler step
            results.append(z)
        return torch.stack(results, dim=0)

SpectralNorm on all linear layers ensures Lipschitz ≤ 1, guaranteeing stable ODE integration and invertibility. Manual integration uses float32 only — fully compatible with MPS (Apple Silicon).

AlignMLP

Nonlinear ATAC alignment:

class AlignMLP(nn.Module):
    def __init__(self, in_dim=256, out_dim=512, hidden=512):
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.LayerNorm(hidden), nn.GELU(),
            nn.Dropout(0.1), nn.Linear(hidden, out_dim), nn.LayerNorm(out_dim),
        )

Gene-Agnostic Design

Problem

Previous architecture hardcoded nn.Parameter(3000) per gene — breaks on different datasets.

Solution

Replace parameter lookup with computation from scGPT gene embeddings:

# Previous: independent parameters
τ_g[g] = softplus(log_tau[g])     # log_tau = nn.Parameter(3000)

# Now: computed from gene semantics
τ_g[g] = softplus(MLP(gene_emb[g]))

# Previous: fixed output dimension
ŷ = Linear(512, 3000)(h)

# Now: dot product with gene embeddings
ŷ = sigmoid(proj(h) @ gene_emb.T)

Benefits

Aspect Previous Gene-agnostic
Generalization Only training genes Any gene in scGPT (~30K)
Parameters ~1.5M ~65K (23× reduction)
Dataset migration Retrain required Swap gene_emb
Semantic consistency Independent Functionally similar genes → similar τ

Data Flow

Training

10X paired data:
  - RNA: (N, 36K genes) → HVG selection → (N, n_genes)
  - Extract scGPT gene embedding → (n_genes, 512) gene_emb

Forward pass:
  z_A → AmortizedTau(z_A, gene_emb) → τ_g (B, n_genes)
  z_A → ODE integrate → trajectory → interpolate at τ_g → hidden_states
  hidden_states → GeneDecoder(hidden_states, gene_emb) → ŷ_g (B, n_genes)
  loss = L_ode + L_recon + L_tau + L_reverse

Inference (new dataset)

New dataset:
  - Extract HVG → (n_genes_new, 512) new_gene_emb
  - scGPT encode → z_R

Inference:
  z_A_new → AmortizedTau(z_A_new, new_gene_emb) → τ_g_new  # works directly
  z_A_new → ODE integrate → hidden_states_new
  hidden_states_new → GeneDecoder(hidden, new_gene_emb) → ŷ_new  # works directly

No retraining needed. Just swap gene_emb lookup table.


File Structure

Effendi/
├── config.py                  # Paths and hyperparameters
├── effendi_loader.py          # h5ad load (backed/streaming) + data prep
│
├── scgpt_encoder.py           # scGPT inference + gene embedding extraction
├── hyenadna_encoder.py        # HyenaDNA inference (+ peak_embeddings export)
│
├── align_mlp.py               # AlignMLP — nonlinear ATAC alignment
├── slice_w2.py                # Sliced W2 distance (pure PyTorch)
├── vector_field.py            # EffendiVectorField (SpectralNorm + manual Euler ODE)
├── train_vf.py                # Stage 1: v_θ + AlignMLP + reverse consistency
│
├── effendi.py                 # AmortizedTauPredictor + GeneDecoder + Effendi model
├── train_delay.py             # Stage 2: delay training (frozen VF → fine-tune)
│
├── demo_real_emb.py           # Demo: real scGPT + HyenaDNA embeddings + Stage 1
├── inference.py               # scRNA_to_tau, reverse_ode
├── intervention.py             # Bootstrap inference + displacement classification
├── validate.py                # Five-layer validation
└── run.py                     # Main entry point

Demo: Real Embeddings (10X Multiome)

Run the full pipeline with real scGPT (RNA) and HyenaDNA (ATAC) embeddings:

python demo_real_emb.py

What it does:

Step 0: Auto-detect settings          → batch=16, peaks=auto, seq_len=auto
Step 1: Load 10X Multiome (10k cells)
Step 2: RNA embeddings with scGPT         → (10000, 512)  ✓
Step 3: HVG selection                    → 2001 genes
Step 4: Gene embeddings from scGPT      → (1654, 512) matched genes
Step 5: ATAC filtering                  → 111k → ~30k peaks (auto)
Step 6: ATAC embeddings with HyenaDNA   → (10000, 256)  ✓
Step 7: Align RNA + ATAC
Step 7.5: Stage 1 — Train Vector Field → SW2=0.733, Rev≈0
Step 8: Model forward pass              → Effendi ready
Step 9: Cell-specific τ                 → (10000, 1654)  ✓
Step 10: Full forward pass
Step 11: Tau analysis

Results (10k PBMC cells, 1654 genes):

Metric Value
RNA embeddings (10000, 512) via scGPT
ATAC embeddings (10000, 256) via HyenaDNA
Stage 1 SW2 0.733 (improving)
Rev error ~0 (vector field invertible)
τ_g range 0.649 – 0.722
Gene-wise variation 0.1%
Cell-wise variation 0.7%

Shortest delay genes: TRIB1, LINC02315, FKBP11, ID3, IL15 (immune regulation)

Longest delay genes: FCER1A, NAMPT, LIN7A, AP002370.2, IL7 (metabolism/signaling)

Hardware: Runs on MPS (Apple Silicon M-series) — no CUDA required. All layers use float32 for MPS compatibility.

Embedding Configuration

All embedding settings are in config.py. Auto-detection is the default — override any setting as needed:

# config.py

# scGPT: max sequence length per cell
# None = auto (min(n_vars + 1, 1200)) — uses actual gene count, capped at 1200
# 300   = fast (fewer genes per cell)
# 1200  = full vocab (max precision)
SCGPT_MAX_SEQ_LEN = None

# scGPT: use Flash Attention (requires pip install flash-attn)
# True = ~2-3x faster, ~50% less memory
SCGPT_USE_FAST_TRANSFORMER = False

# HyenaDNA: batch size for peak embedding computation
# None = auto: MPS=16, CUDA=64, CPU=8
# Higher = faster but more memory
HYENADNA_BATCH_SIZE = None

# ATAC peak filtering (removes sparse/noisy peaks before HyenaDNA)
# None  = auto: filter peaks present in <10% of cells → ~30k peaks
# 0     = disable (keep all 111k peaks, slow)
# 20000 = keep only top-20k by variance
ATAC_TOP_PEAKS = None

Performance guide:

Config RNA time ATAC time Total
Default (auto) ~4.5 min ~18 min ~25 min
ATAC_TOP_PEAKS=20000 ~4.5 min ~3 min ~10 min
All auto + fast transformer ~2 min ~18 min ~22 min

Loss Function

L = L_ode + ½ L_recon + 0.01 L_τ + 0.1 L_reverse

L_ode     = ‖z_R - φ_τ̄(z_A)‖²
            ATAC should reach RNA after ODE integration (median τ̄)

L_recon   = ‖ŷ_g - gene_expr‖²
            Expression prediction should match true expression

L_τ       = |τ_g|
            Prevent τ from exploding or collapsing to zero

L_reverse = ‖z_A - φ_{-τ̄}(φ_τ̄(z_A))‖²
            Vector field must be invertible (essential for scRNA→z_A inference)

Hyperparameters

Parameter Default Description
DIM 512 Latent dimension
HYENADNA_D_MODEL 256 HyenaDNA output dimension
ALIGN_HIDDEN 512 AlignMLP hidden dimension
TAU_HIDDEN 128 AmortizedTau hidden dimension
TAU_MOD_SCALE 0.1 Cell-specific τ modulation amplitude
DEC_HIDDEN 256 GeneDecoder hidden dimension
N_PROJ_W2 200 Sliced W2 projections
VF_EPOCHS 50 Vector field epochs
VF_LR 1e-3 Vector field learning rate
STAGE1_EPOCHS 10 Stage 1 epochs (frozen v_θ)
STAGE2_EPOCHS 40 Stage 2 epochs (fine-tune v_θ)
STAGE1_LR 1e-3 Stage 1 learning rate
STAGE2_LR_VF 1e-6 Stage 2 v_θ learning rate
BATCH_SIZE 128 Training batch size
W_ODE / W_RECON / W_TAU 1.0 / 0.5 / 0.01 Main loss weights
W_REVERSE 0.1 Reverse consistency loss weight
EXPLODE_NORM 50.0 Gradient clipping threshold

Validation Framework

Layer 1: Model Health (Implemented)

Internal consistency checks. No external data required.

Metric Target Method
Lipschitz constant ≤ 1.0 Product of SpectralNorm across layers
Reversibility error < 0.01 ‖z_A - φ_{-τ}(φ_τ(z_A))‖
Reverse consistency loss < 0.1 L_reverse during training
Numerical stability No NaN/Inf torch.isfinite() on outputs
τ_g distribution Right-skewed Shapiro-Wilk test p < 0.05
from validate import check_model_health

health = check_model_health(model, z_test)
assert health['lip_in_range']
assert health['rev_err_acceptable']
assert health['numerically_stable']

# τ_g distribution check
from scipy.stats import shapiro
tau_g = model.tau_predictor(None, gene_emb).cpu().numpy()
stat, p = shapiro(tau_g)
print(f"τ_g Shapiro-Wilk: stat={stat:.3f}, p={p:.2e}")
# Expected p < 0.05 (reject normality → right-skewed is reasonable)

Layer 2: Displacement Classification

Biological plausibility of τ_g estimates.

Short delay genes (immediate-early genes):

  • FOS, JUN, JUNB, EGR1, NR4A1, NR4A2, ATF3
  • Expected: τ < 25th percentile
  • Mechanism: ATAC changes quickly lead to RNA changes

Priming genes (pioneer TFs):

  • SPI1 (PU.1), CEBPA, CEBPB, GATA2, IRF8, RUNX1
  • Expected: τ > 75th percentile
  • Mechanism: ATAC primes chromatin, RNA response is delayed

Open-no-expression genes (chromatin accessible but silenced):

  • Requires: H3K27me3 ChIP-seq data (Roadmap Epigenomics)
  • Expected: high H3K27me3 signal at promoters

Reverse regulation genes (negative feedback):

  • AHR, NR4A1 (literature-validated feedback loops)
  • Expected: τ reflects feedback delay
from scipy.stats import mannwhitneyu

early_genes = {'FOS', 'JUN', 'JUNB', 'EGR1', 'NR4A1', 'NR4A2', 'ATF3'}
priming_tfs = {'SPI1', 'CEBPA', 'CEBPB', 'GATA2', 'IRF8', 'RUNX1'}

early_tau = [tau[gene_names.index(g)] for g in early_genes if g in gene_names]
priming_tau = [tau[gene_names.index(g)] for g in priming_tfs if g in gene_names]

u, p = mannwhitneyu(early_tau, priming_tau, alternative='less')
print(f"Immediate-early < Priming: U={u:.0f}, p={p:.3f}")
# Expected p < 0.05

Layer 3: xQTL Natural Experiment (Strongest Causal Evidence)

Genetic variants provide natural randomization (Mendelian randomization). caQTL = natural ATAC perturbation, eQTL = natural RNA response. SMR provides independent causal evidence.

Data sources: CIMA_caQTL_eQTL_SMR.csv, CIMA_Lead_cis-xQTL.csv

Validation 1: xQTL genes have higher τ than background

smr = pd.read_csv("CIMA_Resource/xQTL/CIMA_caQTL_eQTL_SMR.csv")
xqtl_genes = set(smr['gene']).intersection(set(gene_names))

tau_xqtl = tau[np.isin(gene_names, list(xqtl_genes))]
tau_bg   = tau[~np.isin(gene_names, list(xqtl_genes))]

u, p = mannwhitneyu(tau_xqtl, tau_bg, alternative='greater')
print(f"xQTL genes τ > background: U={u:.0f}, p={p:.2e}")
# Expected p < 0.05

Validation 2: CE direction matches xQTL direction

# For each caQTL→eQTL gene pair, compute model CE
direction_consistent = 0
for gene in xqtl_genes:
    ce = compute_causal_effect(model, z_A_base, z_A_perturbed, gene_emb)
    if np.sign(ce) == np.sign(smr_beta[gene]):
        direction_consistent += 1

direction_rate = direction_consistent / len(xqtl_genes)
print(f"Direction consistency: {direction_rate:.1%}")
# Expected > 60%

Validation 3: CE magnitude correlates with SMR effect size

from scipy.stats import spearmanr

smr_effect = smr.set_index('gene')['beta'].abs()
ce_effect = pd.Series({g: abs(ce[g]) for g in xqtl_genes})

rho, p = spearmanr(smr_effect, ce_effect)
print(f"CE vs SMR Spearman ρ={rho:.3f}, p={p:.2e}")
# Expected ρ > 0.3

Validation 4: caQTL+eQTL colocalized genes have higher τ than caQTL-only

lead_xqtl = pd.read_csv("CIMA_Resource/xQTL/CIMA_Lead_cis-xQTL.csv")

caqtl_only = lead_xqtl[(lead_xqtl['caQTL_p'] < 1e-5) & (lead_xqtl['eQTL_p'] > 0.05)]
caqtl_eqtl = lead_xqtl[(lead_xqtl['caQTL_p'] < 1e-5) & (lead_xqtl['eQTL_p'] < 1e-5)]

tau_coloc  = tau[np.isin(gene_names, list(caqtl_eqtl['gene']))]
tau_caonly = tau[np.isin(gene_names, list(caqtl_only['gene']))]

u, p = mannwhitneyu(tau_coloc, tau_caonly, alternative='greater')
print(f"Colocalized τ > caQTL-only: U={u:.0f}, p={p:.2e}")
# Expected p < 0.05

Layer 4: Perturb-seq Validation (Requires External Data)

CRISPR perturbation data validates do-intervention predictions.

Baseline = 0 genes: Should they be activated?

  • Silence precision > 0.85 (should NOT falsely predict activation when gene stays silent)
  • Silence recall > 0.3 (should detect real activations)

Baseline ≠ 0 genes: Is the direction correct?

  • Direction rate > 0.6
  • Spearman ρ > 0.2
  • Recall@50 > 0.1
def perturb_validation(model, z_R, tf_direction, perturb_data):
    # 1. Latent intervention: simulate TF knockout
    result = do_intervention_latent(model, z_R, tf_direction, alpha=1.0)
    ce_pred = result['CE'].mean(axis=1)  # (n_genes,)
    
    # 2. Compare with experimental DE
    de_obs = perturb_data.set_index('gene')['DE_log2FC']
    y_before = perturb_data.set_index('gene')['y_before']
    y_after = perturb_data.set_index('gene')['y_after']
    
    # 3. Baseline = 0: silence specificity
    zero_mask = (y_before < 0.01)
    y_activated = (y_after[zero_mask] > 0.1)
    ce_activated = (ce_pred[zero_mask].abs() > 0.05)
    
    silence_precision = 1 - ce_activated[~y_activated].mean()
    silence_recall = ce_activated[y_activated].mean()
    
    # 4. Baseline ≠ 0: direction consistency
    nonzero_mask = ~zero_mask
    y_direction = np.sign(y_after[nonzero_mask] - y_before[nonzero_mask])
    ce_direction = np.sign(ce_pred[nonzero_mask])
    direction_rate = (y_direction == ce_direction).mean()
    
    # 5. Ranking consistency
    rho, p = spearmanr(ce_pred[nonzero_mask], de_obs[nonzero_mask])
    
    return {
        'silence_precision': silence_precision,  # > 0.85
        'silence_recall': silence_recall,        # > 0.3
        'direction_rate': direction_rate,        # > 0.6
        'spearman_rho': rho,                     # > 0.2
    }

Data sources (by PBMC relevance):

  • Schmidt et al. 2022 (primary T cells, ~100 TFs) — highest relevance
  • Norman et al. 2019 (K562, bidirectional CRISPR) — medium
  • Replogle et al. 2022 (K562/RPE1, genome-wide) — low but broad coverage

Layer 5: External Cross-Validation

Age-related genes (implemented, data exists):

age_ereg = pd.read_csv("CIMA_Resource/GRN/CIMA_Age_Related_eRegulons.csv")
age_genes = set(age_ereg['Gene_signature_name'])

tau_age = tau[np.isin(gene_names, list(age_genes))]
tau_nonage = tau[~np.isin(gene_names, list(age_genes))]

u, p = mannwhitneyu(tau_age, tau_nonage, alternative='greater')
print(f"Age genes τ > others: U={u:.0f}, p={p:.2e}")
# Expected p < 0.05

eRegulon enrichment (Pioneer TF targets):

ereg = pd.read_csv("CIMA_Resource/GRN/CIMA_eRegulons_Metadata.csv",
                    usecols=['TF', 'Gene_signature_name', 'R2G_importance'])

pioneer_tfs = {'PU.1', 'SPI1', 'CEBPA', 'GATA2', 'RUNX1'}
pioneer_targets = set(ereg[ereg['TF'].isin(pioneer_tfs)]['Gene_signature_name'])

high_tau_genes = set(gene_names[tau > np.percentile(tau, 75)])

from scipy.stats import fisher_exact
overlap = high_tau_genes & pioneer_targets
contingency = [
    [len(overlap), len(high_tau_genes - pioneer_targets)],
    [len(pioneer_targets - high_tau_genes),
     len(set(gene_names) - high_tau_genes - pioneer_targets)]
]
odds, p = fisher_exact(contingency)
print(f"Pioneer TF target enrichment: OR={odds:.2f}, p={p:.2e}")
# Expected OR > 1, p < 0.05

Time-series scRNA-seq (requires external data):

  • LPS-stimulated monocytes (public datasets)
  • Short-delay genes should respond rapidly post-stimulation
  • Expected: correlation between τ_g and response time

Roadmap/ENCODE epigenomics (requires external data):

  • H3K27me3 ChIP-seq for PBMC
  • Open-no-expression genes should have high H3K27me3 at promoters
  • Expected: enrichment OR > 2, p < 0.05

Validation Priority

P0: Immediate (No External Data Required)

Run after training completes:

  1. Model health (Lipschitz, reversibility, τ_g distribution)
  2. Immediate-early vs priming TF delay difference (Mann-Whitney U test)
  3. Expression reconstruction quality (Spearman ρ > 0.5)

P1: High Priority (Data Already Exists)

  1. xQTL natural experiment (4 sub-validations, strongest causal evidence)
  2. eRegulon pioneer TF target enrichment
  3. Age eRegulon enrichment

P2: Medium Priority (Requires Data Acquisition)

  1. Perturb-seq validation
  2. Time-series scRNA validation
  3. Roadmap/ENCODE epigenomic marks

Go/No-Go Checkpoints

Stage Checkpoint Pass Criteria Failure Action
Stage 1 Epoch 10 τ_g distribution Not collapsed to 0 or ∞ Increase L_τ weight
Stage 2 Epoch 50 ODE consistency L_ODE < 0.1 Check v_θ Lipschitz or adjust LR
Post-training Lipschitz constant ≤ 1.0 Verify SpectralNorm is active
Post-training Reversibility error < 0.01 Shorten integration interval
Post-training Reverse ODE explosion rate < 10% Check v_θ Lipschitz, reduce step size
Post-training eRegulon enrichment OR > 1, p < 0.05 Check τ_g distribution flatness
Post-training xQTL direction consistency > 60% Re-examine v_θ training quality

Execution Plan

Week 1-2: Infrastructure (Stage 0 + 1)

  1. Extract CIMA embeddings (bottleneck: ~20-40 hours GPU)
  2. Train vector field with reverse consistency
  3. Process 10X data + extract gene embeddings

Week 3: Delay Training (Stage 2)

  1. Train delay model (Stage 1 frozen VF + Stage 2 fine-tune)
  2. First τ_g estimates + biological sanity check

Week 4: Validation (P0 + P1)

  1. Model health + displacement classification
  2. xQTL natural experiment
  3. eRegulon + Age enrichment

Future Extensions

scRNA → scATAC Reconstruction

Add ATACDecoder to enable cross-modality prediction:

z_A (512d) → ATACDecoder → peak accessibility (n_peaks,)

Requires: cell × peak accessibility matrix from 10X h5, atac_decoder.py.

Gene-Level Perturbation

Perturb a gene and predict the expression cascade:

RNA (perturbed) → reverse_ode → z_A
  → query eRegulon: TF → target peaks
  → modify accessibility[peaks] ∝ δ × weights
  → re-encode → z_A' → forward ODE → RNA' (perturbed prediction)

Requires: eRegulon peak mapping, Perturb-seq validation data.


Architecture Optimizations (for CIMA-scale)

These are noted for CIMA-scale training (4M cells), not demo-scale.

τ_predictor: MLP → Attention

Problem: MLP from frozen 512d gene embeddings has limited expressive power. τ range is narrow (0.65-0.72), gene-wise variation is low (0.1%).

Direction: Replace MLP with attention over gene embeddings.

# Current (weak)
tau_base = MLP(gene_emb)  # MLP bottleneck

# Future: attention lets genes "see" each other
tau_base = multi_head_attention(gene_emb, gene_emb, gene_emb)

Gene Embedding Adapter

Problem: Frozen scGPT embeddings may not match downstream task distribution.

Direction: Add learnable adapter (not full fine-tuning).

# Current: fully frozen
gene_emb = scGPT_encoder.weight

# Future: small adapter, preserves pretrained + adapts to data
gene_emb = scGPT_encoder.weight + adapter(scGPT_encoder.weight)

Trajectory Readout: Interpolation → Attention

Problem: Fixed-step Euler integration + interpolation wastes compute and introduces error.

Direction: O(1) attention-based readout.

# Current: full trajectory + fixed interpolation
trajectory = integrate(z_A, τ_max)  # 200 steps
hidden = interpolate(trajectory, τ_g)

# Future: direct O(1) readout
hidden = attention_readout(z_A, τ_g, gene_emb)  # no interpolation error

ATAC→RNA Contrastive Loss

Problem: SW2=0.733 indicates ATAC→RNA distribution alignment is not tight.

Direction: Add contrastive loss for paired cells.

# Current: only SW2 (unpaired distribution matching)
L = sliced_w2(z_integrated, z_R)

# Future: add contrastive signal for paired cells
L_contrastive = contrastive_loss(AlignMLP(z_ATAC), z_R)
L = SW2 + λ * L_contrastive

Hybrid Precision (ATAC Embedding)

# Current: float32
with torch.autocast(device_type='mps', dtype=torch.bfloat16):
    outputs = model(input_ids=tokens)
# Expected: 1.5-2x speedup, 50% memory reduction

Priority: Hybrid precision (trivial) > Attention τ_predictor (high impact) > Adapter (medium) > Attention readout (high complexity) > Contrastive loss (requires paired data).

About

Effendi

Resources

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages