# GraphST

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>)
- **Date of Creation:** 11.01.2023
- **Date of Last Modification:** 18.07.2024 (Sebastian Birk; <sebastian.birk@helmholtz-munich.de>)

- The GraphST source code is available at https://github.com/JinmiaoChenLab/GraphST.
- The corresponding preprint is "Long, Y. et al. DeepST: A versatile graph contrastive learning framework for spatially informed clustering, integration, and deconvolution of spatial transcriptomics. Preprint at https://doi.org/10.1101/2022.08.02.502407".
- The workflow of this notebook follows the tutorial from https://deepst-tutorials.readthedocs.io/en/latest/Tutorial%204_Horizontal%20Integration.html.

- Run this notebook in the nichecompass-reproducibility environment, installable from ```('../../../envs/environment.yaml')```.

## 1. Setup

### 1.1 Import Libraries

In [None]:
import gc
import os
import time
from datetime import datetime

import anndata as ad
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import squidpy as sq
import ot
import paste as pst
import torch
from GraphST import GraphST
from sklearn import metrics

### 1.2 Define Parameters

In [None]:
model_name = "graphst"
latent_key = f"{model_name}_latent"
mapping_entity_key = "reference"
condition_key = "batch"
counts_key = "counts"
spatial_key = "spatial"
adj_key = "spatial_connectivities"

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### 1.4 Configure Paths and Directories

In [None]:
st_data_gold_folder_path = "../../../datasets/st_data/gold"
st_data_results_folder_path = "../../../datasets/st_data/results" 
figure_folder_path = f"../../../figures"
benchmarking_folder_path = "../../../artifacts/sample_integration_method_benchmarking"

# Create required directories
os.makedirs(st_data_gold_folder_path, exist_ok=True)
os.makedirs(st_data_results_folder_path, exist_ok=True)

## 2. GraphST Model

### 2.1 Define Training Function

In [None]:
def train_graphst_models(dataset,
                         reference_batches,
                         cell_type_key,
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                         plot_latent_umaps: bool=False,
                         filter_genes: bool=False,
                         n_svg: int=3000,
                         paste_alignment=True):    
    # Create new adata to store results from training runs in storage-efficient way
    if adata_new is None:  
        adata_batch_list = []
        if reference_batches is not None:
            for batch in reference_batches:
                adata_batch = ad.read_h5ad(
                    f"{st_data_gold_folder_path}/{dataset}_{batch}.h5ad")
                adata_batch.obs[mapping_entity_key] = "reference"
                adata_batch_list.append(adata_batch)
            adata_original = ad.concat(adata_batch_list, join="inner")
        else:
            adata_original = ad.read_h5ad(f"{st_data_gold_folder_path}/{dataset}.h5ad")

        adata_new = sc.AnnData(sp.csr_matrix(
            (adata_original.shape[0], adata_original.shape[1]),
            dtype=np.float32))
        adata_new.var_names = adata_original.var_names
        adata_new.obs_names = adata_original.obs_names
        adata_new.obs["cell_type"] = adata_original.obs[cell_type_key].values
        adata_new.obsm["spatial"] = adata_original.obsm["spatial"]
        adata_new.obs[condition_key] = adata_original.obs[condition_key]
        adata_new.obs[mapping_entity_key] = adata_original.obs[mapping_entity_key] 
        del(adata_original)

    model_seeds = list(range(10))
    for run_number, n_neighbors in zip(np.arange(n_start_run, n_end_run+1), n_neighbor_list):
        # Load data
        adata_batch_list = []
        if reference_batches is not None:
            for batch in reference_batches:
                print(f"Processing batch {batch}...")
                print("Loading data...")
                adata_batch = ad.read_h5ad(
                    f"{st_data_gold_folder_path}/{dataset}_{batch}.h5ad")
                adata_batch.obs[mapping_entity_key] = "reference"
                adata_batch.X = adata_batch.layers["counts"]
                if paste_alignment:
                    adata_batch_list.append(adata_batch)
                else:
                    # Compute (separate) spatial neighborhood graphs
                    sq.gr.spatial_neighbors(adata_batch,
                                            coord_type="generic",
                                            spatial_key=spatial_key,
                                            n_neighs=n_neighbors)
                    # Make adjacency matrix symmetric
                    adata_batch.obsp[adj_key] = (
                        adata_batch.obsp[adj_key].maximum(
                            adata_batch.obsp[adj_key].T))
                    adata_batch_list.append(adata_batch)
            
            if paste_alignment:
                # Align batches with PASTE algorithm (time is measured)
                start_time = time.time()
                pis = []
                for i in range(len(adata_batch_list)):
                    if i != 0:
                        pis.append(pst.pairwise_align(adata_batch_list[i-1],
                                                      adata_batch_list[i],
                                                      numItermax=200000))
                adata_batch_list = pst.stack_slices_pairwise(adata_batch_list, pis)
                adata = ad.concat(adata_batch_list, join="inner")
                end_time = time.time()
                elapsed_time = end_time - start_time
            
                # Compute spatial neighborhood graph
                sq.gr.spatial_neighbors(adata,
                                        coord_type="generic",
                                        spatial_key=spatial_key,
                                        n_neighs=n_neighbors)
                # Make adjacency matrix symmetric
                adata.obsp[adj_key] = (
                    adata.obsp[adj_key].maximum(
                        adata.obsp[adj_key].T))
            else:
                adata = ad.concat(adata_batch_list, join="inner")
                # Combine spatial neighborhood graphs as disconnected components
                batch_connectivities = []
                len_before_batch = 0
                for i in range(len(adata_batch_list)):
                    if i == 0: # first batch
                        after_batch_connectivities_extension = sp.csr_matrix(
                            (adata_batch_list[0].shape[0],
                            (adata.shape[0] -
                            adata_batch_list[0].shape[0])))
                        batch_connectivities.append(sp.hstack(
                            (adata_batch_list[0].obsp[adj_key],
                            after_batch_connectivities_extension)))
                    elif i == (len(adata_batch_list) - 1): # last batch
                        before_batch_connectivities_extension = sp.csr_matrix(
                            (adata_batch_list[i].shape[0],
                            (adata.shape[0] -
                            adata_batch_list[i].shape[0])))
                        batch_connectivities.append(sp.hstack(
                            (before_batch_connectivities_extension,
                            adata_batch_list[i].obsp[adj_key])))
                    else: # middle batches
                        before_batch_connectivities_extension = sp.csr_matrix(
                            (adata_batch_list[i].shape[0], len_before_batch))
                        after_batch_connectivities_extension = sp.csr_matrix(
                            (adata_batch_list[i].shape[0],
                            (adata.shape[0] -
                            adata_batch_list[i].shape[0] -
                            len_before_batch)))
                        batch_connectivities.append(sp.hstack(
                            (before_batch_connectivities_extension,
                            adata_batch_list[i].obsp[adj_key],
                            after_batch_connectivities_extension)))
                    len_before_batch += adata_batch_list[i].shape[0]
                connectivities = sp.vstack(batch_connectivities)
                adata.obsp[adj_key] = connectivities
                elapsed_time = 0
        else:
            adata = ad.read_h5ad(f"{st_data_gold_folder_path}/{dataset}.h5ad")
            # Store raw counts in adata.X
            adata.X = adata.layers["counts"]
            # Compute (separate) spatial neighborhood graphs
            sq.gr.spatial_neighbors(adata,
                                    coord_type="generic",
                                    spatial_key=spatial_key,
                                    n_neighs=n_neighbors)
            # Make adjacency matrix symmetric
            adata.obsp[adj_key] = (
                adata.obsp[adj_key].maximum(
                    adata.obsp[adj_key].T))
            elapsed_time = 0
            
        if filter_genes:
            sc.pp.filter_genes(adata,
                               min_cells=0)
            sq.gr.spatial_autocorr(adata, mode="moran", genes=adata.var_names)
            sv_genes = adata.uns["moranI"].index[:n_svg].tolist()
            adata.var["spatially_variable"] = adata.var_names.isin(sv_genes)
            adata = adata[:, adata.var["spatially_variable"] == True]
            print(f"Keeping {len(adata.var_names)} spatially variable genes.")
        
        # Store raw counts in adata.X
        adata.X = adata.layers["counts"]
        if "log1p" in adata.uns:
            del(adata.uns["log1p"])

        # Supply precomputed spatial neighborhood graph
        adata.obsm["graph_neigh"] = adata.obsp[adj_key]
        
        # Make adjacency matrix symmetric
        adata.obsm["adj"] = adata.obsm["graph_neigh"].maximum(
            adata.obsm["graph_neigh"].T)

        start_time = time.time()
        
        # Define model
        model = GraphST.GraphST(adata,
                                device=device,
                                random_seed=model_seeds[run_number-1])

        # Train model
        adata = model.train()
        
        # Measure time for model training
        end_time = time.time()
        elapsed_time += end_time - start_time
        hours, rem = divmod(elapsed_time, 3600)
        minutes, seconds = divmod(rem, 60)
        print(f"Duration of model training in run {run_number}: "
              f"{int(hours)} hours, {int(minutes)} minutes and {int(seconds)} seconds.")
        adata_new.uns[f"{model_name}_model_training_duration_run{run_number}"] = (

        # Store latent representation
        adata_new.obsm[latent_key + f"_run{run_number}"] = adata.obsm["emb"]
        
        # Use latent representation for UMAP generation
        sc.pp.neighbors(adata_new,
                        use_rep=f"{latent_key}_run{run_number}",
                        key_added=f"{latent_key}_run{run_number}")
        sc.tl.umap(adata_new,
                   neighbors_key=f"{latent_key}_run{run_number}")
        adata_new.obsm[f"{latent_key}_run{run_number}_X_umap"] = adata_new.obsm["X_umap"]
        del(adata_new.obsm["X_umap"])

        # Store intermediate adata to disk
        if paste_alignment:
            adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}_paste.h5ad")
        else:
            adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad")
        
        # Free memory
        del(adata)
        del(model)
        gc.collect()

    # Store final adata to disk
    if paste_alignment:
        adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}_paste.h5ad")
    else:
        adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad")

### 2.2 Train Models on Benchmarking Datasets

In [None]:
train_graphst_models(dataset="seqfish_mouse_organogenesis",
                     reference_batches=[f"batch{i}" for i in range(1,7)],
                     cell_type_key="celltype_mapped_refined",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                     paste_alignment=False)

In [None]:
train_graphst_models(dataset="seqfish_mouse_organogenesis",
                     reference_batches=[f"batch{i}" for i in range(1,7)],
                     cell_type_key="celltype_mapped_refined",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                     paste_alignment=True)

In [None]:
for subsample_pct in [50, 25, 10, 5, 1]:
    train_graphst_models(dataset=f"seqfish_mouse_organogenesis_subsample_{subsample_pct}pct",
                         reference_batches=[f"batch{i}" for i in range(1,7)],
                         cell_type_key="celltype_mapped_refined",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                         paste_alignment=False)

In [None]:
for subsample_pct in [50, 25, 10, 5, 1]:
    train_graphst_models(dataset=f"seqfish_mouse_organogenesis_subsample_{subsample_pct}pct",
                         reference_batches=[f"batch{i}" for i in range(1,7)],
                         cell_type_key="celltype_mapped_refined",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                         paste_alignment=True)

In [None]:
train_graphst_models(dataset="seqfish_mouse_organogenesis_imputed",
                     reference_batches=[f"batch{i}" for i in range(1,7)],
                     cell_type_key="celltype_mapped_refined",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                     filter_genes=True,
                     n_svg=3000,
                     paste_alignment=False)

In [None]:
train_graphst_models(dataset="seqfish_mouse_organogenesis_imputed",
                     reference_batches=[f"batch{i}" for i in range(1,7)],
                     cell_type_key="celltype_mapped_refined",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                     filter_genes=True,
                     n_svg=3000,
                     paste_alignment=True)

In [None]:
for subsample_pct in [50, 25, 10, 5, 1]:
    train_graphst_models(dataset=f"seqfish_mouse_organogenesis_imputed_subsample_{subsample_pct}pct",
                         reference_batches=[f"batch{i}" for i in range(1,7)],
                         cell_type_key="celltype_mapped_refined",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                         filter_genes=True,
                         n_svg=3000,
                         paste_alignment=False)

In [None]:
for subsample_pct in [50, 25, 10, 5, 1]:
    train_graphst_models(dataset=f"seqfish_mouse_organogenesis_imputed_subsample_{subsample_pct}pct",
                         reference_batches=[f"batch{i}" for i in range(1,7)],
                         cell_type_key="celltype_mapped_refined",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                         filter_genes=True,
                         n_svg=3000,
                         paste_alignment=True)

In [None]:
for subsample_pct in [10, 5, 1]: # 50, 25 pct exhaust memory
    train_graphst_models(dataset=f"nanostring_cosmx_human_nsclc_subsample_{subsample_pct}pct",
                         reference_batches=[f"batch{i}" for i in range(1,4)],
                         cell_type_key="cell_type",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])