In [3]:
import os, sys, json, random
import numpy as np
import pandas as pd
import torch
import scanpy as sc
import scipy.sparse as sp

work_dir = "/home/yyuan/ICB_TCE/"
script_dir = os.path.join(work_dir, "scripts")
if script_dir not in sys.path:
    sys.path.append(script_dir)

from vae import VAE_scRNA, load_trained_vae, compute_latent_embeddings
from sde import VESDE, SchrodingerBridgePolicy, MLP
from bio_con import *
from bio_util import *
from training_util import *
from joint_train import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
# Load data and prepare constraints similarly to icb.ipynb
adata_brca_t = sc.read_h5ad(os.path.join(work_dir, 'data/brca_t_cell.h5ad'))

# HVG selection
sc.pp.filter_cells(adata_brca_t, min_genes=200)
sc.pp.filter_genes(adata_brca_t, min_cells=3)
sc.pp.highly_variable_genes(adata_brca_t, n_top_genes=3000, subset=True)

# Expression tensor
X = adata_brca_t.X.toarray() if hasattr(adata_brca_t.X, 'toarray') else adata_brca_t.X
expression_data = torch.tensor(X, dtype=torch.float32)
input_dim = expression_data.shape[1]
LATENT_DIM = 20

# Load pre-trained VAE and compute latent embeddings
MODEL_SAVE_PATH = os.path.join(work_dir, "trained_models/vae_pretrain.pth")
vae = load_trained_vae(
    model_path=MODEL_SAVE_PATH,
    input_dim=input_dim,
    latent_dim=LATENT_DIM,
    device=device,
)
latent_mu = compute_latent_embeddings(vae, expression_data, device=device)
latent_embeddings = latent_mu.numpy()
adata_brca_t.obsm["X_vae"] = latent_embeddings

# Prepare GO gene sets
cell_death_files = [
    os.path.join(work_dir, "GO_geneset/HALLMARK_APOPTOSIS.v2025.1.Hs.json"),
    os.path.join(work_dir, "GO_geneset/HALLMARK_P53_PATHWAY.v2025.1.Hs.json"),
    os.path.join(work_dir, "GO_geneset/HALLMARK_REACTIVE_OXYGEN_SPECIES_PATHWAY.v2025.1.Hs.json"),
    os.path.join(work_dir, "GO_geneset/HALLMARK_UNFOLDED_PROTEIN_RESPONSE.v2025.1.Hs.json"),
]
cell_birth_files = [
    os.path.join(work_dir, "GO_geneset/HALLMARK_E2F_TARGETS.v2025.1.Hs.json"),
    os.path.join(work_dir, "GO_geneset/HALLMARK_G2M_CHECKPOINT.v2025.1.Hs.json"),
    os.path.join(work_dir, "GO_geneset/HALLMARK_MYC_TARGETS_V1.v2025.1.Hs.json"),
    os.path.join(work_dir, "GO_geneset/HALLMARK_MYC_TARGETS_V2.v2025.1.Hs.json"),
]

all_death_genes, all_birth_genes = set(), set()
for fp in cell_death_files:
    _, genes = load_gene_set_from_json(fp)
    if genes:
        all_death_genes.update(genes)
for fp in cell_birth_files:
    _, genes = load_gene_set_from_json(fp)
    if genes:
        all_birth_genes.update(genes)

all_death_genes_list = sorted(list(all_death_genes))
all_birth_genes_list = sorted(list(all_birth_genes))

# GRN with prior
hvg_names_list = adata_brca_t.var['feature_name'].tolist()
adj_file = os.path.join(work_dir, "data/brca_t_cell_adj.csv")
prior_edges_file = os.path.join(work_dir, "data/TCE_prior_edges.csv")

grn_df_hvg, grn_data = build_grn_with_prior(
    adj_path=adj_file,
    prior_edges_path=prior_edges_file,
    hvg_names_list=hvg_names_list,
)

Loaded 161 genes from HALLMARK_APOPTOSIS
Loaded 200 genes from HALLMARK_P53_PATHWAY
Loaded 49 genes from HALLMARK_REACTIVE_OXYGEN_SPECIES_PATHWAY
Loaded 113 genes from HALLMARK_UNFOLDED_PROTEIN_RESPONSE
Loaded 200 genes from HALLMARK_E2F_TARGETS
Loaded 200 genes from HALLMARK_G2M_CHECKPOINT
Loaded 200 genes from HALLMARK_MYC_TARGETS_V1
Loaded 58 genes from HALLMARK_MYC_TARGETS_V2

Found 123 GRN edges overlapping with prior knowledge.
Loaded 53 GRN rules where both TF and target are in the HVG set.


In [5]:
# Prepare dynamics sampler (shared across runs)
config_path = os.path.join(work_dir, 'conceptual_val_results/full_config.json')
with open(config_path, 'r') as f:
    base_config = json.load(f)

base_config['data_dim'] = [LATENT_DIM]

pre_mask = adata_brca_t.obs['pre_post'] == 'Pre'
on_mask = adata_brca_t.obs['pre_post'] == 'Post'
pre_embeddings = torch.tensor(latent_embeddings[pre_mask], dtype=torch.float32).to(device)
on_embeddings = torch.tensor(latent_embeddings[on_mask], dtype=torch.float32).to(device)

class DataSampler:
    def __init__(self, data, device):
        self.data = data.to(device)
    def sample(self, batch_size):
        idx = torch.randint(0, len(self.data), (batch_size,))
        return self.data[idx]

p_sampler = DataSampler(pre_embeddings, device=device)
q_sampler = DataSampler(on_embeddings, device=device)

dyn = VESDE(base_config, p_sampler, q_sampler)
ts = torch.linspace(base_config['t0'], base_config['T'], base_config['interval']).to(device)

In [6]:
# Helper tensors reused across runs for drift computation
_treatment_col = [col for col in adata_brca_t.obs.columns
                  if 'treatment' in col.lower() or 'pre' in col.lower()][0]
_unique_vals = adata_brca_t.obs[_treatment_col].unique()
_pre_val = [v for v in _unique_vals if 'pre' in str(v).lower() or v == 0][0]
_on_val = [v for v in _unique_vals if 'pre' not in str(v).lower() and v != 0][0]

_pre_mask = adata_brca_t.obs[_treatment_col] == _pre_val
_on_treatment_mask = adata_brca_t.obs[_treatment_col] == _on_val

_pre_cells_expr = adata_brca_t[_pre_mask].X
if sp.issparse(_pre_cells_expr):
    _pre_cells_expr = _pre_cells_expr.toarray()
_pre_cells_expr = torch.tensor(_pre_cells_expr, dtype=torch.float32)

_on_cells_expr = adata_brca_t[_on_treatment_mask].X
if sp.issparse(_on_cells_expr):
    _on_cells_expr = _on_cells_expr.toarray()
_on_cells_expr = torch.tensor(_on_cells_expr, dtype=torch.float32)

_gene_symbols = adata_brca_t.var['feature_name'].values.tolist()

_death_genes = set(all_death_genes_list)
_birth_genes = set(all_birth_genes_list)

print("Prepared shared expression tensors for drift computation.")

Prepared shared expression tensors for drift computation.


In [5]:
# Multi-run loop: 10 runs of full training (original + ablation) and drift (top 50 only)
from collections import Counter

N_RUNS = 10
BASE_SEED = 42
TOP_K = 50
iter_dir = os.path.join(work_dir, 'iter_results')
os.makedirs(iter_dir, exist_ok=True)

# Stability counters across all runs
stab_forward_orig = Counter()
stab_backward_orig = Counter()
stab_forward_abl = Counter()
stab_backward_abl = Counter()

for run_idx in range(1, N_RUNS + 1):
    seed = BASE_SEED + run_idx
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    print(f"\n=== RUN {run_idx}/{N_RUNS} | seed={seed} ===")

    # ----- ORIGINAL TRAINING (with biology) -----
    config = dict(base_config)  # shallow copy
    config['data_dim'] = [LATENT_DIM]

    vae_orig = VAE_scRNA(input_dim=expression_data.shape[1], latent_dim=LATENT_DIM).to(device)
    vae_orig.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))
    vae_decoder_orig = lambda z: vae_orig.decoder_output(vae_orig.decoder(z))

    net_f_orig = MLP(input_dim=LATENT_DIM, output_dim=LATENT_DIM).to(device)
    net_b_orig = MLP(input_dim=LATENT_DIM, output_dim=LATENT_DIM).to(device)
    z_f_orig = SchrodingerBridgePolicy(config, 'forward', dyn, net_f_orig).to(device)
    z_b_orig = SchrodingerBridgePolicy(config, 'backward', dyn, net_b_orig).to(device)

    optimizer_vae_orig = torch.optim.Adam(vae_orig.parameters(), lr=config['lr'])
    optimizer_f_orig = torch.optim.Adam(z_f_orig.parameters(), lr=config['lr'])
    optimizer_b_orig = torch.optim.Adam(z_b_orig.parameters(), lr=config['lr'])

    print("Running joint training (original, with biology)...")
    training_history_orig = run_joint_training_loop(
        config=config,
        dyn=dyn,
        ts=ts,
        vae=vae_orig,
        vae_decoder=vae_decoder_orig,
        z_f=z_f_orig,
        z_b=z_b_orig,
        optimizer_f=optimizer_f_orig,
        optimizer_b=optimizer_b_orig,
        optimizer_vae=optimizer_vae_orig,
        expression_data=expression_data,
        grn_data=grn_data,
        death_gene_indices=[hvg_names_list.index(g) for g in all_death_genes_list if g in hvg_names_list],
        birth_gene_indices=[hvg_names_list.index(g) for g in all_birth_genes_list if g in hvg_names_list],
        device=device,
    )

    run_dir = os.path.join(iter_dir, f'run_{run_idx:02d}')
    os.makedirs(run_dir, exist_ok=True)

    # Save original run checkpoints and history
    torch.save({'model_state_dict': vae_orig.state_dict(), 'input_dim': expression_data.shape[1], 'latent_dim': LATENT_DIM},
               os.path.join(run_dir, 'vae_joint_trained_orig.pth'))
    torch.save({'model_state_dict': z_f_orig.net.state_dict()}, os.path.join(run_dir, 'z_f_policy_orig.pth'))
    torch.save({'model_state_dict': z_b_orig.net.state_dict()}, os.path.join(run_dir, 'z_b_policy_orig.pth'))
    with open(os.path.join(run_dir, 'training_history_orig.json'), 'w') as f:
        json.dump(training_history_orig, f, indent=2)

    # Drift for original model
    f_orig_df, b_orig_df = compute_drift_tables(
        vae=vae_orig,
        z_f=z_f_orig,
        z_b=z_b_orig,
        pre_cells_expr=_pre_cells_expr,
        on_cells_expr=_on_cells_expr,
        gene_symbols=_gene_symbols,
        epsilon=1e-4,
    )
    f_orig_df['is_death_gene'] = f_orig_df['gene'].isin(_death_genes)
    f_orig_df['is_birth_gene'] = f_orig_df['gene'].isin(_birth_genes)
    b_orig_df['is_death_gene'] = b_orig_df['gene'].isin(_death_genes)
    b_orig_df['is_birth_gene'] = b_orig_df['gene'].isin(_birth_genes)

    top_f_orig = f_orig_df.head(TOP_K)
    top_b_orig = b_orig_df.head(TOP_K)

    stab_forward_orig.update(top_f_orig['gene'].tolist())
    stab_backward_orig.update(top_b_orig['gene'].tolist())

    f_orig_df.to_csv(os.path.join(run_dir, 'drift_genes_forward_original.csv'), index=False)
    b_orig_df.to_csv(os.path.join(run_dir, 'drift_genes_backward_original.csv'), index=False)

    # ----- ABLATION TRAINING (no biology) -----
    config_abl = dict(base_config)
    config_abl['data_dim'] = [LATENT_DIM]
    config_abl['lambda_bio'] = 0.0
    config_abl['lambda_grn'] = 0.0
    config_abl['lambda_death'] = 0.0
    config_abl['lambda_birth'] = 0.0

    vae_abl = VAE_scRNA(input_dim=expression_data.shape[1], latent_dim=LATENT_DIM).to(device)
    vae_abl.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))
    vae_decoder_abl = lambda z: vae_abl.decoder_output(vae_abl.decoder(z))

    net_f_abl = MLP(input_dim=LATENT_DIM, output_dim=LATENT_DIM).to(device)
    net_b_abl = MLP(input_dim=LATENT_DIM, output_dim=LATENT_DIM).to(device)
    z_f_abl = SchrodingerBridgePolicy(config_abl, 'forward', dyn, net_f_abl).to(device)
    z_b_abl = SchrodingerBridgePolicy(config_abl, 'backward', dyn, net_b_abl).to(device)

    optimizer_vae_abl = torch.optim.Adam(vae_abl.parameters(), lr=config_abl['lr'])
    optimizer_f_abl = torch.optim.Adam(z_f_abl.parameters(), lr=config_abl['lr'])
    optimizer_b_abl = torch.optim.Adam(z_b_abl.parameters(), lr=config_abl['lr'])

    print("Running joint training (ablation, no biology)...")
    training_history_abl = run_joint_training_loop(
        config=config_abl,
        dyn=dyn,
        ts=ts,
        vae=vae_abl,
        vae_decoder=vae_decoder_abl,
        z_f=z_f_abl,
        z_b=z_b_abl,
        optimizer_f=optimizer_f_abl,
        optimizer_b=optimizer_b_abl,
        optimizer_vae=optimizer_vae_abl,
        expression_data=expression_data,
        grn_data=grn_data,
        death_gene_indices=[hvg_names_list.index(g) for g in all_death_genes_list if g in hvg_names_list],
        birth_gene_indices=[hvg_names_list.index(g) for g in all_birth_genes_list if g in hvg_names_list],
        device=device,
    )

    # Save ablation checkpoints and history
    torch.save({'model_state_dict': vae_abl.state_dict(), 'input_dim': expression_data.shape[1], 'latent_dim': LATENT_DIM},
               os.path.join(run_dir, 'vae_joint_trained_ablation.pth'))
    torch.save({'model_state_dict': z_f_abl.net.state_dict()}, os.path.join(run_dir, 'z_f_policy_ablation.pth'))
    torch.save({'model_state_dict': z_b_abl.net.state_dict()}, os.path.join(run_dir, 'z_b_policy_ablation.pth'))
    with open(os.path.join(run_dir, 'training_history_ablation.json'), 'w') as f:
        json.dump(training_history_abl, f, indent=2)

    # Drift for ablation model
    f_abl_df, b_abl_df = compute_drift_tables(
        vae=vae_abl,
        z_f=z_f_abl,
        z_b=z_b_abl,
        pre_cells_expr=_pre_cells_expr,
        on_cells_expr=_on_cells_expr,
        gene_symbols=_gene_symbols,
        epsilon=1e-4,
    )
    f_abl_df['is_death_gene'] = f_abl_df['gene'].isin(_death_genes)
    f_abl_df['is_birth_gene'] = f_abl_df['gene'].isin(_birth_genes)
    b_abl_df['is_death_gene'] = b_abl_df['gene'].isin(_death_genes)
    b_abl_df['is_birth_gene'] = b_abl_df['gene'].isin(_birth_genes)

    top_f_abl = f_abl_df.head(TOP_K)
    top_b_abl = b_abl_df.head(TOP_K)

    stab_forward_abl.update(top_f_abl['gene'].tolist())
    stab_backward_abl.update(top_b_abl['gene'].tolist())

    f_abl_df.to_csv(os.path.join(run_dir, 'drift_genes_forward_ablation.csv'), index=False)
    b_abl_df.to_csv(os.path.join(run_dir, 'drift_genes_backward_ablation.csv'), index=False)

print("\nAll runs completed.")


=== RUN 1/10 | seed=43 ===
Running joint training (original, with biology)...

STAGE 1/5

  [Backward Policy Training]
Running joint training (original, with biology)...

STAGE 1/5

  [Backward Policy Training]
    Epoch  1/10: total=21.9085, vae=2.6109, dsb=6.3122, bio=25.9708
    Epoch  1/10: total=21.9085, vae=2.6109, dsb=6.3122, bio=25.9708
    Epoch  2/10: total=3.2483, vae=1.0924, dsb=2.1280, bio=0.0558
    Epoch  2/10: total=3.2483, vae=1.0924, dsb=2.1280, bio=0.0558
    Epoch  3/10: total=3.0042, vae=1.0906, dsb=1.9016, bio=0.0239
    Epoch  3/10: total=3.0042, vae=1.0906, dsb=1.9016, bio=0.0239
    Epoch  4/10: total=2.8196, vae=1.0889, dsb=1.7233, bio=0.0149
    Epoch  4/10: total=2.8196, vae=1.0889, dsb=1.7233, bio=0.0149
    Epoch  5/10: total=2.7393, vae=1.0888, dsb=1.6449, bio=0.0111
    Epoch  5/10: total=2.7393, vae=1.0888, dsb=1.6449, bio=0.0111
    Epoch  6/10: total=2.6537, vae=1.0880, dsb=1.5608, bio=0.0099
    Epoch  6/10: total=2.6537, vae=1.0880, dsb=1.5608, bio

In [None]:
# Aggregate stability scores across runs and save rankings

def _counter_to_df(counter, count_col):
    return pd.DataFrame({
        'gene': list(counter.keys()),
        count_col: list(counter.values()),
    })

stab_f_orig_df = _counter_to_df(stab_forward_orig, 'stability_count_forward_orig')
stab_b_orig_df = _counter_to_df(stab_backward_orig, 'stability_count_backward_orig')
stab_f_abl_df = _counter_to_df(stab_forward_abl, 'stability_count_forward_abl')
stab_b_abl_df = _counter_to_df(stab_backward_abl, 'stability_count_backward_abl')

stab = stab_f_orig_df
stab = stab.merge(stab_b_orig_df, on='gene', how='outer')
stab = stab.merge(stab_f_abl_df, on='gene', how='outer')
stab = stab.merge(stab_b_abl_df, on='gene', how='outer')

stab = stab.fillna(0)
for col in stab.columns:
    if col.startswith('stability_count_'):
        stab[col] = stab[col].astype(int)

stab['stability_freq_forward_orig'] = stab['stability_count_forward_orig'] / float(N_RUNS)
stab['stability_freq_backward_orig'] = stab['stability_count_backward_orig'] / float(N_RUNS)
stab['stability_freq_forward_abl'] = stab['stability_count_forward_abl'] / float(N_RUNS)
stab['stability_freq_backward_abl'] = stab['stability_count_backward_abl'] / float(N_RUNS)

stab = stab.sort_values(
    ['stability_count_backward_orig', 'stability_count_forward_orig'],
    ascending=[False, False]
).reset_index(drop=True)

stab_path = os.path.join(iter_dir, 'drift_gene_stability_scores_orig_vs_ablation.csv')
stab.to_csv(stab_path, index=False)
print('Saved stability scores to:', stab_path)
stab.head()

Saved stability scores to: /home/yyuan/ICB_TCE/iter_results/drift_gene_stability_scores_orig_vs_ablation.csv


Unnamed: 0,gene,stability_count_forward_orig,stability_count_backward_orig,stability_count_forward_abl,stability_count_backward_abl,stability_freq_forward_orig,stability_freq_backward_orig,stability_freq_forward_abl,stability_freq_backward_abl
0,NKG7,9,8,5,5,0.9,0.8,0.5,0.5
1,TSC22D3,9,8,7,7,0.9,0.8,0.7,0.7
2,CD7,7,8,3,4,0.7,0.8,0.3,0.4
3,DUSP2,7,8,8,8,0.7,0.8,0.8,0.8
4,VIM,7,8,7,7,0.7,0.8,0.7,0.7
