In [1]:
import torch
import warnings
import pickle

import squidpy as sq
import numpy as np
import pandas as pd
import seaborn as sns

from sklearn.model_selection import KFold
from transpa.eval_util import calc_corr
from transpa.util import expTransImp, compute_autocorr, plot_genes

warnings.filterwarnings("ignore")
pre_datapath = "../../output/preprocessed_dataset/seqFISH_single_cell.pkl"

seed = 10
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
import scanpy as sc
import scipy.stats as st
from sklearn.model_selection import train_test_split

In [3]:
spa_adata = sc.read("/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/spatial_data.h5ad")
scrna_adata = sc.read("/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/scrnaseq_data.h5ad")

In [4]:
def calcualte_pse_correlation(adata_sc, adata_st, celltype, p_value_threshold = 0.05, cor_threshold = 0.5):
    import random 
    random.seed(2023)
    overlap_gene = list(set(adata_sc.var_names).intersection(adata_st.var_names))
    overlap_gene = sorted(overlap_gene)
    adata_sc = adata_sc[:,overlap_gene]
    adata_st = adata_st[:,overlap_gene]
    
    cell_type_common = list(set(adata_sc.obs[celltype].unique()).intersection(adata_st.obs[celltype].unique()))
    
    pseudo_st = []
    pseudo_sc = []
    for i in cell_type_common:
        adata1 = adata_st[adata_st.obs[celltype] == i]
        adata2 = adata_sc[adata_sc.obs[celltype] == i]

        pseudo_st.append(np.mean(adata1.X.toarray(), axis = 0))
        pseudo_sc.append(np.mean(adata2.X.toarray(), axis = 0))
    
    pseudo_st = np.array(pseudo_st)
    pseudo_sc = np.array(pseudo_sc)

    cor_pearson = []
    cor_pvalue = []
    for i in range(pseudo_st.shape[1]):
        cor, pval = st.pearsonr(pseudo_st[:,i], pseudo_sc[:,i])
        cor_pearson.append(cor)
        cor_pvalue.append(pval)
        
    information_stat = pd.DataFrame()

    information_stat['pearson'] = cor_pearson
    information_stat['pvalue'] = cor_pvalue
    information_stat.index = adata_st.var_names

    information_stat_update = information_stat.loc[((information_stat['pvalue']<p_value_threshold) & (information_stat['pearson']>cor_threshold))]
    
    return information_stat_update.index

In [5]:
info_gene = calcualte_pse_correlation(scrna_adata, spa_adata, 'scClassify')

In [6]:
seq_data = scrna_adata[:,info_gene]
spatial_data = spa_adata[:,info_gene]

In [7]:
import random 
random.seed(2023)
g1 = list(set(spatial_data.var_names).intersection(seq_data.var_names))
g1 = sorted(g1)
train_gene, test_gene = train_test_split(g1, test_size=0.33, random_state=2023)

In [8]:
raw_scrna_df = seq_data.to_df()
raw_spatial_df = spatial_data.to_df()

In [9]:
classes = seq_data.obs['scClassify']
ct_list = np.unique(classes)

In [10]:
# transImpRes = expTransImp(
#                 df_ref=raw_scrna_df,
#                 df_tgt=raw_spatial_df,
#                 train_gene=train_gene,
#                 test_gene=test_gene,
#                 n_simulation=200,
#                 signature_mode='cell',
#                 mapping_mode='lowrank',
#                 classes=classes,
#                 n_epochs=2000,
#                 seed=seed,
#                 device=device
# )

In [11]:
# df_transImpLR = pd.DataFrame(np.zeros((spa_adata.n_obs, len(info_gene))), columns=info_gene)

In [12]:
# transImpRes[0]

In [13]:
# df_transImpLR[test_gene] = transImpRes[0]

In [14]:
# df_transImpLR[test_gene]

In [20]:
for seed in range(0,10):
    df_transImpLR = pd.DataFrame(np.zeros((spa_adata.n_obs, len(info_gene))), columns=info_gene)
    transImpRes = expTransImp(
                df_ref=raw_scrna_df,
                df_tgt=raw_spatial_df,
                train_gene=train_gene,
                test_gene=test_gene,
                n_simulation=200,
                signature_mode='cell',
                mapping_mode='lowrank',
                classes=classes,
                n_epochs=2000,
                seed=seed,
                device=device
    )
    df_transImpLR[test_gene] = transImpRes[0]
    adata_out = sc.AnnData(df_transImpLR)
    adata_out.write_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/transpa/smfish/transimp_smfish_2000_seed{seed}.h5ad")

[TransImp] Epoch: 2000/2000, loss: 0.003520, (IMP) 0.003520: 100%|██████████| 2000/2000 [00:07<00:00, 274.68it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


In [21]:
for seed in range(0,10):
    df_transImpLR = pd.DataFrame(np.zeros((spa_adata.n_obs, len(info_gene))), columns=info_gene)
    transImpRes = expTransImp(
                df_ref=raw_scrna_df,
                df_tgt=raw_spatial_df,
                train_gene=train_gene,
                test_gene=test_gene,
                n_simulation=200,
                signature_mode='cell',
                mapping_mode='lowrank',
                classes=classes,
                n_epochs=2000,
                seed=seed,
                device=device
    )
    df_transImpLR[test_gene] = transImpRes[0]
    adata_out = sc.AnnData(df_transImpLR)
    df_var = pd.DataFrame(transImpRes[1], index=np.concatenate([test_gene]), columns=['pred_var'])
    df_var_sort = df_var.loc[test_gene].sort_values(axis=0, by='pred_var', ascending=True)
    print(df_var_sort.index[0:len(df_var_sort)//2]) 
    adata_out_var = adata_out[:,list(df_var_sort.index[0:len(df_var_sort)//2])]
    print(adata_out_var)
    adata_out_var.write_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/transpa/smfish/transimp_smfish_2000_seed_cerd{seed}.h5ad")

[TransImp] Epoch: 2000/2000, loss: 0.003520, (IMP) 0.003520: 100%|██████████| 2000/2000 [00:07<00:00, 274.35it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


[TransImp] Epoch: 2000/2000, loss: 0.004133, (IMP) 0.004133: 100%|██████████| 2000/2000 [00:07<00:00, 274.64it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


[TransImp] Epoch: 2000/2000, loss: 0.003726, (IMP) 0.003726: 100%|██████████| 2000/2000 [00:07<00:00, 275.25it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


[TransImp] Epoch: 2000/2000, loss: 0.003432, (IMP) 0.003432: 100%|██████████| 2000/2000 [00:07<00:00, 274.36it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


[TransImp] Epoch: 2000/2000, loss: 0.003568, (IMP) 0.003568: 100%|██████████| 2000/2000 [00:07<00:00, 273.89it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


[TransImp] Epoch: 2000/2000, loss: 0.004016, (IMP) 0.004016: 100%|██████████| 2000/2000 [00:07<00:00, 275.48it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


[TransImp] Epoch: 2000/2000, loss: 0.004134, (IMP) 0.004134: 100%|██████████| 2000/2000 [00:07<00:00, 275.52it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


[TransImp] Epoch: 2000/2000, loss: 0.004343, (IMP) 0.004343: 100%|██████████| 2000/2000 [00:07<00:00, 275.54it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


[TransImp] Epoch: 2000/2000, loss: 0.003321, (IMP) 0.003321: 100%|██████████| 2000/2000 [00:07<00:00, 274.68it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


[TransImp] Epoch: 2000/2000, loss: 0.004239, (IMP) 0.004239: 100%|██████████| 2000/2000 [00:07<00:00, 274.48it/s]


Index(['Mrc1', 'Hexb', 'Itpr2'], dtype='object')
View of AnnData object with n_obs × n_vars = 4530 × 3


In [21]:
transImpRes[0]

array([[0.01040122, 0.04729909, 0.02260043, ..., 0.06649041, 0.02019353,
        0.01264029],
       [0.01454547, 0.03600146, 0.01033609, ..., 0.05045058, 0.01202893,
        0.00695648],
       [0.00949453, 0.03566999, 0.02917271, ..., 0.0554388 , 0.01957636,
        0.01418971],
       ...,
       [0.00674083, 0.0285861 , 0.00021085, ..., 0.02758951, 0.00078108,
        0.00091389],
       [0.00413906, 0.0199748 , 0.0004785 , ..., 0.02863125, 0.001014  ,
        0.0004684 ],
       [0.00621114, 0.01868274, 0.00077729, ..., 0.02963695, 0.00075953,
        0.00220204]], dtype=float32)

# Breast

In [24]:
import torch
import warnings
import pickle

import squidpy as sq
import numpy as np
import pandas as pd
import seaborn as sns

from sklearn.model_selection import KFold
from transpa.eval_util import calc_corr
from transpa.util import expTransImp, compute_autocorr, plot_genes

warnings.filterwarnings("ignore")
pre_datapath = "../../output/preprocessed_dataset/seqFISH_single_cell.pkl"

seed = 10
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [25]:
import scanpy as sc
import scipy.stats as st
from sklearn.model_selection import train_test_split

In [27]:
def calcualte_pse_correlation(adata_sc, adata_st, celltype, p_value_threshold = 0.05, cor_threshold = 0.5):
    import random 
    random.seed(2023)
    overlap_gene = list(set(adata_sc.var_names).intersection(adata_st.var_names))
    overlap_gene = sorted(overlap_gene)
    adata_sc = adata_sc[:,overlap_gene]
    adata_st = adata_st[:,overlap_gene]
    
    cell_type_common = list(set(adata_sc.obs[celltype].unique()).intersection(adata_st.obs[celltype].unique()))
    
    pseudo_st = []
    pseudo_sc = []
    for i in cell_type_common:
        adata1 = adata_st[adata_st.obs[celltype] == i]
        adata2 = adata_sc[adata_sc.obs[celltype] == i]

        pseudo_st.append(np.mean(adata1.X.toarray(), axis = 0))
        pseudo_sc.append(np.mean(adata2.X.toarray(), axis = 0))
    
    pseudo_st = np.array(pseudo_st)
    pseudo_sc = np.array(pseudo_sc)

    cor_pearson = []
    cor_pvalue = []
    for i in range(pseudo_st.shape[1]):
        cor, pval = st.pearsonr(pseudo_st[:,i], pseudo_sc[:,i])
        cor_pearson.append(cor)
        cor_pvalue.append(pval)
        
    information_stat = pd.DataFrame()

    information_stat['pearson'] = cor_pearson
    information_stat['pvalue'] = cor_pvalue
    information_stat.index = adata_st.var_names

    information_stat_update = information_stat.loc[((information_stat['pvalue']<p_value_threshold) & (information_stat['pearson']>cor_threshold))]
    
    return information_stat_update.index

In [32]:
for sample_index in range(0,10):
    seq_data = sc.read_h5ad("/gpfs/gibbs/pi/zhao/tl688/deconvdatasets/spatial_dataset/xenium_breast/sce_FFPE_full.h5ad")
    spatial_data = sc.read_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/human_breast_simulation/spe_xenium_data_0.1_seed{sample_index}.h5ad")
    seq_data.var_names_make_unique()
    spatial_data.var_names_make_unique()

    seq_data.obs['scClassify'] = seq_data.obs['graph_cluster_anno'].copy() 
    info_gene = calcualte_pse_correlation(seq_data, spatial_data, 'scClassify')

    seq_data = seq_data[:,info_gene]
    spatial_data = spatial_data[:,info_gene]

    import random 
    random.seed(2023)
    g1 = list(set(spatial_data .var_names).intersection(seq_data.var_names))
    g1  = sorted(g1)
    train_gene, test_gene = train_test_split(g1, test_size=0.33, random_state=2023)
    spatial_data_partial = spatial_data[:, train_gene].copy()
    seq_data = seq_data.copy()

    classes = seq_data.obs['scClassify']
    ct_list = np.unique(classes)

    raw_scrna_df = seq_data.to_df()
    raw_spatial_df = spatial_data_partial.to_df()
    
    df_transImpLR = pd.DataFrame(np.zeros((spatial_data_partial.n_obs, len(info_gene))), columns=info_gene)
    transImpRes = expTransImp(
                df_ref=raw_scrna_df,
                df_tgt=raw_spatial_df,
                train_gene=train_gene,
                test_gene=test_gene,
                n_simulation=200,
                signature_mode='cell',
                mapping_mode='lowrank',
                classes=classes,
                n_epochs=2000,
                seed=seed,
                device=device
    )
    df_transImpLR[test_gene] = transImpRes[0]
    adata_out = sc.AnnData(df_transImpLR)
    
    adata_out.write_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/transpa/breast/transimp_breast_seed{sample_index}.h5ad")

[TransImp] Epoch: 2000/2000, loss: 0.346361, (IMP) 0.346361: 100%|██████████| 2000/2000 [00:23<00:00, 83.46it/s]
[TransImp] Epoch: 2000/2000, loss: 0.343096, (IMP) 0.343096: 100%|██████████| 2000/2000 [00:24<00:00, 83.12it/s]
[TransImp] Epoch: 2000/2000, loss: 0.353855, (IMP) 0.353855: 100%|██████████| 2000/2000 [00:23<00:00, 83.53it/s]
[TransImp] Epoch: 2000/2000, loss: 0.366327, (IMP) 0.366327: 100%|██████████| 2000/2000 [00:24<00:00, 82.50it/s]
[TransImp] Epoch: 2000/2000, loss: 0.344171, (IMP) 0.344171: 100%|██████████| 2000/2000 [00:24<00:00, 83.32it/s]
[TransImp] Epoch: 2000/2000, loss: 0.319577, (IMP) 0.319577: 100%|██████████| 2000/2000 [00:23<00:00, 86.22it/s]
[TransImp] Epoch: 2000/2000, loss: 0.360639, (IMP) 0.360639: 100%|██████████| 2000/2000 [00:23<00:00, 84.76it/s]
[TransImp] Epoch: 2000/2000, loss: 0.345269, (IMP) 0.345269: 100%|██████████| 2000/2000 [00:24<00:00, 82.60it/s]
[TransImp] Epoch: 2000/2000, loss: 0.355712, (IMP) 0.355712: 100%|██████████| 2000/2000 [00:25<0

In [33]:
df_transImpLR[test_gene]

Unnamed: 0,HOXD9,PIGR,VWF,ADGRE5,CCDC80,CXCL16,ELF3,HOXD8,CLDN4,CCL5,...,CD93,CD79B,MMP2,AKR1C3,C6orf132,C1QA,PPARG,KDR,SVIL,EIF4EBP1
0,0.008890,0.002796,0.081828,0.088017,0.382570,0.019902,0.078597,0.025988,0.060626,0.017476,...,0.088338,0.003149,0.562675,0.003351,0.007676,0.040858,0.064605,0.032480,0.153413,0.019302
1,0.093076,0.008001,0.890782,0.072491,0.327164,0.029000,0.198429,0.071965,0.141359,0.040791,...,0.493569,0.023464,0.485350,0.027805,0.020388,0.074989,0.112168,0.262443,0.167432,0.038762
2,0.003950,0.042791,0.044916,0.257559,0.180430,0.492387,0.331992,0.013231,0.251886,0.138886,...,0.170971,0.035714,0.408016,0.020482,0.056766,1.295562,0.148724,0.022152,0.231515,0.137983
3,0.078504,0.017580,0.523990,0.231010,1.570625,0.071688,0.298692,0.072918,0.204551,0.092365,...,0.406856,0.051969,1.644617,0.054219,0.037177,0.246142,0.165295,0.236070,0.579752,0.048302
4,0.009779,0.007493,0.041950,0.006800,0.173881,0.019256,0.137878,0.014193,0.110811,0.017441,...,0.031637,0.005307,0.174432,0.001998,0.009235,0.070200,0.017484,0.031513,0.078778,0.007346
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
16020,0.016019,0.001644,0.176941,0.221863,0.346012,0.020614,0.098653,0.011383,0.073834,0.436672,...,0.096923,0.027715,0.357563,0.005579,0.014055,0.067533,0.020004,0.014355,0.107310,0.017951
16021,0.004223,0.005470,0.032659,0.040685,0.097448,0.034254,0.514029,0.020914,0.341120,0.040918,...,0.017375,0.006079,0.054118,0.000097,0.057603,0.063697,0.014923,0.010071,0.059444,0.035629
16022,0.010049,0.009192,0.062292,0.066887,0.218966,0.059398,0.893084,0.039527,0.473350,0.063587,...,0.049724,0.008814,0.208358,0.001555,0.073437,0.237521,0.034408,0.024055,0.128522,0.060053
16023,0.006084,0.025310,0.077454,0.091341,0.514888,0.074015,0.976732,0.055640,0.617394,0.061365,...,0.044548,0.027903,0.796936,0.006261,0.120129,0.122409,0.044279,0.019980,0.213464,0.068315


# Brain

In [1]:
import torch
import warnings
import pickle

import squidpy as sq
import numpy as np
import pandas as pd
import seaborn as sns

from sklearn.model_selection import KFold
from transpa.eval_util import calc_corr
from transpa.util import expTransImp, compute_autocorr, plot_genes

warnings.filterwarnings("ignore")
pre_datapath = "../../output/preprocessed_dataset/seqFISH_single_cell.pkl"

seed = 10
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
import scanpy as sc
import scipy.stats as st
from sklearn.model_selection import train_test_split

In [3]:
def calcualte_pse_correlation(adata_sc, adata_st, celltype, p_value_threshold = 0.05, cor_threshold = 0.5):
    import random 
    random.seed(2023)
    overlap_gene = list(set(adata_sc.var_names).intersection(adata_st.var_names))
    overlap_gene = sorted(overlap_gene)
    adata_sc = adata_sc[:,overlap_gene]
    adata_st = adata_st[:,overlap_gene]
    
    cell_type_common = list(set(adata_sc.obs[celltype].unique()).intersection(adata_st.obs[celltype].unique()))
    
    pseudo_st = []
    pseudo_sc = []
    for i in cell_type_common:
        adata1 = adata_st[adata_st.obs[celltype] == i]
        adata2 = adata_sc[adata_sc.obs[celltype] == i]

        pseudo_st.append(np.mean(adata1.X.toarray(), axis = 0))
        pseudo_sc.append(np.mean(adata2.X.toarray(), axis = 0))
    
    pseudo_st = np.array(pseudo_st)
    pseudo_sc = np.array(pseudo_sc)

    cor_pearson = []
    cor_pvalue = []
    for i in range(pseudo_st.shape[1]):
        cor, pval = st.pearsonr(pseudo_st[:,i], pseudo_sc[:,i])
        cor_pearson.append(cor)
        cor_pvalue.append(pval)
        
    information_stat = pd.DataFrame()

    information_stat['pearson'] = cor_pearson
    information_stat['pvalue'] = cor_pvalue
    information_stat.index = adata_st.var_names

    information_stat_update = information_stat.loc[((information_stat['pvalue']<p_value_threshold) & (information_stat['pearson']>cor_threshold))]
    
    return information_stat_update.index

In [4]:
import gc

In [5]:
for sample_index in range(4,10):
    seq_data = sc.read_h5ad("/gpfs/gibbs/pi/zhao/tl688/deconvdatasets/spatial_dataset/xenium_brain/aibs_mouse_ctx-hpf_smartseq_sce.h5ad")
    spatial_data = sc.read_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/data_brain/spe_xenium_data_0.1_seed{sample_index}.h5ad")
    seq_data.var_names_make_unique()
    seq_data.obs['scClassify'] = seq_data.obs['cell_type_alias_label2'].copy() 

    classes = seq_data.obs['scClassify']
    ct_list = np.unique(classes)

    info_gene = calcualte_pse_correlation(seq_data, spatial_data, 'scClassify')

    seq_data = seq_data[:,info_gene]
    spatial_data = spatial_data[:,info_gene]

    import random 
    random.seed(2023)
    g1 = list(set(spatial_data.var_names).intersection(seq_data.var_names))
    g1  = sorted(g1)
    train_gene, test_gene = train_test_split(g1, test_size=0.33, random_state=2023)
    spatial_data_partial = spatial_data[:, train_gene].copy()
    seq_data = seq_data.copy()

    raw_scrna_df = seq_data.to_df()
    raw_spatial_df = spatial_data_partial.to_df()
    
    df_transImpLR = pd.DataFrame(np.zeros((spatial_data_partial.n_obs, len(info_gene))), columns=info_gene)
    transImpRes = expTransImp(
                df_ref=raw_scrna_df,
                df_tgt=raw_spatial_df,
                train_gene=train_gene,
                test_gene=test_gene,
                n_simulation=200,
                signature_mode='cell',
                mapping_mode='lowrank',
                classes=classes,
                n_epochs=2000,
                seed=seed,
                device=device
    )
    df_transImpLR[test_gene] = transImpRes[0]
    adata_out = sc.AnnData(df_transImpLR)
    gc.collect()
    adata_out.write_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/transpa/brain/transimp_brain_seed{sample_index}.h5ad")

[TransImp] Epoch: 2000/2000, loss: 0.283279, (IMP) 0.283279: 100%|██████████| 2000/2000 [00:33<00:00, 59.26it/s]
[TransImp] Epoch: 2000/2000, loss: 0.353030, (IMP) 0.353030: 100%|██████████| 2000/2000 [00:30<00:00, 64.93it/s]
[TransImp] Epoch: 2000/2000, loss: 0.299047, (IMP) 0.299047: 100%|██████████| 2000/2000 [00:31<00:00, 63.39it/s]
[TransImp] Epoch: 2000/2000, loss: 0.312734, (IMP) 0.312734: 100%|██████████| 2000/2000 [00:31<00:00, 63.65it/s]
[TransImp] Epoch: 2000/2000, loss: 0.303013, (IMP) 0.303013: 100%|██████████| 2000/2000 [00:31<00:00, 63.59it/s]
[TransImp] Epoch: 2000/2000, loss: 0.304878, (IMP) 0.304878: 100%|██████████| 2000/2000 [00:31<00:00, 63.79it/s]


In [20]:
info_gene = calcualte_pse_correlation(scrna_adata, spa_adata, 'scClassify')

In [21]:
seq_data = scrna_adata[:,info_gene]
spatial_data = spa_adata[:,info_gene]

In [22]:
import random 
random.seed(2023)
g1 = list(set(spatial_data.var_names).intersection(seq_data.var_names))
g1 = sorted(g1)
train_gene, test_gene = train_test_split(g1, test_size=0.33, random_state=2023)

In [23]:
raw_scrna_df = seq_data.to_df()
raw_spatial_df = spatial_data.to_df()

In [25]:
classes = seq_data.obs['scClassify']
ct_list = np.unique(classes)

In [26]:
transImpRes = expTransImp(
                df_ref=raw_scrna_df,
                df_tgt=raw_spatial_df,
                train_gene=train_gene,
                test_gene=test_gene,
                n_simulation=200,
                signature_mode='cell',
                mapping_mode='lowrank',
                classes=classes,
                n_epochs=2000,
                seed=seed,
                device=device
)

[TransImp] Epoch: 2000/2000, loss: 0.003844, (IMP) 0.003844: 100%|██████████| 2000/2000 [01:21<00:00, 24.41it/s]


In [31]:
df_transImpLR = pd.DataFrame(np.zeros((spa_adata.n_obs, len(info_gene))), columns=info_gene)

In [32]:
transImpRes[0]

array([[0.01008041, 0.05050546, 0.02460364, ..., 0.05231228, 0.02780909,
        0.01175085],
       [0.00987022, 0.0271963 , 0.01161936, ..., 0.0430636 , 0.01219665,
        0.00620246],
       [0.00967101, 0.04951893, 0.0244266 , ..., 0.06738168, 0.02638319,
        0.01242517],
       ...,
       [0.00683271, 0.02847621, 0.00018988, ..., 0.03571602, 0.00078927,
        0.00090942],
       [0.00491358, 0.04003397, 0.00018913, ..., 0.03880098, 0.00057765,
        0.00049592],
       [0.00669112, 0.02591524, 0.00102373, ..., 0.03220956, 0.00063202,
        0.00140854]], dtype=float32)

In [33]:
df_transImpLR[test_gene] = transImpRes[0]

In [34]:
df_transImpLR[test_gene]

Unnamed: 0,Itpr2,Hexb,Pthlh,Plp1,Mrc1,Crh,Gfap
0,0.010080,0.050505,0.024604,2.502050,0.052312,0.027809,0.011751
1,0.009870,0.027196,0.011619,2.834172,0.043064,0.012197,0.006202
2,0.009671,0.049519,0.024427,2.048846,0.067382,0.026383,0.012425
3,0.005823,0.018915,0.021167,0.977459,0.046649,0.007997,0.000052
4,0.006843,0.024053,0.027172,1.973019,0.037029,0.022219,0.000089
...,...,...,...,...,...,...,...
4525,0.004952,0.037006,0.000247,0.682299,0.035568,0.000251,0.001046
4526,0.005467,0.035986,0.000342,1.456438,0.038851,0.000730,0.000314
4527,0.006833,0.028476,0.000190,1.709622,0.035716,0.000789,0.000909
4528,0.004914,0.040034,0.000189,1.127144,0.038801,0.000578,0.000496


In [35]:
for seed in range(0,10):
    df_transImpLR = pd.DataFrame(np.zeros((spa_adata.n_obs, len(info_gene))), columns=info_gene)
    transImpRes = expTransImp(
                df_ref=raw_scrna_df,
                df_tgt=raw_spatial_df,
                train_gene=train_gene,
                test_gene=test_gene,
                n_simulation=200,
                signature_mode='cell',
                mapping_mode='lowrank',
                classes=classes,
                n_epochs=2000,
                seed=seed,
                device=device
    )
    df_transImpLR[test_gene] = transImpRes[0]
    adata_out = sc.AnnData(df_transImpLR)
    adata_out.write_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/transpa/smfish/transimp_smfish_seed{seed}.h5ad")

[TransImp] Epoch: 2000/2000, loss: 0.003537, (IMP) 0.003537: 100%|██████████| 2000/2000 [01:22<00:00, 24.35it/s]
[TransImp] Epoch: 2000/2000, loss: 0.004132, (IMP) 0.004132: 100%|██████████| 2000/2000 [01:15<00:00, 26.54it/s]
[TransImp] Epoch: 2000/2000, loss: 0.003828, (IMP) 0.003828: 100%|██████████| 2000/2000 [01:21<00:00, 24.55it/s]
[TransImp] Epoch: 2000/2000, loss: 0.003433, (IMP) 0.003433: 100%|██████████| 2000/2000 [01:21<00:00, 24.59it/s]
[TransImp] Epoch: 2000/2000, loss: 0.003561, (IMP) 0.003561: 100%|██████████| 2000/2000 [01:21<00:00, 24.63it/s]
[TransImp] Epoch: 2000/2000, loss: 0.004002, (IMP) 0.004002: 100%|██████████| 2000/2000 [01:23<00:00, 24.09it/s]
[TransImp] Epoch: 2000/2000, loss: 0.004122, (IMP) 0.004122: 100%|██████████| 2000/2000 [01:15<00:00, 26.44it/s]
[TransImp] Epoch: 2000/2000, loss: 0.004345, (IMP) 0.004345: 100%|██████████| 2000/2000 [01:14<00:00, 26.77it/s]
[TransImp] Epoch: 2000/2000, loss: 0.003326, (IMP) 0.003326: 100%|██████████| 2000/2000 [01:15<0

In [37]:
transImpRes[0]

array([[0.01026535, 0.04710996, 0.02292576, ..., 0.06517296, 0.0200364 ,
        0.01253883],
       [0.0146467 , 0.03590734, 0.01032561, ..., 0.05029929, 0.01200929,
        0.00694228],
       [0.00948082, 0.03555085, 0.02914999, ..., 0.05524319, 0.01959631,
        0.01420537],
       ...,
       [0.00674488, 0.02863701, 0.00021149, ..., 0.02759529, 0.00078623,
        0.00091469],
       [0.00413624, 0.01995281, 0.00047814, ..., 0.02856422, 0.00102261,
        0.00047131],
       [0.00619856, 0.01864921, 0.00077844, ..., 0.02953735, 0.00075872,
        0.00220262]], dtype=float32)