# GraphST

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>)
- **Date of Creation:** 11.01.2023
- **Date of Last Modification:** 22.07.2024 (Sebastian Birk; <sebastian.birk@helmholtz-munich.de>)

- The GraphST source code is available at https://github.com/JinmiaoChenLab/GraphST.
- The corresponding preprint is "Long, Y. et al. DeepST: A versatile graph contrastive learning framework for spatially informed clustering, integration, and deconvolution of spatial transcriptomics. Preprint at https://doi.org/10.1101/2022.08.02.502407".
- The workflow of this notebook follows the tutorial from https://deepst-tutorials.readthedocs.io/en/latest/Tutorial%201_10X%20Visium.html.

- Run this notebook in the nichecompass-reproducibility environment, installable from ('../../../envs/environment.yaml')

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import multiprocessing as mp
import os
import time
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import squidpy as sq
import torch
from GraphST import GraphST
from sklearn import metrics

### 1.2 Define Parameters

In [None]:
model_name = "graphst"
latent_key = f"{model_name}_latent"

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Directories

In [None]:
data_folder_path = "../../../datasets/st_data/gold/"
benchmarking_folder_path = "../../../artifacts/single_sample_method_benchmarking"
figure_folder_path = f"../../../figures"

## 2. GraphST Model

### 2.1 Define Training Function

In [None]:
def train_graphst_models(dataset,
                         cell_type_key,
                         niche_type_key=None,
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16]):
    # Configure figure folder path
    dataset_figure_folder_path = f"{figure_folder_path}/{dataset}/single_sample_method_benchmarking/" \
                                 f"{model_name}/{current_timestamp}"
    os.makedirs(dataset_figure_folder_path, exist_ok=True)
    
    # Create new adata to store results from training runs in storage-efficient way
    if adata_new is None:
        adata_original = sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")
        adata_new = sc.AnnData(sp.csr_matrix(
            (adata_original.shape[0], adata_original.shape[1]),
            dtype=np.float32))
        adata_new.var_names = adata_original.var_names
        adata_new.obs_names = adata_original.obs_names
        adata_new.obs["cell_type"] = adata_original.obs[cell_type_key].values
        if niche_type_key in adata_original.obs.columns:
            adata_new.obs["niche_type"] = adata_original.obs[niche_type_key].values
        adata_new.obsm["spatial"] = adata_original.obsm["spatial"]
        del(adata_original)

    model_seeds = list(range(10))
    for run_number, n_neighbors in zip(np.arange(n_start_run, n_end_run+1), n_neighbor_list):
        # Load data
        adata = sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")

        # Store raw counts in adata.X
        adata.X = adata.layers["counts"]
        if "log1p" in adata.uns:
            del(adata.uns["log1p"])

        # Compute spatial neighborhood graph
        sq.gr.spatial_neighbors(adata,
                                coord_type="generic",
                                spatial_key="spatial",
                                n_neighs=n_neighbors)
        adata.obsm["graph_neigh"] = adata.obsp["spatial_connectivities"]
        
        # Make adjacency matrix symmetric
        adata.obsm["adj"] = adata.obsm["graph_neigh"].maximum(
            adata.obsm["graph_neigh"].T)

        start_time = time.time()

        # Define model
        model = GraphST.GraphST(adata,
                                device=device,
                                random_seed=model_seeds[run_number-1])

        # Train model
        adata = model.train()
        
        # Measure time for model training
        end_time = time.time()
        elapsed_time = end_time - start_time
        hours, rem = divmod(elapsed_time, 3600)
        minutes, seconds = divmod(rem, 60)
        print(f"Duration of model training in run {run_number}: "
              f"{int(hours)} hours, {int(minutes)} minutes and {int(seconds)} seconds.")
        adata_new.uns[f"{model_name}_model_training_duration_run{run_number}"] = (
            elapsed_time)

        adata_new.obsm[latent_key + f"_run{run_number}"] = adata.obsm["emb"]

        # Store intermediate adata to disk
        adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad")

    # Store final adata to disk
    adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad") 

### 2.2 Train Models on Benchmarking Datasets

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_graphst_models(dataset="seqfish_mouse_organogenesis_embryo2",
                     cell_type_key="celltype_mapped_refined",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for subsample_pct in [50, 25, 10, 5, 1]:
    train_graphst_models(dataset=f"seqfish_mouse_organogenesis_subsample_{subsample_pct}pct_embryo2",
                         cell_type_key="celltype_mapped_refined",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for subsample_pct in [50, 25, 10, 5, 1]:
    train_graphst_models(dataset=f"nanostring_cosmx_human_nsclc_subsample_{subsample_pct}pct_batch5",
                         cell_type_key="cell_type",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for subsample_pct in [10, 5, 1]: # 50, 25 pct exhausts memory
    train_graphst_models(dataset=f"vizgen_merfish_mouse_liver_subsample_{subsample_pct}pct",
                         cell_type_key="Cell_Type",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_graphst_models(dataset="slideseqv2_mouse_hippocampus",
                     cell_type_key="cell_type",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 1 pct did not run: "There are other near singularities as well"
for subsample_pct in [50, 25, 10, 5]:
    train_graphst_models(dataset=f"slideseqv2_mouse_hippocampus_subsample_{subsample_pct}pct",
                         cell_type_key="cell_type",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_graphst_models(dataset="sim1_1105genes_10000locs_strongincrements",
                     cell_type_key="cell_types",
                     niche_type_key="niche_types",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_graphst_models(dataset="starmap_mouse_mpfc",
                     cell_type_key="cell_type",
                     niche_type_key="niche_type",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_graphst_models(dataset="stereoseq_mouse_embryo",
                     cell_type_key="leiden",
                     niche_type_key="niche_type",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])