In [None]:
import os 

import pandas as pd
import numpy as np
import scanpy as sc


from DensityFlow import DensityFlow
from DensityFlow.perturb import LabelMatrix

from eval_metrics import mmd_eval, r2_score_eval, pearson_eval

import torch
torch.set_float32_matmul_precision("high")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # ä»…æ˜¾ç¤ºGPU 0

np.random.seed(42) 


pert_col = 'gene'
control_label = 'NT'
loss_func = 'poisson'

In [2]:
adata_train = sc.read_h5ad('jiang2025_train.h5ad')

In [3]:
sc.pp.normalize_total(adata_train)
sc.pp.log1p(adata_train)

In [None]:


xs = adata_train.X 


lb = LabelMatrix()
us = lb.fit_transform(adata_train.obs[pert_col], control_label=control_label)
ln = lb.labels_

us.shape


(95585, 52)

In [None]:
model = DensityFlow(input_size = xs.shape[1],
                      perturb_size=us.shape[1],
                      z_dist='studentt',
                      loss_func=loss_func,
                      seed=42,
                      use_cuda=True)

ðŸ§¬ DensityFlow Initialized:
   - Codebook size: 15
   - Latent Dimension: 50
   - Gene Dimension: 33525
   - Hidden Dimensions: [512]
   - Device: cuda:0
   - Parameters: 37,309,940


%%time 

model.fit(xs, 
          us=us, 
          num_epochs=100, 
          batch_size=256, 
          use_jax=True)

DensityFlow.save_model(model, f'benchmark_densityflow_{loss_func}_model.pth')

In [6]:
model = DensityFlow.load_model(f'benchmark_densityflow_{loss_func}_model.pth')

Model loaded from benchmark_densityflow_poisson_model.pth
ðŸ§¬ DensityFlow Initialized:
   - Codebook size: 15
   - Latent Dimension: 50
   - Gene Dimension: 33525
   - Hidden Dimensions: [512]
   - Device: cuda:0
   - Parameters: 37,309,940


In [7]:
import datatable as dt 

adata_test = sc.read_h5ad('jiang2025_test.h5ad')
sc.pp.normalize_total(adata_test)
sc.pp.log1p(adata_test)

adata_control = adata_train[adata_train.obs[pert_col]==control_label].copy()
adata_control.shape

(3906, 33525)

In [None]:
def predict_pert_effect(ad, pert_gene):
    zs_basal = model.get_basal_embedding(ad.X.toarray(), show_progress=False)
    
    pert_idx = int(np.where(ln==pert_gene)[0])
    pert_us = np.ones([ad.shape[0], 1])
    dzs = model.get_cell_shift(ad.X.toarray(), pert_idx, pert_us, soft_assign=True, show_progress=False)
    
    ls = ad.X.sum(1)
    counts = model.get_counts(zs_basal + dzs, library_sizes=ls, show_progress=False)
    
    return counts

In [None]:
results = []
pert_sets = adata_test.obs[pert_col].unique().tolist()
i = 0
for pert in pert_sets:
    i += 1
    print(f'{i}/{len(pert_sets)}')
    
    if pert==control_label:
        continue
    
    ad_test = adata_test[adata_test.obs[pert_col]==pert].copy()
    xs_test = ad_test.X.toarray()
    
    ind = np.random.choice(np.arange(adata_control.shape[0]), size=ad_test.shape[0], replace=True)
    ad_ctrl = adata_control[ind].copy()
    ad_ctrl.obs_names_make_unique()
    xs_basal = ad_ctrl.X.toarray()
    
    xs_test_pred = predict_pert_effect(ad_test, pert)
    
    xs_test_pred = xs_test_pred.astype(float)
    xs_test = xs_test.astype(float)
    xs_basal = xs_basal.astype(float)
    
    mmd_value=mmd_eval(xs_test_pred, xs_test)
    r2 = r2_score_eval(xs_test_pred, xs_test)
    pr = pearson_eval(xs_test_pred-xs_basal,xs_test-xs_basal[:xs_test.shape[0]])
    print(f'mmd:{mmd_value}; r2:{r2}; pearson:{pr}')
    results.append({'mmd':mmd_value,'r2':r2,'pearson':pr})

1/53
mmd:0.0074235551406193765; r2:0.7585069107219019; pearson:0.5324237746909043
2/53
mmd:0.007620790687080934; r2:0.7468573105878911; pearson:0.5157430146296869
3/53
mmd:0.008266523615712226; r2:0.7476622175772891; pearson:0.5584762887347271
4/53
mmd:0.00791485099135075; r2:0.7554447180302898; pearson:0.5472891536728275
5/53
mmd:0.009471870250623021; r2:0.7490840786997632; pearson:0.5287144915593031
6/53
mmd:0.009355114990096305; r2:0.7546380838876846; pearson:0.5512254911064944
7/53
mmd:0.009214883556165515; r2:0.752258650081409; pearson:0.534806905095466
8/53
mmd:0.009013795583852293; r2:0.7550503325750668; pearson:0.530836097537226
9/53
mmd:0.008766027190598566; r2:0.7499356750463779; pearson:0.5233299361223853
10/53
mmd:0.0062055747780855; r2:0.7518199773689735; pearson:0.5329656353619386
11/53
mmd:0.008260422215231765; r2:0.7535060732090659; pearson:0.5473975817380365
12/53
mmd:0.007343054519738363; r2:0.7533639498568906; pearson:0.5287828676446708
13/53
mmd:0.009819043923426992

In [10]:
df = pd.DataFrame(results)
df.mean(0)

mmd        0.008421
r2         0.748369
pearson    0.538324
dtype: float64