## Load data

In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
LatentAdditive model training using external
train/val/test split file.
"""

import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
import scanpy as sc
from tqdm import tqdm

# =========================================
# 0️⃣ Reproducibility setup
# =========================================
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# =========================================
# 1️⃣ Load full AnnData (NOT control-only)
# =========================================
adata_path = "/gpfs/home/junxif/xin_lab/perturbench/data/boli_anndata/boli_with_GETembedding_celltypeaware_subset.h5ad"
split_path = "/gpfs/home/junxif/xin_lab/perturbench/data/boli_251006_1_qual_high_amt_high_split.csv"

adata = sc.read_h5ad(adata_path)

print("Loaded AnnData:", adata)

# =========================================
# 2️⃣ Load split CSV & attach to AnnData
# =========================================
df_split = pd.read_csv(split_path, header=None, names=["barcode", "split"])
df_split.index = df_split["barcode"]

adata.obs["split"] = adata.obs.index.map(df_split["split"])

print("Split counts:")
print(adata.obs["split"].value_counts())

# =========================================
# 3️⃣ Build train/val/test AnnData subsets
# =========================================
train_adata = adata[adata.obs["split"] == "train"].copy()
val_adata   = adata[adata.obs["split"] == "val"].copy()
test_adata  = adata[adata.obs["split"] == "test"].copy()

print(f"Train: {train_adata.n_obs} | Val: {val_adata.n_obs} | Test: {test_adata.n_obs}")
print(f"Genes: {adata.n_vars}")

# =========================================
# 4️⃣ Helper to convert X to numpy
# =========================================
def to_numpy(X):
    return X.toarray() if not isinstance(X, np.ndarray) else X

X_train = to_numpy(train_adata.X)
X_val   = to_numpy(val_adata.X)
X_test  = to_numpy(test_adata.X)

X_train = torch.tensor(X_train, dtype=torch.float32)
X_val   = torch.tensor(X_val, dtype=torch.float32)
X_test  = torch.tensor(X_test, dtype=torch.float32)

# =========================================
# 5️⃣ Perturbation one-hot for train/val/test
# =========================================
pert = adata.obs["condition"].astype("category")
pert_onehot = pd.get_dummies(pert)

p_train = torch.tensor(
    pert_onehot.loc[train_adata.obs.index].values, dtype=torch.float32
)
p_val = torch.tensor(
    pert_onehot.loc[val_adata.obs.index].values, dtype=torch.float32
)
p_test = torch.tensor(
    pert_onehot.loc[test_adata.obs.index].values, dtype=torch.float32
)

n_perts = p_train.shape[1]
print("Perturbations:", list(pert_onehot.columns))
print("Pert dim:", n_perts)

# =========================================
# 6️⃣ Celltype covariates one-hot
# =========================================
celltypes = adata.obs["celltype_mapped"].astype("category")
cov_onehot = pd.get_dummies(celltypes)

cov_train = torch.tensor(cov_onehot.loc[train_adata.obs.index].values, dtype=torch.float32)
cov_val   = torch.tensor(cov_onehot.loc[val_adata.obs.index].values,   dtype=torch.float32)
cov_test  = torch.tensor(cov_onehot.loc[test_adata.obs.index].values,  dtype=torch.float32)

n_cov = cov_train.shape[1]
print("Covariates dim:", n_cov)

# reuse same covariates for encoder + decoder
cov_train_enc = cov_train
cov_train_dec = cov_train
cov_val_enc   = cov_val
cov_val_dec   = cov_val
cov_test_enc  = cov_test
cov_test_dec  = cov_test

# =========================================
# 7️⃣ Build PyTorch datasets
# =========================================
train_ds = TensorDataset(X_train, p_train, cov_train_enc, cov_train_dec)
val_ds   = TensorDataset(X_val,   p_val,   cov_val_enc,   cov_val_dec)
test_ds  = TensorDataset(X_test,  p_test,  cov_test_enc,  cov_test_dec)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False)

print("DataLoaders ready.")


Loaded AnnData: AnnData object with n_obs × n_vars = 21700 × 3332
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'seurat_clusters', 'Assign', 'scds', 'cxds', 'bcds', 'Sample', 'nCount_refAssay', 'nFeature_refAssay', 'predicted.subclass.score', 'predicted.subclass', 'CT', 'mito', 'BioSamp', 'CT2', 'ForPlot', 'Remove', 'active_ident', 'Assign_clean', 'condition', 'cell_type', 'cell_class', 'celltype_mapped', 'split'
    var: 'variable_gene', 'gene_name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'gene_name_upper'
    uns: 'ATAC_embeddings', 'GET_embeddings', 'hvg', 'log1p'
Split counts:
split
train    16070
test      3033
val       2597
Name: count, dtype: int64
Train: 16070 | Val: 2597 | Test: 3033
Genes: 3332
Perturbations: ['ANK3', 'BCL11B', 'CUL1', 'CX3CL1', 'DAB1', 'HERC1', 'RB1CC1', 'SATB2', 'TBR1', 'TRIO', 'XPO7', 'ctrl']
Pert dim: 12
Covariates dim: 4
DataLoaders ready.


## Model training

In [2]:
# =========================================
# 5️⃣ LatentAdditive architecture
# =========================================
from torch import nn

class MLP(nn.Module):
    def __init__(self, in_dim, width, out_dim, n_layers=3, dropout=0.1):
        super().__init__()
        layers = []
        for i in range(n_layers):
            layers.append(nn.Linear(in_dim if i == 0 else width, width))
            layers.append(nn.LayerNorm(width))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(width, out_dim))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class LatentAdditive(nn.Module):
    def __init__(self, n_genes, n_perts, n_covariates_enc, n_covariates_dec,
                 latent_dim=160, encoder_width=3072, n_layers=3,
                 dropout=0.1, softplus_output=True):
        super().__init__()
        self.gene_encoder = MLP(n_genes + n_covariates_enc, encoder_width, latent_dim,
                                n_layers, dropout)
        self.pert_encoder = MLP(n_perts, encoder_width, latent_dim,
                                n_layers, dropout)
        self.decoder      = MLP(latent_dim + n_covariates_dec, encoder_width, n_genes,
                                n_layers, dropout)
        self.softplus_output = softplus_output

    def forward(self, x, p, cov_enc, cov_dec):
        latent_ctrl = self.gene_encoder(torch.cat([x, cov_enc], dim=1))
        latent_pert = self.pert_encoder(p)
        latent_sum  = latent_ctrl + latent_pert
        out = self.decoder(torch.cat([latent_sum, cov_dec], dim=1))
        if self.softplus_output:
            out = F.softplus(out)
        return out


# =========================================
# 6️⃣ Initialize model & optimizer
# =========================================
device = "cuda" if torch.cuda.is_available() else "cpu"

n_genes = X_train.shape[1]
n_covariates = cov_train.shape[1]

model = LatentAdditive(
    n_genes=n_genes,
    n_perts=n_perts,
    n_covariates_enc=n_covariates,
    n_covariates_dec=n_covariates,
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=1e-4,
                        weight_decay=1e-6, betas=(0.9, 0.999))
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min",
                                                   factor=0.1, patience=5)
scaler = torch.cuda.amp.GradScaler()


# =========================================
# 7️⃣ Training loop with AMP, val set
# =========================================
n_epochs = 20
for epoch in range(n_epochs):

    # ---- Train ----
    model.train()
    train_loss = 0
    for xb, pb, cenc, cdec in tqdm(train_loader,
                                   desc=f"Epoch {epoch+1}/{n_epochs} (train)"):
        xb, pb, cenc, cdec = xb.to(device), pb.to(device), cenc.to(device), cdec.to(device)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(dtype=torch.float16):
            recon = model(xb, p=pb, cov_enc=cenc, cov_dec=cdec)
            loss = F.mse_loss(recon, xb)
        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        train_loss += loss.item() * xb.size(0)

    train_loss /= len(train_loader.dataset)

    # ---- Validation ----
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for xb, pb, cenc, cdec in tqdm(val_loader,
                                       desc=f"Epoch {epoch+1}/{n_epochs} (val)"):
            xb, pb, cenc, cdec = xb.to(device), pb.to(device), cenc.to(device), cdec.to(device)
            with torch.cuda.amp.autocast(dtype=torch.float16):
                recon = model(xb, p=pb, cov_enc=cenc, cov_dec=cdec)
                loss = F.mse_loss(recon, xb)
            val_loss += loss.item() * xb.size(0)

    val_loss /= len(val_loader.dataset)
    sched.step(val_loss)

    print(f"Epoch {epoch+1:02d} | train={train_loss:.6f} | val={val_loss:.6f} | lr={opt.param_groups[0]['lr']:.2e}")


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast(dtype=torch.float16):
Epoch 1/20 (train): 100%|██████████| 251/251 [00:38<00:00,  6.56it/s]
  with torch.cuda.amp.autocast(dtype=torch.float16):
Epoch 1/20 (val): 100%|██████████| 41/41 [00:01<00:00, 27.47it/s]


Epoch 01 | train=0.088179 | val=0.066119 | lr=1.00e-04


Epoch 2/20 (train): 100%|██████████| 251/251 [00:31<00:00,  8.08it/s]
Epoch 2/20 (val): 100%|██████████| 41/41 [00:01<00:00, 24.95it/s]


Epoch 02 | train=0.064604 | val=0.059377 | lr=1.00e-04


Epoch 3/20 (train): 100%|██████████| 251/251 [00:31<00:00,  8.00it/s]
Epoch 3/20 (val): 100%|██████████| 41/41 [00:01<00:00, 31.01it/s]


Epoch 03 | train=0.059164 | val=0.055553 | lr=1.00e-04


Epoch 4/20 (train): 100%|██████████| 251/251 [00:31<00:00,  8.04it/s]
Epoch 4/20 (val): 100%|██████████| 41/41 [00:00<00:00, 69.68it/s] 


Epoch 04 | train=0.056165 | val=0.053702 | lr=1.00e-04


Epoch 5/20 (train): 100%|██████████| 251/251 [00:31<00:00,  8.01it/s]
Epoch 5/20 (val): 100%|██████████| 41/41 [00:01<00:00, 31.25it/s]


Epoch 05 | train=0.054457 | val=0.052791 | lr=1.00e-04


Epoch 6/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.85it/s]
Epoch 6/20 (val): 100%|██████████| 41/41 [00:00<00:00, 49.60it/s] 


Epoch 06 | train=0.053370 | val=0.051987 | lr=1.00e-04


Epoch 7/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.78it/s]
Epoch 7/20 (val): 100%|██████████| 41/41 [00:01<00:00, 29.58it/s]


Epoch 07 | train=0.052529 | val=0.051523 | lr=1.00e-04


Epoch 8/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.75it/s]
Epoch 8/20 (val): 100%|██████████| 41/41 [00:01<00:00, 35.13it/s]


Epoch 08 | train=0.051830 | val=0.051018 | lr=1.00e-04


Epoch 9/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.68it/s]
Epoch 9/20 (val): 100%|██████████| 41/41 [00:01<00:00, 28.95it/s]


Epoch 09 | train=0.051173 | val=0.050531 | lr=1.00e-04


Epoch 10/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.80it/s]
Epoch 10/20 (val): 100%|██████████| 41/41 [00:01<00:00, 30.91it/s]


Epoch 10 | train=0.050615 | val=0.050265 | lr=1.00e-04


Epoch 11/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.71it/s]
Epoch 11/20 (val): 100%|██████████| 41/41 [00:01<00:00, 29.31it/s]


Epoch 11 | train=0.050112 | val=0.050066 | lr=1.00e-04


Epoch 12/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.83it/s]
Epoch 12/20 (val): 100%|██████████| 41/41 [00:01<00:00, 29.00it/s]


Epoch 12 | train=0.049635 | val=0.049842 | lr=1.00e-04


Epoch 13/20 (train): 100%|██████████| 251/251 [00:29<00:00,  8.65it/s]
Epoch 13/20 (val): 100%|██████████| 41/41 [00:01<00:00, 30.19it/s]


Epoch 13 | train=0.049193 | val=0.049595 | lr=1.00e-04


Epoch 14/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.89it/s]
Epoch 14/20 (val): 100%|██████████| 41/41 [00:01<00:00, 31.47it/s]


Epoch 14 | train=0.048777 | val=0.049540 | lr=1.00e-04


Epoch 15/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.71it/s]
Epoch 15/20 (val): 100%|██████████| 41/41 [00:01<00:00, 28.76it/s]


Epoch 15 | train=0.048347 | val=0.049329 | lr=1.00e-04


Epoch 16/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.75it/s]
Epoch 16/20 (val): 100%|██████████| 41/41 [00:01<00:00, 29.51it/s]


Epoch 16 | train=0.047941 | val=0.049111 | lr=1.00e-04


Epoch 17/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.78it/s]
Epoch 17/20 (val): 100%|██████████| 41/41 [00:00<00:00, 43.42it/s]


Epoch 17 | train=0.047550 | val=0.048998 | lr=1.00e-04


Epoch 18/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.85it/s]
Epoch 18/20 (val): 100%|██████████| 41/41 [00:01<00:00, 29.30it/s]


Epoch 18 | train=0.047157 | val=0.048891 | lr=1.00e-04


Epoch 19/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.75it/s]
Epoch 19/20 (val): 100%|██████████| 41/41 [00:01<00:00, 40.58it/s]


Epoch 19 | train=0.046766 | val=0.048760 | lr=1.00e-04


Epoch 20/20 (train): 100%|██████████| 251/251 [00:28<00:00,  8.76it/s]
Epoch 20/20 (val): 100%|██████████| 41/41 [00:01<00:00, 29.77it/s]

Epoch 20 | train=0.046359 | val=0.048724 | lr=1.00e-04





## Evaluation

In [10]:
import os
print(os.getcwd())


/gpfs/home/junxif/xin_lab


In [11]:
import sys, os
# add parent directory of the notebook (perturbench/)
sys.path.append(os.path.abspath("perturbench"))
from multiome.eval import evaluate_model
import importlib
import multiome.eval
importlib.reload(multiome.eval)

<module 'multiome.eval' from '/gpfs/home/junxif/xin_lab/perturbench/multiome/eval.py'>

In [12]:
results = evaluate_model(model, test_loader, test_adata, device=device, k=50)
results


  "per_pert": { pert_name: {...}, ... }


{'global': {'rmse': 0.2246228338438571,
  'pearson_mean': 0.8598196506500244,
  'cosine_mean': 0.8714019443546404,
  'global_logfc_cosine': 0.0},
 'per_pert': {'BCL11B': {'n_cells': 1380,
   'logfc_corr': 0.7387996912002563,
   'logfc_cosine': 0.7402163281244087,
   'top50_recall': 0.08},
  'XPO7': {'n_cells': 392,
   'logfc_corr': 0.6945080757141113,
   'logfc_cosine': 0.7193422239942252,
   'top50_recall': 0.06},
  'TBR1': {'n_cells': 607,
   'logfc_corr': 0.7969701886177063,
   'logfc_cosine': 0.7666682400719369,
   'top50_recall': 0.24},
  'CX3CL1': {'n_cells': 183,
   'logfc_corr': 0.7026352882385254,
   'logfc_cosine': 0.6510339120557519,
   'top50_recall': 0.58},
  'HERC1': {'n_cells': 162,
   'logfc_corr': 0.6920470595359802,
   'logfc_cosine': 0.65323934777674,
   'top50_recall': 0.7},
  'ctrl': {'n_cells': 309,
   'logfc_corr': 0.29403549432754517,
   'logfc_cosine': 0.23814074675678354,
   'top50_recall': 0.1}}}