In [None]:
import os.path
from Methods.SpatialGlue.preprocess import pca, clr_normalize_each_cell, construct_neighbor_graph
import scanpy as sc
from matplotlib import pyplot as plt
from Methods.SpatialGlue.SpatialGlue_pyG import Train_SpatialGlue
from Methods.SpatialGlue.utils import clustering
from sklearn.metrics import adjusted_rand_score

# load simulated dataset
for dataset in ['1_Simulation', '2_Simulation', '3_Simulation']:
    data_rna = sc.read_h5ad('../Dataset/' + dataset + '/adata_RNA.h5ad')
    data_pro = sc.read_h5ad('../Dataset/' + dataset + '/adata_ADT.h5ad')

    # RNA
    sc.pp.filter_genes(data_rna, min_cells=10)
    sc.pp.normalize_total(data_rna, target_sum=1e4)
    sc.pp.log1p(data_rna)
    sc.pp.scale(data_rna)
    data_rna.obsm['feat'] = pca(data_rna, n_comps=50)

    # Protein
    data_pro = clr_normalize_each_cell(data_pro)
    sc.pp.scale(data_pro)
    data_pro.obsm['feat'] = pca(data_pro, n_comps=50)

    data = construct_neighbor_graph(data_rna, data_pro)

    model = Train_SpatialGlue(data, datatype='10x')

    output = model.train()
    data_rna.obsm['SpatialGlue'] = output['SpatialGlue']
    clustering(data_rna, key='SpatialGlue', add_key='SpatialGlue', n_clusters=10, method='mclust', use_pca=True)

    fig, ax_list = plt.subplots(1, 2, figsize=(8, 4))
    sc.pp.neighbors(data_rna, use_rep='SpatialGlue', n_neighbors=30)
    sc.tl.umap(data_rna)
    sc.pl.umap(data_rna, color='SpatialGlue', ax=ax_list[0], title='SpatialGlue\n' + dataset, s=60, show=False)
    sc.pl.embedding(data_rna, basis='spatial', color='SpatialGlue', ax=ax_list[1], title='SpatialGlue\n' + 'ARI: {:.3f}'.format(adjusted_rand_score(data_rna.obs['SpatialGlue'], data_rna.obs['cluster'])), s=200, show=False)

    plt.tight_layout(w_pad=0.3)
    result_folder = '../Results/' + dataset + '/'
    if not os.path.exists(result_folder):
        os.makedirs(result_folder, exist_ok=True)
    plt.savefig(result_folder + 'SpatialGlue.pdf')
    plt.show()

  self.alpha = F.softmax(torch.squeeze(self.vu) + 1e-6)
 34%|███▍      | 68/200 [00:08<00:13,  9.67it/s]