In [None]:
import os 

import pandas as pd
import numpy as np
import scanpy as sc


from DensityFlow import DensityFlow
from DensityFlow.perturb import LabelMatrix
from sklearn.model_selection import train_test_split
from eval_metrics import mmd_eval, r2_score_eval, pearson_eval

import torch
torch.set_float32_matmul_precision("high")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # ä»…æ˜¾ç¤ºGPU 0


import re

def remove_g_suffix(text):
    """
    ç§»é™¤ä»¥ 'g' + æ•°å­—ç»“å°¾çš„éƒ¨åˆ†ï¼ˆå¦‚ "abcg123" â†’ "abc"ï¼‰
    """
    return re.sub(r'g\d+$', '', text)

np.random.seed(42) 

pert_col = 'perturbation'
control_label = 'control'
loss_func = 'poisson'

In [2]:
adata_ = sc.read_h5ad('PapalexiSatija2021_eccite_RNA.h5ad')
sc.pp.filter_genes(adata_, min_cells=10)
sc.pp.normalize_total(adata_)
sc.pp.log1p(adata_)
adata = adata_.copy()
adata.shape

(20729, 17734)

In [3]:
adata.obs[pert_col] = [re.sub(r'g\d+$', '', s) for s in adata.obs[pert_col]]

## Split data into two subsets for train and test

In [4]:
cells_pert = adata[adata.obs[pert_col]!=control_label].obs_names
cells_train, cells_test = train_test_split(cells_pert, test_size= adata.shape[0] // 8)
cells_train = cells_train.tolist() + adata[adata.obs[pert_col]==control_label].obs_names.tolist()
adata_train = adata[cells_train].copy()
adata_test = adata[cells_test].copy()

In [None]:
xs = adata_train.X

lb1 = LabelMatrix()
us1 = lb1.fit_transform(adata_train.obs[pert_col],control_label)
ln1 = lb1.labels_

us = us1 
ln = ln1 
us.shape


(18138, 25)

## Train the model

In [None]:
model = DensityFlow(input_size = xs.shape[1],
                      perturb_size=us.shape[1],
                      loss_func=loss_func,
                      seed=42,
                      use_cuda=True)

ðŸ§¬ DensityFlow Initialized:
   - Codebook size: 15
   - Latent Dimension: 50
   - Gene Dimension: 17734
   - Hidden Dimensions: [512]
   - Device: cuda:0
   - Parameters: 19,685,119


In [7]:
%%time 

model.fit(xs, 
          us=us, 
          num_epochs=200, 
          batch_size=1000, 
          use_jax=True)

Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 200/200 [09:01<00:00,  2.71s/epoch, loss=5692297.0099]

CPU times: user 9min 1s, sys: 1.11 s, total: 9min 2s
Wall time: 9min 2s





In [8]:
DensityFlow.save_model(model, f'densityflow_{loss_func}_model.pt')

Model saved to /home/oem/Workspace/PerturbFlow_Ex/THP1_99_PerturbFlow/densityflow_poisson_model.pt


In [9]:
model = DensityFlow.load_model(f'densityflow_{loss_func}_model.pt')

Model loaded from densityflow_poisson_model.pt
ðŸ§¬ DensityFlow Initialized:
   - Codebook size: 15
   - Latent Dimension: 50
   - Gene Dimension: 17734
   - Hidden Dimensions: [512]
   - Device: cuda:0
   - Parameters: 19,685,119


## Prediction for test data

In [10]:
adata_test = adata_test[adata_test.obs[pert_col].isin(ln)]

In [11]:
def predict_pert_effect(ad,pert):
    ad = ad.copy()
    xs_pert = ad.X.toarray()
    zs_basal = model.get_basal_embedding(xs_pert, show_progress=False)
    us_pert = np.ones([xs_pert.shape[0],1])

    ind = int(np.where(ln==pert)[0])
    dzs = model.get_cell_shift(ad.X.toarray(), perturb_idx=ind, perturb_us=us_pert, show_progress=False)
    
    counts = model.get_counts(zs_basal+dzs, library_sizes=ad.X.sum(1), show_progress=False)
    return counts.copy()


## Evaluation

In [12]:
adata_control = adata_train[adata_train.obs[pert_col]==control_label].copy()

In [None]:
results = []
pert_sets = adata_test.obs[pert_col].unique().tolist()
i = 0
for pert in pert_sets:
    i += 1
    
    ad_test = adata_test[adata_test.obs[pert_col]==pert].copy()
    xs_test = ad_test.X.toarray()
    
    ind = np.random.choice(np.arange(adata_control.shape[0]), size=ad_test.shape[0], replace=True)
    ad_ctrl = adata_control[ind].copy()
    ad_ctrl.obs_names_make_unique()
    xs_basal = ad_ctrl.X.toarray()
    
    xs_test_pred = predict_pert_effect(ad_test, pert)
    
    mmd_value=mmd_eval(xs_test_pred, xs_test)
    r2 = r2_score_eval(xs_test_pred, xs_test)
    pr = pearson_eval(xs_test_pred-xs_basal,xs_test-xs_basal)
    
    if i%5==0:
        print(f'{i}/{len(pert_sets)} - mmd:{mmd_value}; r2:{r2}; pearson:{pr}')
        
    results.append({'mmd':mmd_value,'r2':r2,'pearson':pr})
    

5/25 - mmd:7.745185048868564e-10; r2:0.9952535629272461; pearson:0.9398359656333923
10/25 - mmd:6.91945833963814e-10; r2:0.9954155683517456; pearson:0.7280877828598022
15/25 - mmd:0.0; r2:0.9951648116111755; pearson:0.885901153087616
20/25 - mmd:0.0; r2:0.9963154196739197; pearson:0.7236068844795227
25/25 - mmd:0.0; r2:0.9390547275543213; pearson:0.7179104089736938


In [14]:
df = pd.DataFrame(results)
df.mean(axis=0)

mmd        1.317307e-08
r2         9.873288e-01
pearson    7.527656e-01
dtype: float64