In [None]:
import os
import sys
import json
import scanpy as sc
import torch
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import json
import random
import scipy.sparse as sp

In [None]:
work_dir = "/home/yyuan/ICB_TCE/"  # adjust if needed
script_dir = os.path.join(work_dir, "scripts")

if script_dir not in sys.path:
    sys.path.append(script_dir)

iter_dir = os.path.join(work_dir, "iter_results")
os.makedirs(iter_dir, exist_ok=True)

summary_dir = os.path.join(iter_dir, "summaries")
os.makedirs(summary_dir, exist_ok=True)

from vae import *
from sde import *
from bio_con import *
from bio_util import *
from training_util import *
from joint_train import *

In [None]:
# Set global matplotlib defaults: Nimbus Roman + dpi=300
plt.rcParams["font.family"] = "Nimbus Roman"
plt.rcParams["figure.dpi"] = 300
plt.rcParams["savefig.dpi"] = 300

In [None]:
def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    os.environ["PYTHONHASHSEED"] = str(seed)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
adata_brca_t = sc.read_h5ad(os.path.join(work_dir, "data/brca_t_cell.h5ad"))

# QC and filter highly variable genes
sc.pp.filter_cells(adata_brca_t, min_genes = 200)
sc.pp.filter_genes(adata_brca_t, min_cells = 3)
sc.pp.highly_variable_genes(adata_brca_t, n_top_genes = 3000, subset = True)

pre_treatment_mask = adata_brca_t.obs["pre_post"] == "Pre"
post_treatment_mask = adata_brca_t.obs["pre_post"] == "Post"

# Convert the sparse matrix to a dense numpy array, then to a PyTorch tensor
if hasattr(adata_brca_t.X, "toarray"):
    expression_data = torch.tensor(adata_brca_t.X.toarray(), dtype=torch.float32)
else: # If it's already a dense array
    expression_data = torch.tensor(adata_brca_t.X, dtype=torch.float32)

input_dim = expression_data.shape[1]

In [None]:
LATENT_DIM = 20
NUM_EPOCHS = 20
BATCH_SIZE = 256
LEARNING_RATE = 1e-4

# KL Annealing Parameters
KL_START_EPOCH = 3  # Start KL annealing earlier in light pre-training
KL_WARMUP_EPOCHS = 10

In [None]:
class DataSampler:
    def __init__(self, data, device):
        self.data = data.to(device)

    def sample(self, batch_size):
        idx = torch.randint(0, len(self.data), (batch_size,))
        return self.data[idx]

In [None]:
full_config_path = os.path.join(work_dir, "trained_models/full_config.json")
with open(full_config_path, "r") as f: 
    config = json.load(f)
config["input_dim"] = expression_data.shape[1]
config["latent_dim"] = LATENT_DIM

config_abl = config.copy()
config_abl["lambda_bio"] = 0.0
config_abl["lambda_grn"] = 0.0
config_abl["lambda_death"] = 0.0
config_abl["lambda_birth"] = 0.0

In [None]:
def run_single_seed(seed: int, config: dict, config_abl: dict):
    """
    Run original + ablation training pipeline for a single random seed.

    All results are stored under:
        iter_results / run_{seed:03d}
    """

    print("\n" + "=" * 80)
    print(f"SEED {seed}: starting run (original + ablation)")
    print("=" * 80)

    # Set paths for this seed
    run_tag = f"run_{seed:03d}"
    run_dir = os.path.join(iter_dir, run_tag)
    os.makedirs(run_dir, exist_ok=True)

    cfg = config.copy()
    cfg_abl = config_abl.copy()

    # VAE pretraining
    print("\nStarting VAE pre-training...")
    set_seed(seed)
    dataset = TensorDataset(expression_data)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    vae_pre = VAE_scRNA(input_dim=cfg["input_dim"], latent_dim=latent_dim).to(device)
    optimizer = optim.Adam(vae_pre.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

    vae_pre.train()
    for epoch in range(NUM_EPOCHS):
        total_loss = 0
        
        # Calculate beta for KL annealing
        if epoch < KL_START_EPOCH:
            beta = 0.0
        else:
            beta = min(1.0, (epoch - KL_START_EPOCH) / KL_WARMUP_EPOCHS)

        for (batch_features,) in dataloader:
            batch_features = batch_features.to(device)
            
            # Forward pass
            recon_x, mu, log_var = vae(batch_features)
            
            # Compute loss
            loss = elbo_loss(batch_features, recon_x, mu, log_var, beta=beta)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader.dataset)
        scheduler.step(avg_loss)
        
    print("\nVAE pre-training complete.")

    vae_pretrain_path = os.path.join(run_dir, f"vae_pretrain_{run_tag}.pth")
    torch.save(vae_pre.state_dict(), vae_pretrain_path)

    # Latent embeddings for bridge
    latent_mu = compute_latent_embeddings(vae_pre, expression_data, device=device)
    latent_embeddings = latent_mu.numpy()

    pre_embeddings = torch.tensor(
        latent_embeddings[pre_treatment_mask], dtype=torch.float32
    ).to(device)
    post_embeddings = torch.tensor(
        latent_embeddings[post_treatment_mask], dtype=torch.float32
    ).to(device)

    p_sampler = DataSampler(pre_embeddings, device=device)
    q_sampler = DataSampler(post_embeddings, device=device)

    # Joint training (original)
    print("\nStarting joint training (original)...")
    set_seed(seed)

    # Load pretrained VAE weights
    vae = load_trained_vae(
        model_path=vae_pretrain_path,
        input_dim=input_dim,
        latent_dim=LATENT_DIM,
        device=device,
    )

    dyn = VESDE(cfg, p_sampler, q_sampler)
    ts = torch.linspace(cfg["t0"], cfg["T"], cfg["interval"]).to(device)

    net_f = MLP(input_dim=cfg["data_dim"][0], output_dim=cfg["data_dim"][0]).to(device)
    net_b = MLP(input_dim=cfg["data_dim"][0], output_dim=cfg["data_dim"][0]).to(device)

    z_f = SchrodingerBridgePolicy(cfg, "forward", dyn, net_f)
    z_b = SchrodingerBridgePolicy(cfg, "backward", dyn, net_b)

    optimizer_f = torch.optim.Adam(z_f.parameters(), lr=cfg["lr"])
    optimizer_b = torch.optim.Adam(z_b.parameters(), lr=cfg["lr"])
    optimizer_vae = torch.optim.Adam(vae.parameters(), lr=cfg.get("lr_vae", 1e-4))

    vae_decoder = lambda z: vae.decoder_output(vae.decoder(z))

    training_history = run_joint_training_loop(
        config=cfg,
        dyn=dyn,
        ts=ts,
        vae=vae,
        vae_decoder=vae_decoder,
        z_f=z_f,
        z_b=z_b,
        optimizer_f=optimizer_f,
        optimizer_b=optimizer_b,
        optimizer_vae=optimizer_vae,
        expression_data=expression_data,          # computed once outside
        grn_data=grn_data,
        death_gene_indices=death_gene_indices,
        birth_gene_indices=birth_gene_indices,
        device=device,
    )

    # Save original models + history for this run
    torch.save(vae.state_dict(), os.path.join(run_dir, f"vae_original_{run_tag}.pth"))
    torch.save(z_f.state_dict(), os.path.join(run_dir, f"z_f_original_{run_tag}.pth"))
    torch.save(z_b.state_dict(), os.path.join(run_dir, f"z_b_original_{run_tag}.pth"))

    hist_path = os.path.join(run_dir, f"training_history_original_{run_tag}.json")
    with open(hist_path, "w") as f:
        json.dump(training_history, f, indent=2)
    print(f"[Seed {seed}] Joint training complete; history saved to {hist_path}")

    # -------------------------
    # Ablation (no biology constraints) â€“ fresh VAE
    # -------------------------
    vae_abl = VAE_scRNA(
        input_dim=config_abl["input_dim"], latent_dim=cfg_abl["data_dim"][0]
    ).to(device)

    z_f_abl = SchrodingerBridgePolicy(
        cfg_abl,
        "forward",
        dyn,
        MLP(cfg_abl["data_dim"][0], cfg_abl["data_dim"][0]).to(device),
    ).to(device)

    z_b_abl = SchrodingerBridgePolicy(
        cfg_abl,
        "backward",
        dyn,
        MLP(cfg_abl["data_dim"][0], cfg_abl["data_dim"][0]).to(device),
    ).to(device)

    optimizer_vae_abl = torch.optim.Adam(
        vae_abl.parameters(), lr=cfg_abl["lr"], weight_decay=1e-4
    )
    optimizer_f_abl = torch.optim.Adam(
        z_f_abl.parameters(), lr=cfg_abl["lr"], weight_decay=1e-4
    )
    optimizer_b_abl = torch.optim.Adam(
        z_b_abl.parameters(), lr=cfg_abl["lr"], weight_decay=1e-4
    )
    vae_decoder_abl = lambda z: vae_abl.decoder_output(vae_abl.decoder(z))

    ablation_history = run_joint_training_loop(
        config=cfg_abl,
        dyn=dyn,
        ts=ts,
        vae=vae_abl,
        vae_decoder=vae_decoder_abl,
        z_f=z_f_abl,
        z_b=z_b_abl,
        optimizer_f=optimizer_f_abl,
        optimizer_b=optimizer_b_abl,
        optimizer_vae=optimizer_vae_abl,
        expression_data=expression_data,
        grn_data=grn_data,
        death_gene_indices=death_gene_indices,
        birth_gene_indices=birth_gene_indices,
        device=device,
    )

    # Save ablation models + history
    torch.save(vae_abl.state_dict(), os.path.join(run_dir, f"vae_ablation_{run_tag}.pth"))
    torch.save(z_f_abl.state_dict(), os.path.join(run_dir, f"z_f_ablation_{run_tag}.pth"))
    torch.save(z_b_abl.state_dict(), os.path.join(run_dir, f"z_b_ablation_{run_tag}.pth"))

    ablation_hist_path = os.path.join(run_dir, f"training_history_ablation_{run_tag}.json")
    with open(ablation_hist_path, "w") as f:
        json.dump(ablation_history, f, indent=2)

    # Drift genes (original + ablation)
    drift_fwd_orig, drift_bwd_orig = compute_drift_genes_for_models(
        vae=vae,
        z_f=z_f,
        z_b=z_b,
        adata=adata_brca_t,   # use global adata
        config=cfg,
        device=device,
    )
    drift_fwd_abl, drift_bwd_abl = compute_drift_genes_for_models(
        vae=vae_abl,
        z_f=z_f_abl,
        z_b=z_b_abl,
        adata=adata_brca_t,
        config=cfg_abl,
        device=device,
    )

    drift_fwd_orig.to_csv(os.path.join(run_dir, "drift_genes_forward_original.csv"), index=False)
    drift_bwd_orig.to_csv(os.path.join(run_dir, "drift_genes_backward_original.csv"), index=False)
    drift_fwd_abl.to_csv(os.path.join(run_dir, "drift_genes_forward_ablation.csv"), index=False)
    drift_bwd_abl.to_csv(os.path.join(run_dir, "drift_genes_backward_ablation.csv"), index=False)

    print(f"[Seed {seed}] Drift genes saved in {run_dir}")
    print(f"[Seed {seed}] run finished.")

In [None]:
# Loop over 20 seeds and run full pipeline
NUM_SEEDS = 20
SEEDS = list(range(1, NUM_SEEDS + 1))

for s in SEEDS:
    run_single_seed(seed=s, config=config, config_abl=config_abl)