In [18]:
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import pandas as pd
from scipy import stats
import torch
import sys
sys.path.append('..')
from VAE.VAE_model import VAE
from torch.autograd import Variable

In [19]:
def load_VAE():
    autoencoder = VAE(
        num_genes=19423,
        device='cuda',
        seed=0,
        loss_ae='mse',
        hidden_dim=128,
        decoder_activation='ReLU',
    )
    autoencoder.load_state_dict(torch.load('/home/workplace/cfDiffusion/checkpoint/VAE/wot/model_seed=0_step=800000.pt'))
    return autoencoder

real data

In [20]:
adata = sc.read_h5ad('/home/workplace/cfDiffusion/dataset/WOT_dataset/WOT_filted_data.h5ad')
adata = adata[np.where(np.in1d(adata.obs['period'], ['D0','D0.5','D1','D1.5','D2','D2.5','D3','D4.5','D5','D5.5','D6','D6.5','D7','D7.5','D8']))[0]]
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata = adata[np.where(adata.obs['period'].values.isnull()==0)[0]]#[::5]
adata.var_names_make_unique()
gene_names = adata.var_names
cell_data = adata.X
cell_data.shape

(82920, 19423)

In [21]:
82920/15

5528.0

In [22]:
22879/31

738.0322580645161

conditioned generated data

In [23]:
device = torch.device('cuda:0')
device

device(type='cuda', index=0)

In [24]:
cato = ['D0', 'D0.5', 'D1', 'D1.5', 'D2', 'D2.5', 'D3', 'D4.5', 'D5', 'D5.5', 'D6', 'D6.5',
        'D7', 'D7.5', 'D8']

cell_gen = []
gen_class = []
length_per_type = 5600

for i in range(15):
    npyfile=np.load(f'/home/workplace/cfDiffusion/generation/wot/cell{i}_cache5.npy',allow_pickle=True)
    cell_gen.append(npyfile[:length_per_type])
    gen_class+=[cato[i]]*length_per_type
    
cell_gen = np.concatenate(cell_gen,axis=0)

autoencoder = load_VAE().to(device)
cell_gen = autoencoder(torch.tensor(cell_gen).to(device), return_decoded=True).cpu().detach().numpy()

sim_adata = ad.AnnData(X=cell_gen)
sim_adata.obs['period'] = gen_class

In [25]:
cell_gen.shape

(84000, 19423)

correlation

In [12]:
print('spearman=',stats.spearmanr(cell_data.mean(axis=0), cell_gen.mean(axis=0)).correlation)
print('pearson=',np.corrcoef(cell_data.mean(axis=0), cell_gen.mean(axis=0))[0][1])

spearman= 0.9856944300753281
pearson= 0.9933166698867822


Wasserstein

In [13]:
from scipy.stats import wasserstein_distance
distance = wasserstein_distance(cell_data.mean(axis=0), cell_gen.mean(axis=0))
print(distance)

0.016375659741300917


MMD

In [26]:
import torch

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    '''
    将源域数据和目标域数据转化为核矩阵, 即上文中的K
    Params: 
	    source: 源域数据(n * len(x))
	    target: 目标域数据(m * len(y))
	    kernel_mul: 
	    kernel_num: 取不同高斯核的数量
	    fix_sigma: 不同高斯核的sigma值
	Return:
		sum(kernel_val): 多个核矩阵之和
    '''
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)

    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))

    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))

    L2_distance = ((total0-total1)**2).sum(2) 

    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)

    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]

    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]

    return sum(kernel_val)

def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
        kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)

    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY -YX)
    return loss

In [27]:
adata = np.concatenate((cell_data, cell_gen),axis=0)
adata = ad.AnnData(adata, dtype=np.float32)
adata.obs_names = [f"true_Cell" for i in range(cell_data.shape[0])]+[f"gen_Cell" for i in range(cell_gen.shape[0])]

In [28]:
sc.tl.pca(adata, svd_solver='arpack')
real = adata[adata.obs_names=='true_Cell'].obsm['X_pca'][::2][:5000] # can not be set too large, the kernel might fail
gen = adata[adata.obs_names=='gen_Cell'].obsm['X_pca'][::2][:5000]
X = torch.Tensor(real)
Y = torch.Tensor(gen)
X, Y = Variable(X), Variable(Y)
print(mmd_rbf(X, Y))

tensor(2.6222)


scib(ILISI)

In [29]:
import scib
adata = np.concatenate((cell_data, cell_gen),axis=0)
adata = ad.AnnData(adata, dtype=np.float32)
adata.obs['batch'] = pd.Categorical([f"true_Cell" for i in range(cell_data.shape[0])]+[f"gen_Cell" for i in range(cell_gen.shape[0])])
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=20)
scib.me.ilisi_graph(adata, batch_key="batch", type_="knn")

         Falling back to preprocessing with `sc.pp.pca` and default params.




0.8198410845510322

knn

In [30]:
from sklearn.neighbors import KNeighborsClassifier  
from sklearn.model_selection import  train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score

def knn_classify(adata):
    real = adata[adata.obs_names=='true_Cell'].X.toarray()#.obsm['X_pca']
    sim = adata[adata.obs_names=='gen_Cell'].X.toarray()#.obsm['X_pca']#

    data = np.concatenate((real,sim),axis=0)
    label = np.concatenate((np.ones((real.shape[0])),np.zeros((sim.shape[0]))))

    knn_classifier = KNeighborsClassifier(n_neighbors=5)  
    
    ##将训练集切分为训练集和验证集
    X_train,X_val,y_train,y_val = train_test_split(data, label,
                                                test_size = 0.3,random_state = 1)
    knn_classifier.fit(X_train, y_train)
    predicted_label = knn_classifier.predict(X_val)
    # print((predicted_label==y_val).sum()/X_val.shape[0])
    accuracy = accuracy_score(predicted_label, y_val)

    # 算AUC
    predicted_probabilities = knn_classifier.predict_proba(X_val)[:, 1]  
    
    # 计算AUC，只适用于二分类问题  
    # AUC需要真实标签和正类的预测概率  
    auc = roc_auc_score(y_val, predicted_probabilities)  
    print(f"AUC: {auc}, Accuracy: {accuracy}") 

    return accuracy, auc

In [31]:
adata = sc.read_h5ad('/home/zqzhao/workplace/cfDiffusion/dataset/WOT_dataset/WOT_filted_data.h5ad')
adata = adata[np.where(np.in1d(adata.obs['period'], ['D0','D0.5','D1','D1.5','D2','D2.5','D3','D4.5','D5','D5.5','D6','D6.5','D7','D7.5','D8']))[0]]
adata.var_names_make_unique()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata = adata[np.where(adata.obs['period'].values.isnull()==0)[0]]#[::5]

gene_names = adata.var_names
cell_data = adata.X
cell_data.shape

(82920, 19423)

In [34]:
cato = ['D0', 'D0.5', 'D1', 'D1.5', 'D2', 'D2.5', 'D3', 'D4.5', 'D5', 'D5.5', 'D6', 'D6.5',
        'D7', 'D7.5', 'D8']
knn_acc = []
knn_auc = []
cell_gen = []
gen_class = []
index2 = list(range(15))
length_per_type = 5600

for i in range(15):
    npyfile=np.load(f'/home/zqzhao/workplace/cfDiffusion/generation/wot/cell{i}_cache5.npy',allow_pickle=True)
    cell_gen.append(npyfile[:length_per_type])
    # gen_class+=['gen_'+cato[i]]*length_per_type
cell_gen = np.concatenate(cell_gen,axis=0)
print(cell_gen.shape)

autoencoder = load_VAE()
cell_gen = autoencoder(torch.tensor(cell_gen).cuda(),return_decoded=True).cpu().detach().numpy()

for i in range(15):
    cell_diff = cell_gen[i*length_per_type:(i+1)*length_per_type]
    ori = ad.AnnData(cell_diff, dtype=np.float32)
    ori.var_names = gene_names

    length = min(adata[adata.obs['period'] == cato[i]].X.toarray().shape[0],length_per_type)

    adata1 = ad.concat((adata[adata.obs['period'] == cato[i]][:length],ori[:length]))
    adata1.obs_names = [f"true_Cell" for i in range(length)]+[f"gen_Cell" for i in range(ori[:length].X.shape[0])]

    sc.tl.pca(adata1, svd_solver='arpack')
    acc, auc = knn_classify(adata1)
    knn_acc.append(acc)
    knn_auc.append(auc)
print(np.mean(knn_acc))
print(knn_acc)
print(knn_auc)

(84000, 128)




AUC: 0.5, Accuracy: 0.49602601156069365




AUC: 0.5, Accuracy: 0.5154589371980677




AUC: 0.5, Accuracy: 0.5135869565217391




AUC: 0.5016474464579901, Accuracy: 0.4829642248722317




AUC: 0.5, Accuracy: 0.4973214285714286




AUC: 0.5, Accuracy: 0.4973214285714286




AUC: 0.5, Accuracy: 0.4973214285714286




AUC: 0.5002960331557135, Accuracy: 0.4973214285714286




AUC: 0.5, Accuracy: 0.4973214285714286




AUC: 0.5, Accuracy: 0.4973214285714286




AUC: 0.5840734162226169, Accuracy: 0.4973214285714286




AUC: 0.5, Accuracy: 0.502410283877879




AUC: 0.5, Accuracy: 0.4973214285714286




AUC: 0.5003287310979618, Accuracy: 0.4991768192295028




AUC: 0.5004317789291882, Accuracy: 0.5072340425531915
0.49969524695898226
[0.49602601156069365, 0.5154589371980677, 0.5135869565217391, 0.4829642248722317, 0.4973214285714286, 0.4973214285714286, 0.4973214285714286, 0.4973214285714286, 0.4973214285714286, 0.4973214285714286, 0.4973214285714286, 0.502410283877879, 0.4973214285714286, 0.4991768192295028, 0.5072340425531915]
[0.5, 0.5, 0.5, 0.5016474464579901, 0.5, 0.5, 0.5, 0.5002960331557135, 0.5, 0.5, 0.5840734162226169, 0.5, 0.5, 0.5003287310979618, 0.5004317789291882]


: 