# Tutorial: Spatially Variable Gene detection on SeqFISH+ (Mouse cortex)

This tutorial demonstrates how to use RGAST to detect SVG on a SeqFISH+ data.

## Preparation

In [None]:
import os,sys
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import matplotlib.colors as clr
import warnings
import RGAST
from RGAST import svg
warnings.filterwarnings("ignore")
color_self = clr.LinearSegmentedColormap.from_list('pink_green', ['#3AB370',"#EAE7CC","#FD1593"], N=256)

## Read data

In [2]:
dir_input = f'../data/seqFISH/cortex_seqFISH.h5ad'
dir_output = f'./output/SVG'
if not os.path.exists(dir_output):
    os.makedirs(dir_output)
adata = sc.read_h5ad(dir_input)

## Preprocessing

In [3]:
adata.var_names_make_unique()
adata.raw = adata.copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

## Model training

In [None]:
RGAST.Cal_Spatial_Net(adata, rad_cutoff=80, model='Radius')
RGAST.Cal_Expression_Net(adata)
#train
train_RGAST = RGAST.Train_RGAST(adata, spatial_net_arg={'rad_cutoff':80, 'model':'Radius'})
train_RGAST.train_RGAST(save_path=dir_output)

### Can also load the model parameter we trained in our study

In [None]:
train_RGAST.load_model('../model_path/cortex_seqfish.pth')
z, _ = train_RGAST.process()
adata.obsm['RGAST'] =  z.to('cpu').detach().numpy()

## Clustering

In [4]:
import RGAST.utils.res_search_fixed_clus as res_search_fixed_clus
sc.pp.neighbors(adata, use_rep='RGAST')
sc.tl.umap(adata)
_ = res_search_fixed_clus(adata, 7)
adata.obs.rename(column={'leiden':'RGAST'})

## SVG detection pipeline

In [7]:
#Set filtering criterials
min_in_group_fraction=0.8
min_in_out_group_ratio=1
min_fold_change=1.5

#Search radius such that each spot in the target domain has approximately 10 neighbors on average
x_array=adata.obs["X"].tolist()
y_array=adata.obs["Y"].tolist()
adj_2d=svg.calculate_adj_matrix(x=x_array, y=y_array)
start, end= np.quantile(adj_2d[adj_2d!=0],q=0.001), np.quantile(adj_2d[adj_2d!=0],q=0.1)

In [8]:
df = pd.DataFrame(adata.raw.X, index=adata.obs_names, columns=adata.raw.var_names)
I = svg.Moran_I(df, x=x_array, y=y_array)
C = svg.Geary_C(df, x=x_array, y=y_array)

In [None]:
svgene = []

for target in adata.obs['RGAST'].value_counts().index.to_list():

    r=svg.search_radius(target_cluster=target, cell_id=adata.obs.index.tolist(), x=x_array, y=y_array, pred=adata.obs['RGAST'].tolist(),
                        adj_2d=adj_2d, start=start, end=end, num_min=10, num_max=14,  max_run=100)
    #Detect neighboring domains
    nbr_domians=svg.find_neighbor_clusters(target_cluster=target,
                                    cell_id=adata.obs.index.tolist(), 
                                    x=x_array, 
                                    y=y_array, 
                                    pred=adata.obs['RGAST'].tolist(),
                                    radius=r,
                                    ratio=0.5)
    if nbr_domians is None:
        print('skip to next domain')
        continue
    if len(nbr_domians) > 3:
        nbr_domians=nbr_domians[0:3]
    de_genes_info=svg.rank_genes_groups(input_adata=adata,
                                    target_cluster=target,
                                    nbr_list=nbr_domians, 
                                    label_col='RGAST',
                                    adj_nbr=True, 
                                    log=True)

    #Filter genes
    de_genes_info=de_genes_info[(de_genes_info["pvals_adj"]<0.05)]
    filtered_info=de_genes_info
    filtered_info=filtered_info[(filtered_info["pvals_adj"]<0.05) &
                                (filtered_info["in_out_group_ratio"]>min_in_out_group_ratio) &
                                (filtered_info["in_group_fraction"]>min_in_group_fraction) &
                                (filtered_info["fold_change"]>min_fold_change)]
    filtered_info=filtered_info.sort_values(by="in_group_fraction", ascending=False)
    filtered_info["target_dmain"]=target
    filtered_info["neighbors"]=str(nbr_domians)
    print("SVGs for domain ", str(target),":", filtered_info["genes"].tolist())
    svgene = svgene + filtered_info["genes"].tolist()

    for g in filtered_info["genes"].tolist():
        plt.rcParams["figure.figsize"] = (5, 5)
        ax = sc.pl.scatter(adata,alpha=1,x="X",y="Y",color=g, color_map=color_self, title=g,show=False,size=100000/adata.shape[0])
        ax.axis('off')
        indices = np.where(adata.obs['RGAST']==target)
        ax.scatter(adata.obs["X"].iloc[indices], adata.obs["Y"].iloc[indices], marker='o',facecolors='red', alpha=0.02)
        ax.set_aspect('equal', 'box')
        ax.axes.invert_yaxis()
        plt.savefig(f'{dir_output}/domain'+str(target)+'-'+str(g)+'.pdf', bbox_inches='tight')