In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import torch
import sys
import argparse
from options import Options
from dataset import ImputationDataset
from dataset import ImputationNano
from tqdm import tqdm
import numpy as np
import scanpy as sc
import pandas as pd
from torch.utils.data import Dataset
from utils import seed_everything, CalculateMeteics

# from model.SpaIM import ImputeModule
from model.model_nano_1layers import ImputeModule

sys.argv = ['your_script.py', '--kfold', '0', '--lr', '1e-3']
opt = Options().parse()

In [2]:
sp_path = '/data/boom/SpaIM/dataset/nano9-1/Insitu_count.h5ad'
sc_path = '/data/boom/SpaIM/dataset/nano9-1/scRNA_count_cluster.h5ad'

SP_adata = sc.read(sp_path)
SC_adata = sc.read(sc_path)

print(SP_adata, '\n', SC_adata)

# (67798, 3482),　SP index and SC d gene

AnnData object with n_obs × n_vars = 77890 × 928
    obs: 'cell_type', 'niche', 'cx', 'cy', 'cx_g', 'cy_g', 'merge_cell_type', 'cell_type_map', 'label'
    obsm: 'spatial', 'spatial_global' 
 AnnData object with n_obs × n_vars = 50517 × 3495
    obs: 'merge_cell_type', 'leiden_0.00', 'leiden_0.01', 'leiden_0.0005', 'leiden_0.005', 'leiden_0.02', 'leiden_0.10', 'leiden_0.30', 'leiden_0.50', 'leiden_0.70', 'leiden_0.90'


In [3]:
#  重写dataload数据仅加载 val 部分

class ImputationNano(Dataset):
    def __init__(self, opt, istrain='val'):
        self.opt = opt
        print('opt.seq_data', opt.seq_data)
        self.seq_adata = self.load_data(opt.seq_data)
        print('shape', self.seq_adata.shape)

        self.spa_data = self.load_data(opt.spa_data)  # 基因名

        if 'leiden' in opt.cluster:
            self.seq_cluster = self.seq_adata.obs[opt.cluster].cat.codes.values
        elif opt.cluster == 'annotate':
            self.seq_cluster = self.seq_adata.obs['merge_cell_type'].cat.codes.values  # 对应的是nano的设置，需要改 seq_data 的数据源
        else:
            self.seq_cluster = [1]

        self.style_dim = opt.style_dim
        self.istrain = istrain

        self.seq_data = self.aggreate_cell_types(self.seq_adata)

        # train_gene = np.load(os.path.join(opt.root, opt.dataset_name, 'train_list.npy'), allow_pickle=True).tolist()[opt.kfold] # 11 right
        # test_gene = np.load(os.path.join(opt.root, opt.dataset_name, 'test_list.npy'), allow_pickle=True).tolist()[opt.kfold]
        # val_gene = np.load(os.path.join(opt.root, opt.dataset_name, 'val_list.npy'), allow_pickle=True).tolist()[opt.kfold]

        # some of the gene is filtered because of the low expression
        # val_gene = list(set(val_gene) & set(self.seq_data.var.index) & set(self.spa_data.var.index))
        # val_gene = list(set(val_gene) & set(self.seq_data.var.index))
        val_gene = list(set(self.seq_data.var.index))

        print('len(val_gene)', len(val_gene))  # 3482
        # print(self.spa_train.shape, self.seq_train.shape)

        self.seq_val = self.seq_data[:, val_gene].copy().T
        print(self.seq_val)  #　AnnData object with n_obs × n_vars = 3482 × 6

    def get_cluster_dim(self):
        return len(set(self.seq_cluster))

    def run_leiden(self):
        adata_label = self.seq_data.copy()
        sc.pp.highly_variable_genes(adata_label)
        adata_label = adata_label[:, adata_label.var.highly_variable]
        sc.pp.scale(adata_label, max_value=10)
        # sc.pp.scale(adata_label)
        sc.tl.pca(adata_label)
        sc.pp.neighbors(adata_label)
        sc.tl.leiden(adata_label, resolution=0.5)
        # sc.tl.leiden(adata_label)
        self.seq_data.obs['leiden'] = adata_label.obs['leiden']
        # self.seq_data.write('dataset/nanostring/seq_all_cluster.h5ad')
    
    def get_eval_names(self):
        return self.spa_data.obs_names, self.seq_adata.var_names
    
    def aggreate_cell_types(self, adata):
        if 'leiden' in self.opt.cluster or self.opt.cluster == 'annotate': # True
            # print('aggreate leiden or annotation')
            x = adata.X
            num_cls = len(set(self.seq_cluster))
            new_x = np.zeros((num_cls, x.shape[1]))
            for i in list(set(self.seq_cluster)):
                new_x[i] = np.mean(x[self.seq_cluster == i], axis=0)
            df = pd.DataFrame(new_x, columns=adata.var.index)
            new_adata = sc.AnnData(df)
            new_adata.var.index = adata.var.index
            # print(new_adata.shape)  # (11, 17040)
            return new_adata
        else:
            x = adata.X
            new_x = np.zeros((1, x.shape[1]))
            new_x[0] = np.mean(x, axis=0)
            df = pd.DataFrame(new_x, columns=adata.var.index)
            new_adata = sc.AnnData(df)
            new_adata.var.index = adata.var.index
            return new_adata

    def load_data(self, root):
        adata = sc.read(root)
        sc.pp.filter_genes(adata, min_cells=3)
        sc.pp.filter_cells(adata, min_genes=3)
        # sc.pp.normalize_total(adata)
        if not "log1p" in adata.uns_keys():
            sc.pp.log1p(adata)
        
        # adata.var.index = adata.var.index.str.lower()
        return adata
    
    def __len__(self):
        if self.istrain == 'val':
            return self.seq_val.shape[0]
        else:
            return self.seq_test.shape[0]
    
    def __getitem__(self, index):
        st_style = torch.ones(self.style_dim)
        sc_style = torch.zeros(self.style_dim)
        
        if self.istrain == 'val':
            seq_x = self.seq_val.X[index, ...]
            return torch.FloatTensor(seq_x), st_style
        else:
            print("Error in spx")
            seq_x = self.seq_test.X[index, ...]
            return torch.FloatTensor(seq_x), st_style


In [4]:
valdataset = ImputationNano(opt, istrain='val')

# hope (67798, 3482),
gene_names, cell_names = valdataset.get_eval_names()
print(len(gene_names), len(cell_names))

opt.seq_data dataset/nano9-1/scRNA_count_cluster.h5ad
shape (50517, 3482)
len(val_gene) 3482
AnnData object with n_obs × n_vars = 3482 × 6
77890 3482


In [5]:
opt.kfold = 10

In [6]:
# 预测SP中不存在的SC gene。

valdataloader = torch.utils.data.DataLoader(
    valdataset, 
    batch_size=opt.batch_size, 
    shuffle=False, 
    num_workers=0
)
opt.sc_dim = valdataset.get_cluster_dim()

model = ImputeModule(opt)
if opt.parallel:
    model = torch.nn.DataParallel(model).cuda().module
else:
    model = model.to(torch.device('cuda:%d'%(opt.gpu)))
# model.load(os.path.join(opt.save_path, 'best_pcc_%d.pth'%(opt.kfold)))
# model_path = os.path.join(opt.save_path, 'nano9-1/last_%d.pth'%(opt.kfold))
model_path = '/data/boom/SpaIM/results/UM_256/nano9-1/last_%d.pth'%(opt.kfold)
print('model_path:',model_path)
model.load(model_path)

with torch.no_grad():
    eval_result = None
    input_result = None
    for i, (seq, st_style) in enumerate(valdataloader):
        inputs = {
            'scx': seq,
            'st_style': st_style
        }
    # print('SCX:', inputs['scx'].shape, len(st_style))
    model.set_input(inputs, istrain=0)
    out = model.inference()
    impute_result = out['st_fake'].detach().cpu().numpy()
    print("impute_result",impute_result.shape) #  (3482, 67798) torch.Size([3482, 67798])
    eval_result = impute_result if eval_result is None else np.concatenate((eval_result, impute_result), axis=0)
    print("eval_result", eval_result.shape)  # (3482, 77890) (3482, 77890)

model_path: /data/boom/SpaIM/results/UM_256/nano9-1/last_10.pth
impute_result (3482, 77890)
eval_result (3482, 77890)


In [7]:
eval_result = eval_result.T
print(eval_result[0][:10])
eval_result[eval_result <0] = 0

[1.4798865 1.6316684 1.4377447 1.5326391 1.3243272 1.619016  1.731115
 1.5837083 1.5570486 1.598061 ]


In [8]:
df1 = pd.DataFrame(eval_result, index=gene_names, columns=cell_names)
df1.to_pickle(os.path.join(opt.save_path, 'impute_sc_result_%d.pkl'%(opt.kfold)))

In [9]:
ss

NameError: name 'ss' is not defined

# 验证性能

In [11]:
import scanpy as sc
import numpy as np
import pandas as pd

# 读取数据
adata = sc.read("./Insitu_count.h5ad")
SpaIM_adata1 = pd.read_pickle('./SpaIM/impute_sc_result_0.pkl')
Tangram_adata2 = pd.read_pickle('./Tangram/impute_sc_result_0.pkl')
StDiff_adata2 = pd.read_pickle("./StDiff/impute_sc_result_0.pkl")


# 提取表达矩阵
raw = adata.to_df()
spaim = SpaIM_adata1
stdiff = StDiff_adata2
tangram = Tangram_adata2

# 修改行名
raw.index = ['cell' + str(i) for i in range(1, len(raw) + 1)]
spaim.index = ['cell' + str(i) for i in range(1, len(spaim) + 1)]
stdiff.index = ['cell' + str(i) for i in range(1, len(stdiff) + 1)]
tangram.index = ['cell' + str(i) for i in range(1, len(tangram) + 1)]

# 修改列名
raw.columns = raw.columns.str.upper()
spaim.columns = spaim.columns.str.upper()
stdiff.columns = stdiff.columns.str.upper()
tangram.columns = tangram.columns.str.upper()

# 计算相关系数
genes = ['SOX4', 'TYK2', 'GPX1', 'EZH2']
for gene in genes:
    print(f"PCC between raw and spaim for gene {gene}:", np.corrcoef(raw[gene], spaim[gene])[0, 1])
    print(f"PCC between raw and stdiff for gene {gene}:", np.corrcoef(raw[gene], stdiff[gene])[0, 1])
    print(f"PCC between raw and tangram for gene {gene}:", np.corrcoef(raw[gene], tangram[gene])[0, 1])
    print('\n')


PCC between raw and spaim for gene SOX4: 0.7587358619247753
PCC between raw and stdiff for gene SOX4: 0.5683980537201061
PCC between raw and tangram for gene SOX4: 0.697424680178555


PCC between raw and spaim for gene TYK2: 0.7703004976585627
PCC between raw and stdiff for gene TYK2: 0.7362592818413901
PCC between raw and tangram for gene TYK2: 0.71610167420911


PCC between raw and spaim for gene GPX1: 0.7055524102949307
PCC between raw and stdiff for gene GPX1: 0.6371375643353872
PCC between raw and tangram for gene GPX1: 0.6836300159506713


PCC between raw and spaim for gene EZH2: 0.6750833938409269
PCC between raw and stdiff for gene EZH2: 0.5457309547193485
PCC between raw and tangram for gene EZH2: 0.6261168958741387




In [12]:
StDiff_adata2

Unnamed: 0,INHBA,LDLR,TIE1,CCL26,COL9A2,NPPC,CD69,IDO1,IL1RN,ANGPT1,...,LYN,TLR2,ACVR1B,P2RX5,HDAC5,HSP90AA1,CD53,KRT1,CHEK2,MERTK
cell1,1.063586,1.164445,1.487774,1.035214,1.155095,1.084211,1.166202,0.929605,1.079259,1.181175,...,1.336643,1.553326,1.149385,1.484229,1.794985,1.130400,1.116958,1.145909,0.891967,0.979264
cell2,0.519570,0.496891,0.648638,0.365379,0.493549,0.359684,0.519228,0.527985,0.415503,0.513308,...,0.551051,0.356723,0.494342,0.594713,0.687339,0.480835,0.502250,0.490179,0.227291,0.667386
cell3,0.388743,0.356920,0.504616,0.227486,0.354235,0.294949,0.376135,0.414485,0.311888,0.370574,...,0.389675,0.283712,0.355707,0.413471,0.465267,0.353314,0.365576,0.353562,0.151612,0.538312
cell4,1.803507,1.795272,1.518277,1.757318,1.803264,1.345231,1.772109,1.760620,1.631994,1.782756,...,1.701626,0.849249,1.804409,1.612099,1.392613,1.777482,1.810312,1.801392,1.644680,1.611301
cell5,0.611922,0.614582,0.436615,0.628017,0.616824,0.588203,0.595228,0.611810,0.622509,0.603515,...,0.572449,0.410215,0.616871,0.538351,0.464412,0.624087,0.617588,0.619532,0.640286,0.523360
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
cell77886,0.996578,0.870783,0.788603,0.918089,0.874129,0.576163,0.926645,1.105994,0.824202,0.894659,...,0.849412,0.429759,0.879537,0.828423,0.769889,0.847575,0.916437,0.869451,0.848863,1.234866
cell77887,0.825459,0.763467,0.675310,0.758599,0.766241,0.696206,0.784354,0.889616,0.770513,0.769318,...,0.736226,0.501075,0.768025,0.715558,0.670549,0.762326,0.786071,0.765050,0.738257,0.960389
cell77888,1.065229,1.003374,0.787455,1.071305,1.006730,1.020148,1.025412,1.131571,1.069064,1.006804,...,0.943120,0.782857,1.010793,0.894593,0.790181,1.018074,1.031948,1.011205,1.082727,1.104997
cell77889,0.406149,0.391647,0.413812,0.461464,0.393865,0.292364,0.394190,0.399038,0.371251,0.391244,...,0.379466,0.236469,0.393775,0.369282,0.340221,0.380896,0.396474,0.390835,0.450703,0.363694


# 验证其他模型的性能

In [13]:
Tangram_adata2

Unnamed: 0,INHBA,LDLR,TIE1,CCL26,COL9A2,NPPC,CD69,IDO1,IL1RN,ANGPT1,...,LYN,TLR2,ACVR1B,P2RX5,HDAC5,HSP90AA1,CD53,KRT1,CHEK2,MERTK
cell1,1.372462,1.385759,1.346315,1.618329,1.616426,1.452099,1.162762,1.342827,1.536951,1.460292,...,1.410905,4.444549,1.339869,1.437019,3.429359,1.284247,1.318124,1.618166,1.306803,1.295124
cell2,1.216525,0.915988,0.751930,0.688796,0.691021,1.226137,0.569900,0.680424,0.567582,0.511399,...,0.725879,1.410370,0.196378,0.482999,1.294981,1.330006,1.218706,0.684353,1.040007,0.599171
cell3,0.747721,0.601163,0.529847,0.502560,0.502491,0.716440,0.398147,0.511071,0.472359,0.370177,...,0.556616,1.024033,0.508798,0.286214,0.899212,0.781485,0.725441,0.500519,0.651080,0.510676
cell4,1.848191,2.322559,2.468894,2.488029,2.485056,2.000819,2.133903,2.501660,2.526745,2.381063,...,2.236371,0.263376,1.018038,2.280138,1.300602,1.896048,2.000386,2.486050,2.227812,2.505647
cell5,0.739802,0.752589,0.757213,0.796199,0.797085,0.751060,0.714752,0.751173,0.782606,0.759741,...,0.728349,0.009697,0.493267,0.777070,0.545175,0.765143,0.751068,0.795618,0.766805,0.737331
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
cell77886,1.656364,1.574232,1.507485,1.261637,1.259325,1.674238,1.383185,1.463977,1.235001,1.261404,...,1.463300,0.741257,0.639050,1.177239,1.066799,1.728534,1.692614,1.261506,1.626664,1.439381
cell77887,1.225514,1.186451,1.139241,1.060641,1.059703,1.253470,1.077305,1.114091,1.026913,0.989143,...,1.100366,0.303969,0.748894,0.958093,0.917352,1.279873,1.251452,1.059195,1.215932,1.076477
cell77888,1.480382,1.520379,1.537626,1.302895,1.299127,1.435289,1.586027,1.540093,1.374902,1.375334,...,1.572188,0.193837,0.923275,1.279765,0.801617,1.414461,1.450390,1.303429,1.493734,1.583049
cell77889,0.495527,0.523958,0.540362,0.474267,0.474815,0.493588,0.631919,0.544619,0.490439,0.565113,...,0.556394,0.121552,0.112974,0.597548,0.387331,0.497909,0.506833,0.474790,0.521408,0.548886


In [14]:
SpaIM_adata1

Unnamed: 0,INHBA,LDLR,TIE1,CCL26,COL9A2,NPPC,CD69,IDO1,IL1RN,ANGPT1,...,LYN,TLR2,ACVR1B,P2RX5,HDAC5,HSP90AA1,CD53,KRT1,CHEK2,MERTK
cell1,0.047889,0.064791,0.008742,0.001069,0.028767,0.003033,0.593395,0.025338,0.219573,0.005402,...,0.324322,0.069959,0.019602,0.036572,0.067193,1.862892,0.689115,0.001099,0.019225,0.040306
cell2,0.031549,0.036776,0.000336,0.001010,0.009182,0.000964,0.139967,0.019882,0.061639,0.000625,...,0.125434,0.016854,0.006488,0.003710,0.022558,0.597873,0.195461,0.000119,0.006429,0.003323
cell3,0.005840,0.056018,0.001357,0.000648,0.009767,0.000322,0.169711,0.005461,0.031963,0.000701,...,0.056106,0.023898,0.021721,0.003196,0.021990,0.435931,0.121312,0.001037,0.002529,0.006293
cell4,0.022697,0.161332,0.015583,0.001777,0.052870,0.004608,0.573670,0.024471,0.090885,0.008441,...,0.294587,0.036523,0.038490,0.036925,0.088148,1.995933,0.556413,0.002407,0.027324,0.015695
cell5,0.008493,0.038584,0.002483,0.000592,0.013996,0.000768,0.207594,0.010939,0.012456,0.003739,...,0.040583,0.008697,0.009627,0.006169,0.024671,0.648862,0.162830,0.000148,0.004364,0.004183
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
cell77886,0.049949,0.117451,0.000738,0.001237,0.021260,0.000775,0.395784,0.007698,0.050153,0.002419,...,0.168235,0.025693,0.020248,0.027864,0.049854,1.174337,0.365251,0.000283,0.010118,0.011198
cell77887,0.011677,0.064093,0.005309,0.000780,0.022239,0.001442,0.328221,0.010350,0.057874,0.003482,...,0.192527,0.020633,0.015089,0.035855,0.045949,0.863022,0.321350,0.000419,0.010044,0.007813
cell77888,0.059046,0.086097,0.007109,0.000779,0.027178,0.003596,0.660674,0.019473,0.055189,0.002923,...,0.155285,0.018919,0.020824,0.022398,0.063135,1.661218,0.546224,0.000463,0.016341,0.006982
cell77889,0.020747,0.031646,0.005641,0.000263,0.012425,0.001577,0.159362,0.003221,0.020491,0.001412,...,0.062934,0.006795,0.007637,0.007063,0.019693,0.506529,0.170933,0.000537,0.006508,0.003857


In [81]:
import scanpy as sc
import numpy as np

# 读取数据
adata = sc.read("dataset/nano9-1/Insitu_count.h5ad")
SpaIM_adata1 = df1
# SpaIM_adata1 = pd.read_pickle("path/to/SpaIM_adata1.pkl")


# 提取表达矩阵
raw = adata.to_df()
spaim = SpaIM_adata1
# stdiff = StDiff_adata2.to_df()
# tangram = Tangram_adata2.to_df()

# 修改行名
raw.index = ['cell' + str(i) for i in range(1, len(raw) + 1)]
spaim.index = ['cell' + str(i) for i in range(1, len(spaim) + 1)]
# stdiff.index = ['cell' + str(i) for i in range(1, len(stdiff) + 1)]
# tangram.index = ['cell' + str(i) for i in range(1, len(tangram) + 1)]

# 修改列名
raw.columns = raw.columns.str.upper()
spaim.columns = spaim.columns.str.upper()
# stdiff.columns = stdiff.columns.str.upper()
# tangram.columns = tangram.columns.str.upper()

# 计算相关系数
genes = ['SOX4', 'TYK2', 'GPX1', 'EZH2']
for gene in genes:
    print(f"PCC between raw and spaim for gene {gene}:", np.corrcoef(raw[gene], spaim[gene])[0, 1])
    # print(f"PCC between raw and stdiff for gene {gene}:", np.corrcoef(raw[gene], stdiff[gene])[0, 1])
    # print(f"PCC between raw and tangram for gene {gene}:", np.corrcoef(raw[gene], tangram[gene])[0, 1])
    print('\n')
