In [2]:
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import pandas as pd
from scipy import stats
import torch
import sys
sys.path.append('..')
from VAE.VAE_model import VAE
from torch.autograd import Variable
import celltypist

In [3]:
def load_VAE():
    autoencoder = VAE(
        num_genes=18996,
        device='cuda',
        seed=0,
        loss_ae='mse',
        hidden_dim=128,
        decoder_activation='ReLU',
    )
    autoencoder.load_state_dict(torch.load('/data1/lep/Workspace/guided-diffusion/VAE/checkpoint/muris_scimilarity_lognorm_finetune/model_seed=0_step=150000.pt'))
    return autoencoder

real data

In [4]:
adata = sc.read_h5ad('/data1/lep/Workspace/guided-diffusion/data/tabula_muris/all.h5ad')
adata.var_names_make_unique()
sc.pp.filter_cells(adata, min_genes=10)
sc.pp.filter_genes(adata, min_cells=3)
gene_names = adata.var_names

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
cell_data = adata.X.toarray()[::5]

cell_data.shape

(11401, 18996)

unconditional generated data

In [23]:
# the generated data path
npzfile=np.load('/data1/lep/Workspace/guided-diffusion/output/muris_scimilarity.npz',allow_pickle=True)

cell_gen_all = npzfile['cell_gen'][:10000]

autoencoder = load_VAE()
cell_gen_all = autoencoder(torch.tensor(cell_gen_all).cuda(),return_decoded=True).detach().cpu().numpy()
ori = ad.AnnData(cell_gen_all, dtype=np.float32)
cell_gen = ori.X
cell_gen.shape

(10000, 18996)

correlation

In [14]:
print('spearman=',stats.spearmanr(cell_data.mean(axis=0), cell_gen.mean(axis=0)).correlation)
print('pearson=',np.corrcoef(cell_data.mean(axis=0), cell_gen.mean(axis=0))[0][1])

spearman= 0.9944149643330309
pearson= 0.9986732924303006


MMD

In [15]:
import torch

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    '''
    将源域数据和目标域数据转化为核矩阵, 即上文中的K
    Params: 
	    source: 源域数据(n * len(x))
	    target: 目标域数据(m * len(y))
	    kernel_mul: 
	    kernel_num: 取不同高斯核的数量
	    fix_sigma: 不同高斯核的sigma值
	Return:
		sum(kernel_val): 多个核矩阵之和
    '''
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)

    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))

    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))

    L2_distance = ((total0-total1)**2).sum(2) 

    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)

    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]

    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]

    return sum(kernel_val)

def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
        kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)

    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY -YX)
    return loss

In [24]:
adata = np.concatenate((cell_data, cell_gen),axis=0)
adata = ad.AnnData(adata, dtype=np.float32)
adata.obs_names = [f"true_Cell" for i in range(cell_data.shape[0])]+[f"gen_Cell" for i in range(cell_gen.shape[0])]

In [25]:
sc.tl.pca(adata, svd_solver='arpack')
real = adata[adata.obs_names=='true_Cell'].obsm['X_pca'][::2][:5000] # can not be set too large, the kernel might fail
gen = adata[adata.obs_names=='gen_Cell'].obsm['X_pca'][::2][:5000]
X = torch.Tensor(real)
Y = torch.Tensor(gen)
X,Y = Variable(X), Variable(Y)
print(mmd_rbf(X,Y))

tensor(0.0822)


scib

In [18]:
import scib
adata = np.concatenate((cell_data, cell_gen),axis=0)
adata = ad.AnnData(adata, dtype=np.float32)
adata.obs['batch'] = pd.Categorical([f"true_Cell" for i in range(cell_data.shape[0])]+[f"gen_Cell" for i in range(cell_gen.shape[0])])
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=20)
scib.me.ilisi_graph(adata, batch_key="batch", type_="knn")

         Falling back to preprocessing with `sc.pp.pca` and default params.


0.6955686406343145

celltypist for conditional generation

In [8]:
# if not generated all type of cells, use the real cell to balance the batchnorm in the scimilarity
adata_w = adata.copy()[::5].X.toarray()

autoencoder = load_VAE()
cell_w = autoencoder(torch.tensor(adata_w).cuda(),return_latent=True).detach().cpu().numpy()

# concat this cell_w with cell_gen and send them to the autoencoder
# cell_gen_all = autoencoder(torch.tensor(np.concatenate((cell_gen,cell_w),axis=0)).cuda(),return_decoded=True).cpu().detach().numpy()
cell_w.shape

(11401, 128)

In [None]:
# if generated all type of cells, combine them together
cato = ['Bladder', 'Heart_and_Aorta', 'Kidney', 'Limb_Muscle', 'Liver',
       'Lung', 'Mammary_Gland', 'Marrow', 'Spleen', 'Thymus', 'Tongue',
       'Trachea']
index = [0,1,2,3,4,5,6,7,8,9,10,11]
rf = []
diffu_acc = []

cell_gen_all = []
gen_class = []

for i in range(12):
    npzfile=np.load(f'../output/muris_condi/muris_{i}_scimilarity_nodrop.npz',allow_pickle=True)
    length = 1000
    cell_gen_all.append(npzfile['cell_gen'][:int(length)])#.squeeze(1)

    gen_class+=['gen '+cato[i]]*int(length)

cell_gen_all = np.concatenate(cell_gen_all,axis=0)

autoencoder = load_VAE()
cell_gen_all = autoencoder(torch.tensor(cell_gen_all).cuda(),return_decoded=True).cpu().detach().numpy()

In [None]:
import celltypist

accs = []
for i in index:
    cell = cell_gen_all[i*1000:(i+1)*1000]
    ori = ad.AnnData(cell, dtype=np.float32)
    ori.var_names = gene_names

    ori.X = (ori.X>np.log1p(10000)) * (np.log1p(10000)-1e-6) + ori.X * (ori.X<np.log1p(10000))
    
    predictions = celltypist.annotate(ori, model = '../checkpoint_old/celltypist_muris_all_re2.pkl')
    acc = (predictions.predicted_labels.squeeze(1).values == cato[i]).sum()/cell.shape[0]
    # print(cato[i],acc)
    accs.append((cato[i],acc))
    diffu_acc.append(acc)
print(accs)

🔬 Input data has 1000 cells and 18996 genes
🔗 Matching reference genes in the model
🧬 18996 features used for prediction
⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
🔬 Input data has 1000 cells and 18996 genes
🔗 Matching reference genes in the model
🧬 18996 features used for prediction
⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
🔬 Input data has 1000 cells and 18996 genes
🔗 Matching reference genes in the model
🧬 18996 features used for prediction
⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
🔬 Input data has 1000 cells and 18996 genes
🔗 Matching reference genes in the model
🧬 18996 features used for prediction
⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
🔬 Input data has 1000 cells and 18996 genes
🔗 Matching reference genes in the model
🧬 18996 features used for prediction
⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
🔬 Input data has 1000 cells and 18996 genes
🔗 Matching reference genes in the model
🧬

[('Bladder', 0.987), ('Heart_and_Aorta', 0.665), ('Kidney', 0.915), ('Limb_Muscle', 0.917), ('Liver', 0.992), ('Lung', 0.941), ('Mammary_Gland', 0.899), ('Marrow', 0.953), ('Spleen', 0.996), ('Thymus', 0.925), ('Tongue', 0.996), ('Trachea', 0.983)]


knn

In [None]:
from sklearn.neighbors import KNeighborsClassifier  
from sklearn.model_selection import  train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score

def knn_classify(adata):
    real = adata[adata.obs_names=='true_Cell'].X.toarray()#.obsm['X_pca']
    sim = adata[adata.obs_names=='gen_Cell'].X.toarray()#.obsm['X_pca']#

    data = np.concatenate((real,sim),axis=0)
    label = np.concatenate((np.ones((real.shape[0])),np.zeros((sim.shape[0]))))

    knn_classifier = KNeighborsClassifier(n_neighbors=5)  
    
    ##将训练集切分为训练集和验证集
    X_train,X_val,y_train,y_val = train_test_split(data, label,
                                                test_size = 0.3,random_state = 1)
    knn_classifier.fit(X_train, y_train)
    predicted_label = knn_classifier.predict(X_val)
    # print((predicted_label==y_val).sum()/X_val.shape[0])
    accuracy = accuracy_score(predicted_label, y_val)

    # 算AUC
    predicted_probabilities = knn_classifier.predict_proba(X_val)[:, 1]  
    
    # 计算AUC，只适用于二分类问题  
    # AUC需要真实标签和正类的预测概率  
    auc = roc_auc_score(y_val, predicted_probabilities)  
    print(f"AUC: {auc}, Accuracy: {accuracy}") 

    return accuracy, auc

In [None]:
cato = ['Bladder', 'Heart_and_Aorta', 'Kidney', 'Limb_Muscle', 'Liver',
       'Lung', 'Mammary_Gland', 'Marrow', 'Spleen', 'Thymus', 'Tongue',
       'Trachea']
knn_acc = []
knn_auc = []
cell_gen_all = []
gen_class = []
index2 = list(range(12))
length_per_type = 1000

for i in range(12):
    npzfile=np.load(f'../output/muris_condi/muris_{i}_scimilarity_nodrop.npz',allow_pickle=True)
    cell_gen_all.append(npzfile['cell_gen'][:length_per_type])
    gen_class+=['gen '+cato[i]]*length_per_type
cell_gen_all = np.concatenate(cell_gen_all,axis=0)
# print(cell_gen_all.shape)

autoencoder = load_VAE()
cell_gen_all = autoencoder(torch.tensor(cell_gen_all).cuda(),return_decoded=True).cpu().detach().numpy()

for i in range(12):
    cell_diff = cell_gen_all[i*length_per_type:(i+1)*length_per_type]
    ori = ad.AnnData(cell_diff, dtype=np.float32)
    ori.var_names = gene_names

    length = min(adata[adata.obs['celltype'] == cato[i]].X.toarray().shape[0],length_per_type)

    adata1 = ad.concat((adata[adata.obs['celltype'] == cato[i]][:length],ori[:length]))
    adata1.obs_names = [f"true_Cell" for i in range(length)]+[f"gen_Cell" for i in range(ori[:length].X.shape[0])]

    sc.tl.pca(adata1, svd_solver='arpack')
    acc, auc = knn_classify(adata1)
    knn_acc.append(acc)
    knn_auc.append(auc)
print(np.mean(knn_acc))
print(knn_acc)
print(knn_auc)

AUC: 0.5050333892598806, Accuracy: 0.5016666666666667
AUC: 0.5, Accuracy: 0.48091603053435117
AUC: 0.5083612040133779, Accuracy: 0.5016666666666667
AUC: 0.5066889632107023, Accuracy: 0.5016666666666667
AUC: 0.5083612040133779, Accuracy: 0.5016666666666667
AUC: 0.5066889632107023, Accuracy: 0.5066666666666667
AUC: 0.5033444816053512, Accuracy: 0.5016666666666667
AUC: 0.5083612040133779, Accuracy: 0.5016666666666667
AUC: 0.500016666851854, Accuracy: 0.5016666666666667
AUC: 0.5016722408026756, Accuracy: 0.5016666666666667
AUC: 0.5050167224080268, Accuracy: 0.505


  utils.warn_names_duplicates("obs")


AUC: 0.5016722408026756, Accuracy: 0.5016666666666667
0.5006318914334182
[0.5016666666666667, 0.48091603053435117, 0.5016666666666667, 0.5016666666666667, 0.5016666666666667, 0.5066666666666667, 0.5016666666666667, 0.5016666666666667, 0.5016666666666667, 0.5016666666666667, 0.505, 0.5016666666666667]
[0.5050333892598806, 0.5, 0.5083612040133779, 0.5066889632107023, 0.5083612040133779, 0.5066889632107023, 0.5033444816053512, 0.5083612040133779, 0.500016666851854, 0.5016722408026756, 0.5050167224080268, 0.5016722408026756]
