# GraphST

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 11.01.2023
- **Date of Last Modification:** 16.08.2023

- The GraphST source code is available at https://github.com/JinmiaoChenLab/GraphST.
- The corresponding preprint is "Long, Y. et al. DeepST: A versatile graph contrastive learning framework for spatially informed clustering, integration, and deconvolution of spatial transcriptomics. Preprint at https://doi.org/10.1101/2022.08.02.502407".
- The workflow of this notebook follows the tutorial from https://deepst-tutorials.readthedocs.io/en/latest/Tutorial%201_10X%20Visium.html.
- The authors use raw counts as input to GraphST (stored in adata.X). Therefore, we also use raw counts.
- To define the spatial neighborhood graph, the original GraphST paper uses the 3 nearest neighbors of a cell as neighbors and the union of all neighbors is used as final spatial neighborhood graph (the adjacency matrix is made symmetric). We use the same method but vary the number of neighbors between 4, 8, 12, 16 and 20.

## 1. Setup

### 1.1 Import Libraries

In [1]:
import os
import time
from datetime import datetime

import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import squidpy as sq
import torch
from GraphST import GraphST
from sklearn import metrics

### 1.2 Define Parameters

In [2]:
model_name = "graphst"
latent_key = f"{model_name}_latent"
leiden_resolution = 0.3 # used for Leiden clustering of latent space
random_seed = 0 # used for Leiden clustering

### 1.3 Run Notebook Setup

In [3]:
sc.set_figure_params(figsize=(6, 6))

In [4]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Directories

In [5]:
data_folder_path = "../../datasets/srt_data/gold/"
benchmarking_folder_path = "../../artifacts/single_sample_method_benchmarking"
figure_folder_path = f"../../figures"

## 3. GraphST Model

### 2.1 Define Training Function

In [6]:
def train_graphst_models(dataset,
                         cell_type_key,
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                         plot_latent_umaps: bool=False):
    # Configure figure folder path
    dataset_figure_folder_path = f"{figure_folder_path}/{dataset}/single_sample_method_benchmarking/" \
                                 f"{model_name}/{current_timestamp}"
    os.makedirs(dataset_figure_folder_path, exist_ok=True)
    
    # Create new adata to store results from training runs in storage-efficient way
    if adata_new is None:
        adata_original = sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")
        adata_new = sc.AnnData(sp.csr_matrix(
            (adata_original.shape[0], adata_original.shape[1]),
            dtype=np.float32))
        adata_new.var_names = adata_original.var_names
        adata_new.obs_names = adata_original.obs_names
        adata_new.obs["cell_type"] = adata_original.obs[cell_type_key].values
        adata_new.obsm["spatial"] = adata_original.obsm["spatial"]
        del(adata_original)

    model_seeds = list(range(10))
    for run_number, n_neighbors in zip(np.arange(n_start_run, n_end_run+1), n_neighbor_list):
        # Load data
        adata = sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")

        # Store raw counts in adata.X
        adata.X = adata.layers["counts"]
        if "log1p" in adata.uns:
            del(adata.uns["log1p"])

        # Compute spatial neighborhood graph
        sq.gr.spatial_neighbors(adata,
                                coord_type="generic",
                                spatial_key="spatial",
                                n_neighs=n_neighbors)
        adata.obsm["graph_neigh"] = adata.obsp["spatial_connectivities"]
        
        # Make adjacency matrix symmetric
        adata.obsm["adj"] = adata.obsm["graph_neigh"].maximum(
            adata.obsm["graph_neigh"].T)

        start_time = time.time()

        # Define model
        model = GraphST.GraphST(adata,
                                device=device,
                                random_seed=model_seeds[run_number-1])

        # Train model
        adata = model.train()
        
        # Measure time for model training
        end_time = time.time()
        elapsed_time = end_time - start_time
        hours, rem = divmod(elapsed_time, 3600)
        minutes, seconds = divmod(rem, 60)
        print(f"Duration of model training in run {run_number}: "
              f"{int(hours)} hours, {int(minutes)} minutes and {int(seconds)} seconds.")
        adata_new.uns[f"{model_name}_model_training_duration_run{run_number}"] = (
            elapsed_time)

        if plot_latent_umaps:
            # Use GraphST latent space for UMAP generation
            sc.pp.neighbors(adata,
                            use_rep="emb",
                            n_neighbors=n_neighbors)
            sc.tl.umap(adata)
            fig = sc.pl.umap(adata,
                             color=[cell_type_key],
                             title="Latent Space with Cell Types: GraphST",
                             return_fig=True)
            fig.savefig(f"{dataset_figure_folder_path}/latent_{model_name}"
                        f"_cell_types_run{run_number}.png",
                        bbox_inches="tight")

            # Compute latent Leiden clustering
            sc.tl.leiden(adata=adata,
                         resolution=leiden_resolution,
                         random_state=random_seed,
                         key_added=f"latent_graphst_leiden_{str(leiden_resolution)}")

            # Create subplot of latent Leiden cluster annotations in physical and latent space
            fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(6, 12))
            title = fig.suptitle(t="Latent and Physical Space with Leiden Clusters: GraphST")
            sc.pl.umap(adata=adata,
                       color=[f"latent_{model_name}_leiden_{str(leiden_resolution)}"],
                       title=f"Latent Space with Leiden Clusters",
                       ax=axs[0],
                       show=False)
            sq.pl.spatial_scatter(adata=adata,
                                  color=[f"latent_{model_name}_leiden_{str(leiden_resolution)}"],
                                  title=f"Physical Space with Leiden Clusters",
                                  shape=None,
                                  ax=axs[1])

            # Create and position shared legend
            handles, labels = axs[0].get_legend_handles_labels()
            lgd = fig.legend(handles, labels, bbox_to_anchor=(1.25, 0.9185))
            axs[0].get_legend().remove()
            axs[1].get_legend().remove()

            # Adjust, save and display plot
            plt.subplots_adjust(wspace=0, hspace=0.2)
            fig.savefig(f"{dataset_figure_folder_path}/latent_physical_comparison_"
                        f"{model_name}_run{run_number}.png",
                        bbox_extra_artists=(lgd, title),
                        bbox_inches="tight")
            plt.show()

        adata_new.obsm[latent_key + f"_run{run_number}"] = adata.obsm["emb"]

        # Store intermediate adata to disk
        adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad")

    # Store final adata to disk
    adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad") 

### 2.2 Train Models on Benchmarking Datasets

In [None]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_graphst_models(dataset="seqfish_mouse_organogenesis_embryo2",
                     cell_type_key="celltype_mapped_refined",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=8,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [33]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for subsample_pct in [50, 25, 10, 5, 1]:
    train_graphst_models(dataset=f"seqfish_mouse_organogenesis_subsample_{subsample_pct}pct_embryo2",
                         cell_type_key="celltype_mapped_refined",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

AnnData object with n_obs × n_vars = 7092 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 94.41it/s]


Optimization finished for ST data!
Duration of model training in run 1: 0 hours, 0 minutes and 7 seconds.
AnnData object with n_obs × n_vars = 7092 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 96.84it/s]


Optimization finished for ST data!
Duration of model training in run 2: 0 hours, 0 minutes and 7 seconds.
AnnData object with n_obs × n_vars = 7092 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 96.71it/s]


Optimization finished for ST data!
Duration of model training in run 3: 0 hours, 0 minutes and 7 seconds.
AnnData object with n_obs × n_vars = 7092 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 96.34it/s]


Optimization finished for ST data!
Duration of model training in run 4: 0 hours, 0 minutes and 7 seconds.
AnnData object with n_obs × n_vars = 7092 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 96.19it/s]


Optimization finished for ST data!
Duration of model training in run 5: 0 hours, 0 minutes and 7 seconds.
AnnData object with n_obs × n_vars = 7092 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 96.04it/s]


Optimization finished for ST data!
Duration of model training in run 6: 0 hours, 0 minutes and 7 seconds.
AnnData object with n_obs × n_vars = 7092 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 95.73it/s]


Optimization finished for ST data!
Duration of model training in run 7: 0 hours, 0 minutes and 7 seconds.
AnnData object with n_obs × n_vars = 7092 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 95.37it/s]


Optimization finished for ST data!
Duration of model training in run 8: 0 hours, 0 minutes and 7 seconds.
AnnData object with n_obs × n_vars = 3546 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 153.59it/s]


Optimization finished for ST data!
Duration of model training in run 1: 0 hours, 0 minutes and 4 seconds.
AnnData object with n_obs × n_vars = 3546 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 153.60it/s]


Optimization finished for ST data!
Duration of model training in run 2: 0 hours, 0 minutes and 4 seconds.
AnnData object with n_obs × n_vars = 3546 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 153.48it/s]


Optimization finished for ST data!
Duration of model training in run 3: 0 hours, 0 minutes and 4 seconds.
AnnData object with n_obs × n_vars = 3546 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 153.55it/s]


Optimization finished for ST data!
Duration of model training in run 4: 0 hours, 0 minutes and 4 seconds.
AnnData object with n_obs × n_vars = 3546 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 153.29it/s]


Optimization finished for ST data!
Duration of model training in run 5: 0 hours, 0 minutes and 4 seconds.
AnnData object with n_obs × n_vars = 3546 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 153.59it/s]


Optimization finished for ST data!
Duration of model training in run 6: 0 hours, 0 minutes and 4 seconds.
AnnData object with n_obs × n_vars = 3546 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 153.27it/s]


Optimization finished for ST data!
Duration of model training in run 7: 0 hours, 0 minutes and 4 seconds.
AnnData object with n_obs × n_vars = 3546 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 152.91it/s]


Optimization finished for ST data!
Duration of model training in run 8: 0 hours, 0 minutes and 4 seconds.
AnnData object with n_obs × n_vars = 1418 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 155.16it/s]


Optimization finished for ST data!
Duration of model training in run 1: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 1418 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.83it/s]


Optimization finished for ST data!
Duration of model training in run 2: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 1418 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.35it/s]


Optimization finished for ST data!
Duration of model training in run 3: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 1418 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 152.11it/s]


Optimization finished for ST data!
Duration of model training in run 4: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 1418 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.43it/s]


Optimization finished for ST data!
Duration of model training in run 5: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 1418 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.60it/s]


Optimization finished for ST data!
Duration of model training in run 6: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 1418 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 153.50it/s]


Optimization finished for ST data!
Duration of model training in run 7: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 1418 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.65it/s]


Optimization finished for ST data!
Duration of model training in run 8: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 709 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.16it/s]


Optimization finished for ST data!
Duration of model training in run 1: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 709 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.37it/s]


Optimization finished for ST data!
Duration of model training in run 2: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 709 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.81it/s]


Optimization finished for ST data!
Duration of model training in run 3: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 709 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.07it/s]


Optimization finished for ST data!
Duration of model training in run 4: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 709 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 155.27it/s]


Optimization finished for ST data!
Duration of model training in run 5: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 709 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 152.96it/s]


Optimization finished for ST data!
Duration of model training in run 6: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 709 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 153.90it/s]


Optimization finished for ST data!
Duration of model training in run 7: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 709 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 155.42it/s]


Optimization finished for ST data!
Duration of model training in run 8: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 141 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 159.34it/s]


Optimization finished for ST data!
Duration of model training in run 1: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 141 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 159.63it/s]


Optimization finished for ST data!
Duration of model training in run 2: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 141 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 159.59it/s]


Optimization finished for ST data!
Duration of model training in run 3: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 141 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 158.23it/s]


Optimization finished for ST data!
Duration of model training in run 4: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 141 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 158.51it/s]


Optimization finished for ST data!
Duration of model training in run 5: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 141 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 157.64it/s]


Optimization finished for ST data!
Duration of model training in run 6: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 141 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 158.09it/s]


Optimization finished for ST data!
Duration of model training in run 7: 0 hours, 0 minutes and 3 seconds.
AnnData object with n_obs × n_vars = 141 × 351
    obs: 'Area', 'celltype_mapped_refined', 'sample', 'batch'
    uns: 'X_name', 'spatial_neighbors'
    obsm: 'spatial', 'graph_neigh', 'adj'
    layers: 'counts'
    obsp: 'spatial_connectivities', 'spatial_distances'
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 158.50it/s]


Optimization finished for ST data!
Duration of model training in run 8: 0 hours, 0 minutes and 3 seconds.


In [9]:
# Run device. By default, the package is implemented on 'cpu'. It is recommended to use GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for subsample_pct in [5, 1]: # 50, 25 pct exhaust memory
    train_graphst_models(dataset=f"vizgen_merfish_mouse_liver_subsample_{subsample_pct}pct",
                         cell_type_key="Cell_Type",
                         adata_new=None,
                         n_start_run=1,
                         n_end_run=8,
                         n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:35<00:00, 17.05it/s]


Optimization finished for ST data!
Duration of model training in run 1: 0 hours, 0 minutes and 39 seconds.
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:35<00:00, 16.88it/s]


Optimization finished for ST data!
Duration of model training in run 2: 0 hours, 0 minutes and 39 seconds.
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:35<00:00, 16.79it/s]


Optimization finished for ST data!
Duration of model training in run 3: 0 hours, 0 minutes and 40 seconds.
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:35<00:00, 16.73it/s]


Optimization finished for ST data!
Duration of model training in run 4: 0 hours, 0 minutes and 40 seconds.
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:35<00:00, 16.71it/s]


Optimization finished for ST data!
Duration of model training in run 5: 0 hours, 0 minutes and 40 seconds.
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:35<00:00, 16.71it/s]


Optimization finished for ST data!
Duration of model training in run 6: 0 hours, 0 minutes and 40 seconds.
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:35<00:00, 16.69it/s]


Optimization finished for ST data!
Duration of model training in run 7: 0 hours, 0 minutes and 40 seconds.
Begin to train ST data...


100%|██████████████████████████████████████████████████████████| 600/600 [00:35<00:00, 16.70it/s]


Optimization finished for ST data!
Duration of model training in run 8: 0 hours, 0 minutes and 40 seconds.
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:04<00:00, 149.05it/s]


Optimization finished for ST data!
Duration of model training in run 1: 0 hours, 0 minutes and 4 seconds.
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 156.27it/s]


Optimization finished for ST data!
Duration of model training in run 2: 0 hours, 0 minutes and 3 seconds.
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 156.37it/s]


Optimization finished for ST data!
Duration of model training in run 3: 0 hours, 0 minutes and 3 seconds.
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 156.40it/s]


Optimization finished for ST data!
Duration of model training in run 4: 0 hours, 0 minutes and 3 seconds.
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 156.36it/s]


Optimization finished for ST data!
Duration of model training in run 5: 0 hours, 0 minutes and 4 seconds.
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 156.71it/s]


Optimization finished for ST data!
Duration of model training in run 6: 0 hours, 0 minutes and 3 seconds.
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 155.64it/s]


Optimization finished for ST data!
Duration of model training in run 7: 0 hours, 0 minutes and 4 seconds.
Begin to train ST data...


100%|█████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 154.98it/s]


Optimization finished for ST data!
Duration of model training in run 8: 0 hours, 0 minutes and 4 seconds.


In [None]:
# ~10% sample of the original dataset
train_graphst_models(dataset="vizgen_merfish_mouse_liver_sample",
                     cell_type_key="Cell_Type",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=10,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20])

In [None]:
# ~20% sample of the original dataset
train_graphst_models(dataset="starmap_plus_mouse_cns_sample",
                     cell_type_key="Main_molecular_cell_type",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=10,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20])

In [None]:
# ~20% sample of the original dataset
train_graphst_models(dataset="nanostring_cosmx_human_nsclc_sample",
                     cell_type_key="cell_type",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=10,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20])

In [None]:
train_graphst_models(dataset="slideseqv2_mouse_hippocampus",
                     cell_type_key="cell_type",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=10,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20])

In [None]:
# ~50% sample of the original dataset
train_graphst_models(dataset="slideseqv2_mouse_hippocampus_sample",
                     cell_type_key="cell_type",
                     adata_new=None,
                     n_start_run=1,
                     n_end_run=10,
                     n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20])