In [19]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
import torch
import torch.nn as nn
import numpy as np
import pickle
import torch
import scanpy as sc
from sklearn.preprocessing import MinMaxScaler
from itertools import combinations
from tqdm import tqdm
import matplotlib.pyplot as plt
import geomloss
import sys
import argparse

from model.stDiff_model import DiT_stDiff
from utils import save_log_file, save_model, SinkhornLoss


In [21]:
'''
python train_scDYff_DiT.py --dataset DR --ntps 11 --batch_size 1024
python train_scDYff_DiT.py --dataset EB --ntps 5 --batch_size 2048
python train_scDYff_DiT.py --dataset MB --ntps 13 --batch_size 1024
python train_scDYff_DiT.py --dataset MP --ntps 4 --batch_size 2048
python train_scDYff_DiT.py --dataset ZB --ntps 12 --batch_size 512

'''

para_dict = {
    'DR': (11, 1024, [4,6,8]),
    'EB': (5, 2048, [2,]),
    'MB': (13, 1024, [4,6,8]),
    'MP': (4, 2048, [2,]),
    'ZB': (12, 512, [4,6,8]),
}


dataset_name = 'DR'
dataset_ntps = para_dict[dataset_name][0]
batch_size = para_dict[dataset_name][1]
test_label_list = para_dict[dataset_name][2]
label_list = list(range(dataset_ntps))

In [22]:
gene_num = 50
depth = 6
hidden_size = 512
head = 16


model = DiT_stDiff(
    input_size=gene_num * 2,
    output_size=gene_num,
    hidden_size=hidden_size,
    depth=depth,
    num_heads=head,
    classes=6,
    dit_type='dit',
    mlp_ratio=4.0,
)

device = torch.device('cuda:0')
model.to(device)

# 加载模型
epochs = 1400
model_path = f'/home/hanyuji/Results/scDYff/{dataset_name}/model_inter_{dataset_name}_{epochs}epochs.pt'
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [23]:
dataset_path = (
    f'/home/hanyuji/Results/VAE_result/data_latent/{dataset_name}_latent_50.pkl'
)

with open(dataset_path, 'rb') as f:
    data_list = pickle.load(f)


In [24]:
model.eval()
pred = {}
for test_tp in test_label_list:
    t1 = test_tp - 1
    t3 = test_tp + 1
    
    x1 = data_list[t1]
    x3 = data_list[t3]
    cell_idx_1 = np.random.choice(
        np.arange(x1.shape[0]), size=batch_size, replace=(x1.shape[0] < batch_size)
    )
    cell_idx_3 = np.random.choice(
        np.arange(x3.shape[0]), size=batch_size, replace=(x3.shape[0] < batch_size)
    )
    x1 = x1[cell_idx_1, :]
    x3 = x3[cell_idx_3, :]
    x13 = np.concatenate([x1, x3], axis=1)  # batchsize, 100


    x13 = torch.tensor(x13).type(torch.float32).to(device)
    
    t = (test_tp-t1) / (t3-t1)
    t = torch.tensor(np.full(x1.shape[0], t)).type(torch.float32).to(device)
    
    
    x2_pred = model(x13, t=t)
    pred[test_tp] = x2_pred.detach().cpu().numpy()


In [25]:
result_path = f'/home/hanyuji/Results/scDYff/interpolation_latent/{dataset_name}_result_dict_50_latent.pt'
with open(result_path, 'wb') as f:
    pickle.dump(pred, f)

In [26]:
from scipy.spatial.distance import cdist

# 评估结果
for i in test_label_list:
    x_pred = pred[i]
    x_pred = data_list[i + 1]
    x_true = data_list[i]

    l2_dist = cdist(x_true, x_pred, metric="euclidean")
    cos_dist = cdist(x_true, x_pred, metric="cosine")
    corr_dist = cdist(x_true, x_pred, metric="correlation")
    avg_l2 = l2_dist.sum() / np.prod(l2_dist.shape)
    avg_cos = cos_dist.sum() / np.prod(cos_dist.shape)
    avg_corr = corr_dist.sum() / np.prod(corr_dist.shape)


    ot = SinkhornLoss(
        torch.tensor(x_pred).type(torch.float32).to(device),
        torch.tensor(x_true).type(torch.float32).to(device),
    ).item()
    # l2 = nn.MSELoss(x_pred, x_true)
    print(f'ot: {ot}, l2: {avg_l2}, cos: {avg_cos}, corr: {avg_corr}')

ot: 26.12495231628418, l2: 9.66082114150935, cos: 0.9644107625787834, corr: 0.9648374750487079
ot: 24.935546875, l2: 9.82725093060234, cos: 0.9872741440194381, corr: 0.9870808272993952
ot: 25.15021514892578, l2: 9.947431277149597, cos: 0.9872738788033648, corr: 0.9870056809217979


In [27]:
with open(result_path, 'rb') as f:
    loaded_dict = pickle.load(f)

loaded_dict.keys()

dict_keys([4, 6, 8])