## Load the data

In [15]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
LatentAdditiveGET data loading with external train/val/test split.
Includes:
  - gene expression (X)
  - perturbation one-hot (p)
  - celltype covariates (cov)
  - GET gene-level embeddings (GET)
"""

import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
import scanpy as sc
from tqdm import tqdm

# =========================================
# 0Ô∏è‚É£ Reproducibility
# =========================================
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# =========================================
# 1Ô∏è‚É£ Load AnnData WITH GET embeddings
# =========================================
adata_path = "/gpfs/home/junxif/xin_lab/perturbench/data/boli_anndata/boli_with_GETembedding_celltypeaware_filled.h5ad"
split_path = "/gpfs/home/junxif/xin_lab/perturbench/data/boli_251006_1_qual_high_amt_high_split.csv"

adata = sc.read_h5ad(adata_path)
print("Loaded AnnData:", adata)

# =========================================
# 2Ô∏è‚É£ Load split CSV
# =========================================
df_split = pd.read_csv(split_path, header=None, names=["barcode", "split"])
df_split.index = df_split["barcode"]

adata.obs["split"] = adata.obs.index.map(df_split["split"])
print("Split counts:")
print(adata.obs["split"].value_counts())

# =========================================
# 3Ô∏è‚É£ Build train / val / test sets
# =========================================
train_adata = adata[adata.obs["split"] == "train"].copy()
val_adata   = adata[adata.obs["split"] == "val"].copy()
test_adata  = adata[adata.obs["split"] == "test"].copy()

print(f"Train: {train_adata.n_obs} | Val: {val_adata.n_obs} | Test: {test_adata.n_obs}")
print("Genes:", adata.n_vars)

# =========================================
# 4Ô∏è‚É£ Convert expression matrix to numpy
# =========================================
def to_numpy(X):
    return X.toarray() if not isinstance(X, np.ndarray) else X

X_train = torch.tensor(to_numpy(train_adata.X), dtype=torch.float32)
X_val   = torch.tensor(to_numpy(val_adata.X),   dtype=torch.float32)
X_test  = torch.tensor(to_numpy(test_adata.X),  dtype=torch.float32)

# =========================================
# 5Ô∏è‚É£ Perturbation one-hot
# =========================================
pert = adata.obs["condition"].astype("category")
pert_onehot = pd.get_dummies(pert)

p_train = torch.tensor(pert_onehot.loc[train_adata.obs.index].values, dtype=torch.float32)
p_val   = torch.tensor(pert_onehot.loc[val_adata.obs.index].values,   dtype=torch.float32)
p_test  = torch.tensor(pert_onehot.loc[test_adata.obs.index].values,  dtype=torch.float32)

n_perts = p_train.shape[1]
print("Perturbations:", list(pert_onehot.columns))
print("Pert dim:", n_perts)

# =========================================
# 6Ô∏è‚É£ Celltype covariates one-hot
# =========================================
celltypes = adata.obs["celltype_mapped"].astype("category")
cov_onehot = pd.get_dummies(celltypes)

cov_train = torch.tensor(cov_onehot.loc[train_adata.obs.index].values, dtype=torch.float32)
cov_val   = torch.tensor(cov_onehot.loc[val_adata.obs.index].values,   dtype=torch.float32)
cov_test  = torch.tensor(cov_onehot.loc[test_adata.obs.index].values,  dtype=torch.float32)

n_cov = cov_train.shape[1]
print("Covariates dim:", n_cov)

# encoder/decoder use same covariates
cov_train_enc = cov_train
cov_train_dec = cov_train
cov_val_enc   = cov_val
cov_val_dec   = cov_val
cov_test_enc  = cov_test
cov_test_dec  = cov_test

# =========================================
# 7Ô∏è‚É£ Prepare GET gene-level embeddings
# =========================================
print("Preparing GET embeddings...")

# Extract GET dictionary
get_dict = adata.uns["GET_embeddings"]

# Check gene dimension
example_ct = list(get_dict.keys())[0]
n_genes, d_get = get_dict[example_ct].shape
assert n_genes == adata.n_vars, "Mismatch GET vs adata gene count!"

print(f"Detected GET embedding: {n_genes} genes √ó {d_get}-dim")

# Function to map GET embeddings for each cell in a subset
def extract_get_for_subset(sub_adata, get_dict):
    N = sub_adata.n_obs
    celltypes = sub_adata.obs["celltype_mapped"].astype(str).values
    
    GET_mat = np.zeros((N, n_genes, d_get), dtype=np.float32)
    for i, ct in enumerate(celltypes):
        if ct not in get_dict:
            raise ValueError(f"Missing GET embedding for cell type '{ct}'")
        GET_mat[i] = get_dict[ct]  # shape (genes √ó d_get)
    return GET_mat

# Build GET train/val/test
GET_train = torch.tensor(extract_get_for_subset(train_adata, get_dict), dtype=torch.float32)
GET_val   = torch.tensor(extract_get_for_subset(val_adata,   get_dict), dtype=torch.float32)
GET_test  = torch.tensor(extract_get_for_subset(test_adata,  get_dict), dtype=torch.float32)

print("GET shapes:")
print("  GET_train:", GET_train.shape)
print("  GET_val:  ", GET_val.shape)
print("  GET_test: ", GET_test.shape)

# =========================================
# 8Ô∏è‚É£ Build PyTorch datasets (WITH GET)
# =========================================
train_ds = TensorDataset(
    X_train,
    p_train,
    cov_train_enc,
    cov_train_dec,
    GET_train
)

val_ds = TensorDataset(
    X_val,
    p_val,
    cov_val_enc,
    cov_val_dec,
    GET_val
)

test_ds = TensorDataset(
    X_test,
    p_test,
    cov_test_enc,
    cov_test_dec,
    GET_test
)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=64, shuffle=False)

print("DataLoaders (with GET) ready.")


Loaded AnnData: AnnData object with n_obs √ó n_vars = 21700 √ó 5000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'seurat_clusters', 'Assign', 'scds', 'cxds', 'bcds', 'Sample', 'nCount_refAssay', 'nFeature_refAssay', 'predicted.subclass.score', 'predicted.subclass', 'CT', 'mito', 'BioSamp', 'CT2', 'ForPlot', 'Remove', 'active_ident', 'Assign_clean', 'condition', 'cell_type', 'cell_class', 'celltype_mapped'
    var: 'variable_gene', 'gene_name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'ATAC_embeddings', 'GET_embeddings', 'hvg', 'log1p'
Split counts:
split
train    16070
test      3033
val       2597
Name: count, dtype: int64
Train: 16070 | Val: 2597 | Test: 3033
Genes: 5000
Perturbations: ['ANK3', 'BCL11B', 'CUL1', 'CX3CL1', 'DAB1', 'HERC1', 'RB1CC1', 'SATB2', 'TBR1', 'TRIO', 'XPO7', 'ctrl']
Pert dim: 12
Covariates dim: 4
Preparing GET embeddings...
Detected GET embedding: 5000 genes √ó 768-dim
GET shapes:
  GET_train: torch.Size([16070, 5000, 768])

In [17]:
# ======================================================
# üîç Verification: check GET embeddings assigned correctly
# ======================================================

import random

def verify_get_subset(sub_adata, GET_tensor, name="train"):
    print(f"\nVerifying GET mapping for {name} set...")
    
    N = sub_adata.n_obs
    idxs = random.sample(range(N), k=min(5, N))  # check 5 random cells
    ok_count = 0
    
    for i in idxs:
        ct = sub_adata.obs["celltype_mapped"].iloc[i]
        expected = get_dict[ct]                                    # numpy array
        loaded   = GET_tensor[i].cpu().numpy()                     # what we stored
        
        same = np.allclose(expected, loaded, atol=1e-6)
        
        print(f"  Cell {i:4d} | celltype={ct:15s} | GET match = {same}")
        if same:
            ok_count += 1
    
    print(f"‚úî Passed {ok_count}/{len(idxs)} checks for {name} set.\n")


verify_get_subset(train_adata, GET_train, "train")
verify_get_subset(val_adata,   GET_val,   "val")
verify_get_subset(test_adata,  GET_test,  "test")



Verifying GET mapping for train set...
  Cell 6851 | celltype=cr_glut         | GET match = True
  Cell 4631 | celltype=nonit_glut      | GET match = True
  Cell 2489 | celltype=nonit_glut      | GET match = True
  Cell 4484 | celltype=ctx-mge_gaba    | GET match = True
  Cell  219 | celltype=nonit_glut      | GET match = True
‚úî Passed 5/5 checks for train set.


Verifying GET mapping for val set...
  Cell 2584 | celltype=nonit_glut      | GET match = True
  Cell 2052 | celltype=nonit_glut      | GET match = True
  Cell 1577 | celltype=nonit_glut      | GET match = True
  Cell   62 | celltype=nonit_glut      | GET match = True
  Cell 1172 | celltype=it_glut         | GET match = True
‚úî Passed 5/5 checks for val set.


Verifying GET mapping for test set...
  Cell  393 | celltype=nonit_glut      | GET match = True
  Cell 2450 | celltype=it_glut         | GET match = True
  Cell 2255 | celltype=it_glut         | GET match = True
  Cell  271 | celltype=nonit_glut      | GET match = Tr

## Model training

In [None]:
# =========================================
# 5Ô∏è‚É£ LatentAdditiveGET architecture (with GET fusion)
# =========================================
from torch import nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, in_dim, width, out_dim, n_layers=3, dropout=0.1):
        super().__init__()
        layers = []
        for i in range(n_layers):
            layers.append(nn.Linear(in_dim if i == 0 else width, width))
            layers.append(nn.LayerNorm(width))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(width, out_dim))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class LatentAdditiveGET(nn.Module):
    """
    LatentAdditive + GET gene-level additive fusion (Strategy A)

    forward(
        x        : (batch √ó n_genes)
        p        : (batch √ó n_perts)
        cov_enc  : (batch √ó n_cov)
        cov_dec  : (batch √ó n_cov)
        get_emb  : (batch √ó n_genes √ó d_get)
    )
    """
    def __init__(
        self,
        n_genes,
        n_perts,
        n_covariates_enc,
        n_covariates_dec,
        d_get,                  # GET embedding dimension
        latent_dim=160,
        encoder_width=3072,
        n_layers=3,
        dropout=0.1,
        softplus_output=True,
    ):
        super().__init__()

        # Encoders
        self.gene_encoder = MLP(
            n_genes + n_covariates_enc,
            encoder_width,
            latent_dim,
            n_layers,
            dropout,
        )

        self.pert_encoder = MLP(
            n_perts,
            encoder_width,
            latent_dim,
            n_layers,
            dropout,
        )

        # Decoder
        self.decoder = MLP(
            latent_dim + n_covariates_dec,
            encoder_width,
            n_genes,
            n_layers,
            dropout,
        )

        # NEW: Linear GET ‚Üí per-gene scalar
        # maps each d_get-dimensional gene embedding to a single scalar
        self.linear_get = nn.Linear(d_get, 1)

        self.softplus_output = softplus_output

    def forward(self, x, p, cov_enc, cov_dec, get_emb):
        """
        get_emb: (batch √ó n_genes √ó d_get)
        """
        # ----- latent computation -----
        latent_ctrl = self.gene_encoder(torch.cat([x, cov_enc], dim=1))
        latent_pert = self.pert_encoder(p)
        latent_sum  = latent_ctrl + latent_pert  # (batch √ó latent_dim)

        # ----- base decoder output -----
        zcat = torch.cat([latent_sum, cov_dec], dim=1)
        base_out = self.decoder(zcat)  # (batch √ó n_genes)

        # ----- GET additive term -----
        # get_emb: (batch √ó n_genes √ó d_get)
        # linear_get -> (batch √ó n_genes √ó 1)
        get_term = self.linear_get(get_emb).squeeze(-1)

        # ----- combine base prediction + GET -----
        out = base_out + get_term

        if self.softplus_output:
            out = F.softplus(out)

        return out


# =========================================
# 6Ô∏è‚É£ Initialize model & optimizer
# =========================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

n_genes = X_train.shape[1]
n_covariates = cov_train.shape[1]

# GET embedding dimension
_, _, d_get = GET_train.shape

model = LatentAdditiveGET(
    n_genes=n_genes,
    n_perts=n_perts,
    n_covariates_enc=n_covariates,
    n_covariates_dec=n_covariates,
    d_get=d_get,                   # <---- NEW
).to(device)

opt = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-6,
    betas=(0.9, 0.999)
)

sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt, mode="min", factor=0.1, patience=5
)

scaler = torch.cuda.amp.GradScaler()


# =========================================
# 7Ô∏è‚É£ Training loop with GET embeddings
# =========================================
n_epochs = 20
for epoch in range(n_epochs):

    # ---- TRAIN ----
    model.train()
    train_loss = 0.

    for xb, pb, cenc, cdec, getb in tqdm(
        train_loader, desc=f"Epoch {epoch+1}/{n_epochs} (train)"
    ):
        xb   = xb.to(device)
        pb   = pb.to(device)
        cenc = cenc.to(device)
        cdec = cdec.to(device)
        getb = getb.to(device)

        opt.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(dtype=torch.float16):
            recon = model(xb, p=pb, cov_enc=cenc, cov_dec=cdec, get_emb=getb)
            loss = F.mse_loss(recon, xb)

        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()

        train_loss += loss.item() * xb.size(0)

    train_loss /= len(train_loader.dataset)

    # ---- VALIDATION ----
    model.eval()
    val_loss = 0.

    with torch.no_grad():
        for xb, pb, cenc, cdec, getb in tqdm(
            val_loader, desc=f"Epoch {epoch+1}/{n_epochs} (val)"
        ):
            xb   = xb.to(device)
            pb   = pb.to(device)
            cenc = cenc.to(device)
            cdec = cdec.to(device)
            getb = getb.to(device)

            with torch.cuda.amp.autocast(dtype=torch.float16):
                recon = model(xb, p=pb, cov_enc=cenc, cov_dec=cdec, get_emb=getb)
                loss = F.mse_loss(recon, xb)

            val_loss += loss.item() * xb.size(0)

    val_loss /= len(val_loader.dataset)
    sched.step(val_loss)

    print(f"Epoch {epoch+1:02d} | train={train_loss:.6f} | val={val_loss:.6f} | lr={opt.param_groups[0]['lr']:.2e}")


Using device: cuda


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast(dtype=torch.float16):
Epoch 1/20 (train): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 251/251 [06:59<00:00,  1.67s/it]
  with torch.cuda.amp.autocast(dtype=torch.float16):
Epoch 1/20 (val): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [00:55<00:00,  1.35s/it]


Epoch 01 | train=0.068595 | val=0.046865 | lr=1.00e-04


Epoch 2/20 (train):  79%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ  | 198/251 [06:19<01:43,  1.95s/it]

## Evaluation

In [None]:
import os
print(os.getcwd())
import sys, os
sys.path.append(os.path.abspath("/gpfs/home/junxif/xin_lab/multiome"))
from utils.eval import evaluate_model
import importlib
import utils.eval
importlib.reload(utils.eval)

In [None]:
results = evaluate_model(model, test_loader, test_adata, device=device, k=50)
results
