In [1]:
import os
import torch
import pandas as pd
import scanpy as sc
import numpy as np
from sklearn.metrics.cluster import adjusted_rand_score,normalized_mutual_info_score,adjusted_mutual_info_score,silhouette_score
import squidpy as sq
import time,psutil,tracemalloc

In [2]:
#Be sure that R_HOME is included in the environment variant. Otherwise it needs to be defined here
os.environ["R_HOME"] = r"D:\R-4.3.1"
os.environ["PATH"]   = r"D:\R-4.3.1\bin\x64" + ";" + os.environ["PATH"]

def mk_dir(input_path):
    if not os.path.exists(input_path):
        os.makedirs(input_path)
    return input_path

def eval_model(pred, labels=None):
    if labels is not None:
        label_df = pd.DataFrame({"True": labels, "Pred": pred}).dropna()
        ari = adjusted_rand_score(label_df["True"], label_df["Pred"])
        nmi = normalized_mutual_info_score(label_df["True"], label_df["Pred"])
        ami=adjusted_mutual_info_score(label_df["True"], label_df["Pred"])
    return  ari,nmi,ami


In [3]:
def run_STAGATE(adata, dataset, random_seed=np.random.randint(100),
                device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
                save_data_path="/home/sda1/fangzy/data/st_data/Benchmark/STAGATE/",
                n_clusters=None, rad_cutoff=150):
    import STAGATE_pyG as STAGATE
    start = time.time()
    start_MB = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
    sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)

    STAGATE.Cal_Spatial_Net(adata, rad_cutoff=rad_cutoff)
    STAGATE.Stats_Spatial_Net(adata)
    adata = STAGATE.train_STAGATE(adata, device=device, random_seed=random_seed)
    sc.pp.neighbors(adata, use_rep='STAGATE')
    sc.tl.umap(adata)

    if ("ground_truth" in adata.obs.keys()):
        n_clusters = len(set(adata.obs["ground_truth"].dropna()))
    else:
        n_clusters = n_clusters
    adata = STAGATE.mclust_R(adata, used_obsm='STAGATE', num_cluster=n_clusters)

    obs_df = adata.obs.dropna()
    adata.obs["pred_label"] = adata.obs["mclust"]
    adata.obsm["embedding"] = adata.obsm["STAGATE"]

    res = {}
    if ("ground_truth" in adata.obs.keys()):
        ari, nmi, ami = eval_model(adata.obs['mclust'], adata.obs['ground_truth'])
        SC = silhouette_score(adata.obsm["embedding"],adata.obs['mclust'])

        used_adata = adata[adata.obs["ground_truth"].notna()]
        SC_revise = silhouette_score(used_adata.obsm["embedding"], used_adata.obs['ground_truth'])

        end = time.time()
        end_MB = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024  #
        used_memory = end_MB - start_MB

        res = {}
        res["dataset"] = dataset
        res["ari"] = ari
        res["nmi"] = nmi
        res["ami"] = ami
        res["sc"] = SC
        res["time"] = end - start
        res["Memo"] = used_memory
        res['SC_revise']=SC_revise

    # adata.write_h5ad(save_data_path+str(dataset)+".h5ad")
    return res, adata


In [4]:
import sys
sys.path.append('../')
import utils_for_all as usa
if __name__ == '__main__':

    # dataset1 = ["Stereo", "Breast_cancer", "Mouse_brain", "STARmap", "SeqFish", "STARmap"]
    Dataset_test = ['151673']
for dataset in Dataset_test:
    print(f"====================begin test on {dataset}======================================")
    if dataset.startswith('15'):
        save_path = f'../../Output/STAGATE/DLPFC/{dataset}/'
    else:
        save_path = f'../../Output/STAGATE/{dataset}/'
    mk_dir(save_path)

    adata, n_clusters = usa.get_adata(dataset, data_path='../../Dataset/')
    adata.var_names_make_unique()

    random_seed = 0
    rad_cutoff = 150
    results = pd.DataFrame()
    for i in range(1):
        num = i + 1
        print("===epoch:{}===".format(num))
        start = time.time()
        tracemalloc.start()
        start_MB = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
        res, adata_h5 = run_STAGATE(adata.copy(), dataset, random_seed=random_seed, rad_cutoff=rad_cutoff,n_clusters= n_clusters)

        end = time.time()
        end_MB = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
        uesd_time = end - start
        used_memo = end_MB - start_MB
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        peak = peak / 1024.0 / 1024.0 / 1024.0
        print(u'Current memory usage_end:：%.4f GB' % used_memo)
        print('time: {:.4f} s'.format(uesd_time))
        print('memory blocks peak:{:>10.4f} GB'.format(peak))
        tracemalloc.clear_traces()

        res["time"] = uesd_time
        res["Memo"] = used_memo
        res["Memo_peak"] = peak
        res["round"] = i + 1
        results = results._append(res, ignore_index=True)

    adata_h5.write_h5ad(save_path + str(dataset) + ".h5ad")
    results.set_index('dataset', inplace=True)
    results.to_csv(save_path +"/result_"+dataset+".csv", header=True)
    print(results.head())
    res_mean = results.mean()
    res_mean.to_csv(f'{save_path}{dataset}_mean.csv', header=True)
    res_std = results.std()
    res_std.to_csv(f'{save_path}{dataset}_std.csv', header=True)
    res_median = results.median()
    res_median.to_csv(f'{save_path}{dataset}_median.csv', header=True)


Current memory usage_end:：0.0000 GB
time: 0.0000 s
memory blocks peak:    0.0000 GB
load DLPFC dataset:


  utils.warn_names_duplicates("var")


===epoch:1===
------Calculating spatial graph...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)


The graph contains 21124 edges, 3639 cells.
5.8049 neighbors per cell on average.


100%|██████████| 400/400 [01:33<00:00,  4.26it/s]
R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.0.0
Type 'citation("mclust")' for citing this R package in publications.



fitting ...
Current memory usage_end:：16.0069 GB
time: 122.0316 s
memory blocks peak:    0.3739 GB
              ari       nmi       ami        sc        time       Memo  \
dataset                                                                  
151673   0.591607  0.716298  0.715533  0.185218  122.031552  16.006878   

         SC_revise  Memo_peak  round  
dataset                               
151673    0.131393   0.373933      1  
