In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import scanpy as sc
import numpy as np
import pickle

from model_VAE import VAE
from dataloader_VAE import get_h5ad_data, get_dataloader, normalize, inverse_normalize

# 准备数据

In [3]:
dataset_name = 'DR'


para_dict = {
    'DR': 'drosophila_scNODE2_2000genes_2489cells_11tps.h5ad',
    'EB': 'embryoid_scNODE5_2000genes_6232cells_5tps.h5ad',
    'MB': 'mammalian_scNODE1_2000genes_7542cells_13tps.h5ad',
    'MP': 'pancreatic_scNODE4_2000genes_9483cells_4tps.h5ad',
    'ZB': 'zebrafish_scNODE0_2000genes_3227cells_12tps.h5ad',
}


dataset_h5ad = para_dict[dataset_name]

# result_path = f'/home/hanyuji/Results/scDYff/interpolation_latent/{dataset_name}_result_dict_50_latent.pt'
result_path = f'/home/hanyuji/Results/scDYff/interpolation_latent/{dataset_name}_result_dict_50_latent_3000cell.pt'
with open(result_path, 'rb') as f:
    latent_data_dict = pickle.load(f)

if dataset_name in ['MP', 'EB']:
    test_index = [
        2,
    ]
else:
    test_index = [0,1,2,3,4,5,6,7,8,9,10]


# 训练集和测试集对应的下标
test_list = [latent_data_dict[i] for i in test_index]
test_dataloader = get_dataloader(test_list, test_index, batch_size=200)

In [4]:
# 准备数据缩放因子
data_list = get_h5ad_data(dataset_h5ad)

norm_data_list, scalers = normalize(data_list)

# 模型生成

In [5]:
# 加载模型
device = "cuda:0"
net = VAE().to(device)
model_path = f'/home/hanyuji/Results/VAE_result/model_para/vae_model_0604_{dataset_name}_all.pt'
net.load_state_dict(torch.load(model_path))


<All keys matched successfully>

In [6]:
recon_dict = {}  # 12个数组

# {4: [], 6: [], 8: []}
for index in test_index:
    recon_dict[index] = []

In [7]:
net.eval()
for (x, y) in test_dataloader:
    
    x = x.float().to(device)
    recon = net.decoder(x)
    
    y_np = y.detach().cpu().numpy()
    recon_np = recon.detach().cpu().numpy()
                
    for (recon_i,y_i) in zip(recon_np,y_np):
        recon_dict[y_i].append(recon_i)


In [8]:
recon_list = []
    
for index, arr in recon_dict.items():
    recon_list.append(np.asarray(arr))

inverse_norm_recon_list = inverse_normalize(recon_list, scalers)

In [9]:
# 保存这些数组到一个文件

file_name = f'{dataset_name}_2000_recon_3000cell.pkl'
with open('/home/hanyuji/Results/VAE_result/data_recon_0606/'+file_name, 'wb') as f:
    pickle.dump(inverse_norm_recon_list, f)


In [35]:
from scipy.spatial.distance import cdist
import geomloss

print(dataset_name,'scDYff_DiT')

for index_recon, index_true in zip(range(len(test_index)),test_index):
    # 评估结果
    x_pred = inverse_norm_recon_list[index_recon]
    # x_pred = data_list[index_true-1]  # naive method
    
    x_true = data_list[index_true]

    l2_dist = cdist(x_true, x_pred, metric="euclidean")
    cos_dist = cdist(x_true, x_pred, metric="cosine")
    corr_dist = cdist(x_true, x_pred, metric="correlation")
    avg_l2 = l2_dist.sum() / np.prod(l2_dist.shape)
    avg_cos = cos_dist.sum() / np.prod(cos_dist.shape)
    avg_corr = corr_dist.sum() / np.prod(corr_dist.shape)


    ot_solver = geomloss.SamplesLoss(
        "sinkhorn", p=2, blur=0.05, scaling=0.5, debias=True, backend="tensorized"
    )
    ot = ot_solver(
        torch.tensor(x_pred).type(torch.float32).to(device),
        torch.tensor(x_true).type(torch.float32).to(device),
    ).item()
    # l2 = nn.MSELoss(x_pred, x_true)

    print(f'ot: {ot}, l2: {avg_l2}, cos: {avg_cos}, corr: {avg_corr}')

MB scDYff_DiT
ot: 109.12289428710938, l2: 18.197466965094875, cos: 0.2345526444695823, corr: 0.2558484355573812
ot: 105.7564697265625, l2: 18.921055496911304, cos: 0.24143947949675806, corr: 0.26514125322703386
ot: 116.9522705078125, l2: 16.774277772741225, cos: 0.1932824699523682, corr: 0.20929105330507125


# Linear Method

In [36]:
# index_true = 4

# batch_size = 2000

# x1 = data_list[index_true - 1]
# x3 = data_list[index_true + 1]
# cell_idx_1 = np.random.choice(
#     np.arange(x1.shape[0]), size=batch_size, replace=(x1.shape[0] < batch_size)
# )
# cell_idx_3 = np.random.choice(
#     np.arange(x3.shape[0]), size=batch_size, replace=(x3.shape[0] < batch_size)
# )
# x1 = x1[cell_idx_1, :]
# x3 = x3[cell_idx_3, :]

# x_pred = (x1+x3)/2


In [37]:
from scipy.spatial.distance import cdist
import geomloss

print(dataset_name, 'Linear Method')

for index_recon, index_true in zip(range(len(test_index)),test_index):
    # 评估结果
    # x_pred = inverse_norm_recon_list[index_recon]

    
    
    batch_size = 2000

    x1 = data_list[index_true - 1]
    x3 = data_list[index_true + 1]
    cell_idx_1 = np.random.choice(
        np.arange(x1.shape[0]), size=batch_size, replace=(x1.shape[0] < batch_size)
    )
    cell_idx_3 = np.random.choice(
        np.arange(x3.shape[0]), size=batch_size, replace=(x3.shape[0] < batch_size)
    )
    x1 = x1[cell_idx_1, :]
    x3 = x3[cell_idx_3, :]

    x_pred = (x1+x3)/2

    
    
    
    
    
    
    
    
    
    
    
    x_true = data_list[index_true]

    l2_dist = cdist(x_true, x_pred, metric="euclidean")
    cos_dist = cdist(x_true, x_pred, metric="cosine")
    corr_dist = cdist(x_true, x_pred, metric="correlation")
    avg_l2 = l2_dist.sum() / np.prod(l2_dist.shape)
    avg_cos = cos_dist.sum() / np.prod(cos_dist.shape)
    avg_corr = corr_dist.sum() / np.prod(corr_dist.shape)


    ot_solver = geomloss.SamplesLoss(
        "sinkhorn", p=2, blur=0.05, scaling=0.5, debias=True, backend="tensorized"
    )
    ot = ot_solver(
        torch.tensor(x_pred).type(torch.float32).to(device),
        torch.tensor(x_true).type(torch.float32).to(device),
    ).item()
    # l2 = nn.MSELoss(x_pred, x_true)

    print(f'ot: {ot}, l2: {avg_l2}, cos: {avg_cos}, corr: {avg_corr}')

MB Linear Method
ot: 146.31344604492188, l2: 20.65794167020932, cos: 0.28849095546102554, corr: 0.3142006418408773
ot: 142.38583374023438, l2: 20.88380599968904, cos: 0.2828053162285965, corr: 0.30938833702307816
ot: 158.1946563720703, l2: 20.40762634291961, cos: 0.2699596243096057, corr: 0.29267287460356545


# Naive Method

In [38]:
# from scipy.spatial.distance import cdist
# import geomloss

# print(dataset_name, 'Naive Method')

# for index_recon, index_true in zip(range(len(test_index)),test_index):
#     # 评估结果
#     # x_pred = inverse_norm_recon_list[index_recon]
#     x_pred = data_list[index_true-1]  # naive method
    
#     x_true = data_list[index_true]

#     l2_dist = cdist(x_true, x_pred, metric="euclidean")
#     cos_dist = cdist(x_true, x_pred, metric="cosine")
#     corr_dist = cdist(x_true, x_pred, metric="correlation")
#     avg_l2 = l2_dist.sum() / np.prod(l2_dist.shape)
#     avg_cos = cos_dist.sum() / np.prod(cos_dist.shape)
#     avg_corr = corr_dist.sum() / np.prod(corr_dist.shape)


#     ot_solver = geomloss.SamplesLoss(
#         "sinkhorn", p=2, blur=0.05, scaling=0.5, debias=True, backend="tensorized"
#     )
#     ot = ot_solver(
#         torch.tensor(x_pred).type(torch.float32).to(device),
#         torch.tensor(x_true).type(torch.float32).to(device),
#     ).item()
#     # l2 = nn.MSELoss(x_pred, x_true)

#     print(f'ot: {ot}, l2: {avg_l2}, cos: {avg_cos}, corr: {avg_corr}')