# Test Spot2vector on the DLPFC data

In [None]:
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import numpy as np
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt

## Import data

In [None]:
slice_id = '151672'
data_name = 'DLPFC_' + slice_id

In [None]:
n_clusters = 5

In [None]:
adatast = sc.read_h5ad('./data/DLPFC/preprocess_' + data_name + '.h5ad')

In [None]:
sc.pp.pca(adatast, n_comps=10)

## Run Spot2Vector

In [None]:
import Spot2Vector

Spot2Vector.Build_Graph(adatast, radius_cutoff=150, cutoff_type='radius', graph_type='spatial')
Spot2Vector.Build_Graph(adatast, neighbors_cutoff=4, cutoff_type='neighbors', graph_type='expression')

In [None]:
Spot2Vector.Graph_Stat_Plot(adatast)

In [None]:
device = 'cuda:0'
Spot2Vector.Fit(adatast, verbose=False, seed=6, device=device, max_epochs_st=1500)

In [None]:
sns.lineplot(adatast.uns['training_history_df_st'], x='epoch', y='loss_total')

## Clustering

In [None]:
clust_method = 'mclust'

In [None]:
Spot2Vector.Clustering(adatast, obsm_data='exp_embeddings', method=clust_method, n_cluster=n_clusters, verbose=False)
Spot2Vector.Clustering(adatast, obsm_data='spa_embeddings', method=clust_method, n_cluster=n_clusters, verbose=False)

In [None]:
# lamda = 1 for expression, lamda = 0 for spatial
Spot2Vector.Infer(adatast, lamda=0.2, device=device)

In [None]:
Spot2Vector.Clustering(adatast, obsm_data='embeddings', method = clust_method, n_cluster=n_clusters, verbose=False)

In [None]:
Spot2Vector.Clustering_Metrics(adatast, f'embeddings_{clust_method}')

## Visualization

In [None]:
sc.pl.spatial(adatast,
              color=[f"embeddings_{clust_method}",
                     f"exp_embeddings_{clust_method}", 
                     f"spa_embeddings_{clust_method}", 
                     "domain_annotation"], 
              title=[f"embeddings_{clust_method}",
                     f"exp_embeddings_{clust_method}", 
                     f"spa_embeddings_{clust_method}", 
                     "domain_annotation"], 
              size=1.3, 
              alpha=0.7)

In [None]:
plt.rcParams["figure.figsize"] = (3, 3)
sc.pl.spatial(adatast, color=f"embeddings_{clust_method}", 
              title='Spot2vector (ARI=' + str(np.round(adatast.uns['embeddings_mclust_ARI'], 2)) + ')',
              save=data_name + '_Spot2vector.svg')