In [1]:
import anndata
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

import numpy as np
import scanpy
import scanpy as sc
from scipy.stats import spearmanr
from imputevi import GIMVI_GCN
import scvi
import pandas as pd
from sklearn.model_selection import train_test_split

import random 
random.seed(2023)

[rank: 0] Global seed set to 0


In [2]:
from scipy import stats
import scipy.stats as st
import scipy

In [3]:
def calcualte_pse_correlation(adata_sc, adata_st, celltype, p_value_threshold = 0.05, cor_threshold = 0.5):
    overlap_gene = overlap_gene = list(set(adata_sc.var_names).intersection(adata_st.var_names))
    adata_sc = adata_sc[:,overlap_gene]
    adata_st = adata_st[:,overlap_gene]
    
    cell_type_common = list(set(adata_sc.obs[celltype].unique()).intersection(adata_st.obs[celltype].unique()))
    
    pseudo_st = []
    pseudo_sc = []
    for i in cell_type_common:
        adata1 = adata_st[adata_st.obs[celltype] == i]
        adata2 = adata_sc[adata_sc.obs[celltype] == i]

        pseudo_st.append(np.mean(adata1.X.toarray(), axis = 0))
        pseudo_sc.append(np.mean(adata2.X.toarray(), axis = 0))
    
    pseudo_st = np.array(pseudo_st)
    pseudo_sc = np.array(pseudo_sc)

    cor_pearson = []
    cor_pvalue = []
    for i in range(pseudo_st.shape[1]):
        cor, pval = st.pearsonr(pseudo_st[:,i], pseudo_sc[:,i])
        cor_pearson.append(cor)
        cor_pvalue.append(pval)
        
    information_stat = pd.DataFrame()

    information_stat['pearson'] = cor_pearson
    information_stat['pvalue'] = cor_pvalue
    information_stat.index = adata_st.var_names

    information_stat_update = information_stat.loc[((information_stat['pvalue']<p_value_threshold) & (information_stat['pearson']>cor_threshold))]
    
    return information_stat_update.index

In [4]:
seq_data = sc.read_h5ad("/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/scrnaseq_data.h5ad")
spatial_data = sc.read_h5ad("/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/spatial_data.h5ad")

In [5]:
seq_data.obs['names'] = seq_data.obs_names
spatial_data.obs['names'] = spatial_data.obs_names

seq_data.obs['ind_x'] = seq_data.obs_names
spatial_data.obs['ind_x'] = spatial_data.obs_names

spatial_index = [list(spatial_data.obs['x_coord']), list(spatial_data.obs['y_coord'])]
spatial_data.obsm['spatial'] = np.array(spatial_index).T.astype('int')

In [6]:
info_gene = calcualte_pse_correlation(seq_data, spatial_data, 'scClassify')

In [7]:
import random 
random.seed(2023)
gene_for_impute = seq_data.var_names

In [8]:
seq_data = seq_data[:,gene_for_impute]
spatial_data = spatial_data[:,info_gene]

In [9]:
spatial_data_partial = spatial_data.copy()
seq_data = seq_data.copy()

seq_gene_names = seq_data.var_names
n_genes = seq_data.n_vars

# spatial_data_partial has a subset of the genes to train on
spatial_data_partial = spatial_data_partial

# # remove cells with no counts
# scanpy.pp.filter_cells(spatial_data_partial, min_counts=1)
# scanpy.pp.filter_cells(seq_data, min_counts=1)

# setup_anndata for spatial and sequencing data
GIMVI_GCN.setup_anndata(spatial_data_partial, batch_key="batch", obs_names = 'names')
GIMVI_GCN.setup_anndata(seq_data)
# GIMVI.setup_anndata(seq_data, labels_key="graph_cluster_anno")

# spatial_data should use the same cells as our training data
# cells may have been removed by scanpy.pp.filter_cells()
spatial_data = spatial_data[spatial_data_partial.obs_names]

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [10]:
model = GIMVI_GCN(seq_data, spatial_data_partial, n_latent = 32)

In [11]:
import scvi

In [12]:
# train for 200 epochs
model.train(200)

fish_imputation_norm = model.get_imputed_values(normalized=True)[0]
fish_imputation_raw = model.get_imputed_values(normalized=False)[0]
fish_imputation_theta = model.get_imputed_theta(normalized=False)[0]

spatial_data_imputed = sc.AnnData(fish_imputation_raw, obs = spatial_data_partial.obs, var = seq_data.var)

spatial_data_imputed.obsm['imputed'] = fish_imputation_norm
spatial_data_imputed.obsm['imputed_raw'] = fish_imputation_raw
spatial_data_imputed.obsm['imputed_raw_theta'] =  fish_imputation_theta

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 14/200:   6%|▋         | 13/200 [00:25<05:49,  1.87s/it, loss=3.37e+03, v_num=1]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [13]:
# spatial_data_imputed.write_h5ad("/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish_cluster/gimvigcn_smfish_allgenes_best.h5ad")
spatial_data_imputed.write_h5ad("/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish_cluster/gimvigcn_smfish_allgenes_mode200.h5ad")

# Uncertainty quantification

In [13]:
from uncertainty.scvi_distribution import NegativeBinomial

In [15]:
import numpy as np
from numpy.linalg import norm

def cossim(A,B):
    return np.dot(A,B)/(norm(A)*norm(B))

# for cell

In [None]:
test_list = []


adata = sc.read_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/gimvigat_lat32_nei20sc_ep400_smfish_seed0.h5ad")
# adata = sc.read_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/gimvigat_lat64_nei20_ep200_smfish_seed{seed}.h5ad")
import random 
random.seed(2023)
adata.layers['imputed_raw'] = adata.obsm['imputed_raw']
adata.layers['imputed_raw_theta'] = adata.obsm['imputed_raw_theta']

adata_st_true = adata[:, test_g]
adata_st_raw = adata_st[:, test_g]

distr = NegativeBinomial(mu = torch.FloatTensor(adata_st_true.layers['imputed_raw']), 
                         theta = torch.FloatTensor(adata_st_true.layers['imputed_raw_theta']))

sample_200 = []
sample_0_1 = int(min(100, len(adata)*0.1))
for _ in range(sample_0_1):
    sample_200.append(distr.sample())

cossim_mean = []
for cell in range(0,len(adata_st_true)):
    cossim_store = []
    for item in range(sample_0_1):
        sample_data = sample_200[item]
        cossim_store.append(cossim(sample_data, adata_st_true.layers['imputed_raw'][cell,:]))
    cossim_mean.append(np.median(cossim_store))

median_upper = np.argsort(cossim_mean)[0:int(len(adata_st_true)//2)]

print(len(median_upper))

# adata = sc.read_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/gimvigat_lat64_nei20_ep200_smfish_seed{seed}.h5ad")
adata = sc.read_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/gimvigat_lat32_nei20sc_ep400_smfish_seed0.h5ad")

adata.X = adata.obsm['imputed_raw']
adata = adata[adata_st_raw.obs_names,:]
adata_st_true = adata[median_upper,:]
print(adata_st_true) # final reliable result

# for gene

In [None]:
test_list = []

adata = sc.read_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/gimvigat_lat32_nei20sc_ep400_smfish_seed0.h5ad")
import random 
random.seed(2023)
adata.layers['imputed_raw'] = adata.obsm['imputed_raw']
adata.layers['imputed_raw_theta'] = adata.obsm['imputed_raw_theta']

adata_st_true = adata[:, test_g]
adata_st_raw = adata_st[:, test_g]

distr = NegativeBinomial(mu = torch.FloatTensor(adata_st_true.layers['imputed_raw']), 
                         theta = torch.FloatTensor(adata_st_true.layers['imputed_raw_theta']))

sample_200 = []
sample_0_1 = int(min(100, len(adata)*0.1))
for _ in range(sample_0_1):
    sample_200.append(distr.sample())

cossim_mean = []
for gene in range(0,len(test_g)):
    cossim_store = []
    for item in range(sample_0_1):
        sample_data = sample_200[item][:,gene]
        cossim_store.append(cossim(sample_data, adata_st_true.layers['imputed_raw'][:,gene]))
    cossim_mean.append(np.mean(cossim_store))

median_upper = np.argsort(cossim_mean)[::-1][0:int(len(test_g)//2)] # can control the length by modifying this upper bound

print(median_upper)


adata = sc.read_h5ad(f"/gpfs/gibbs/pi/zhao/tl688/tangram/data_smfish/gimvigat_lat32_nei20sc_ep400_smfish_seed0.h5ad")

adata.X = adata.obsm['imputed_raw']
adata_st_true = adata[:, test_g]
adata_st_true = adata_st_true[:,median_upper]
