In [None]:
import os
import yaml
import time
import random
from pathlib import Path

import pandas as pd
import numpy as np
import torch
import scanpy as sc

import STAGATE_pyG

In [None]:
sc.set_figure_params(dpi_save=200, dpi=150)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
print(torch.cuda.device_count())

In [None]:
# parameter cell
adata1_file = ''
adata2_file = ''
emb0_file = ''
emb1_file = ''

In [None]:
def global_seed(seed: int):
    r"""
    Set seed
    
    Parameters
    ----------
    seed 
        int
    """
    seed = seed if seed != -1 else torch.seed()
    if seed > 2**32 - 1:
        seed = seed >> 32

    random.seed(seed)
    np.random.seed(seed)
    print(f"Global seed set to {seed}.")

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)
adata1.obs.index = adata1.obs.index + '_1'
adata2.obs.index = adata2.obs.index + '_2'
adata1.obs['dataset'] = 'dataset1'
adata2.obs['dataset'] = 'dataset2'

# Run STAGATE

In [None]:
start = time.time()
STAGATE_pyG.Cal_Spatial_Net(adata1, k_cutoff=20, model='KNN')
STAGATE_pyG.Cal_Spatial_Net(adata2, k_cutoff=20, model='KNN')
adata = sc.concat([adata1, adata2])
adata.uns['Spatial_Net'] = pd.concat([adata1.uns['Spatial_Net'], adata2.uns['Spatial_Net']])
# sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
# sc.pp.normalize_total(adata, target_sum=1e4)
# sc.pp.log1p(adata)

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata)
adata = STAGATE_pyG.train_STAGATE(adata)
run_time = str(time.time() - start)
print('Runtime: ' + run_time)


# Save

In [None]:
embd0 = adata[adata.obs['dataset']=='dataset1'].obsm['STAGATE']
embd1 = adata[adata.obs['dataset']=='dataset2'].obsm['STAGATE']

In [None]:
time_dic = {}
time_dic['run_time'] = run_time

out_dir = Path(os.path.dirname(emb0_file))
with open(out_dir / 'run_time.yaml', "w") as f:
    yaml.dump(time_dic, f)

In [None]:
np.savetxt(emb0_file, embd0, delimiter=',')
np.savetxt(emb1_file, embd1, delimiter=',')