# STACI

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>)
- **Date of Creation:** 16.11.2023
- **Date of Last Modification:** 22.07.2024 (Sebastian Birk; <sebastian.birk@helmholtz-munich.de>)

- The STACI source code is available at https://github.com/uhlerlab/STACI/blob/master.
- The corresponding publication is "Zhang, X., Wang, X., Shivashankar, G. V. & Uhler, C. Graph-based autoencoder integrates spatial transcriptomics with chromatin images and identifies joint biomarkers for Alzheimer’s disease. Nat. Commun. 13, 7480 (2022)".
- The workflow of this notebook follows the notebook from https://github.com/uhlerlab/STACI/blob/master/train_gae_starmap_multisamples.ipynb.

- Run this notebook in the nichecompass-reproducibility environment, installable from ```('../../../envs/environment.yaml')```. In addition, it is required to clone the STACI repo from GitHub as follows:
    - ```cd analysis/benchmarking```
    - ```git clone https://github.com/uhlerlab/STACI.git```

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../STACI")

In [None]:
import gc
import os
import pickle
import sys
import time
from datetime import datetime

import scanpy as sc
import squidpy as sq
import numpy as np
import scipy.sparse as sp
import torch
from torch import optim
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.preprocessing import scale
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import pairwise_distances

import gae.gae.optimizer as optimizer
import gae.gae.model
import gae.gae.preprocessing as preprocessing

### 1.2 Define Parameters

In [None]:
model_name = "staci"
latent_key = f"{model_name}_latent"
spatial_key = "spatial"
adj_key = "spatial_connectivities"

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Directories

In [None]:
data_folder_path = "../../../datasets/st_data/gold/"
benchmarking_folder_path = "../../../artifacts/single_sample_method_benchmarking"
figure_folder_path = f"../../../figures"

## 2. STACI Model

### 2.1 Define Training Function

In [None]:
def train_staci_models(dataset,
                       cell_type_key,
                       niche_type_key=None,
                       adata_new=None,
                       n_start_run=1,
                       n_end_run=8,
                       n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16]):
    
    # Settings
    use_cuda=True #set to true if GPU is used 
    fastmode=True #Perform validation during training pass
    useSavedMaskedEdges=False #some edges of the adjacency matrices are held-out for validation; set to True to save and use saved version of the edge masks
    epochs=1000 #number of training epochs (default was 10000 but after 1000 no improvements achieved anymore)
    saveFreq=30 #the model parameters will be saved during training at a frequency defined by this parameter
    lr=0.001 #initial learning rate
    lr_adv=0.001 #this is ignored if not using an adversarial loss in the latent space (i.e. it is ignored for the default setup of STACI. If a discriminator is trained to use the adversarial loss, this is the learning rate of the discriminator.)
    weight_decay=0 #regularization term

    dropout=0.01 #neural network dropout term
    testNodes=0.1 #fraction of total cells used for testing
    valNodes=0.05 #fraction of total cells used for validation
    XreconWeight=20  #reconstruction weight of the gene expression
    ridgeL=0.01 #regularization weight of the gene dropout parameter
    shareGenePi=True #ignored in the default model; This is a parameter to specify how if the gene dropout term is shared for some variants of the ZINB distribution modeling as discussed in the original deep count autoencoder paper.

    targetBatch=None #if adversarial loss is used, one possibility is to make all batches look like one target batch. None, if not using this option.
    switchFreq=10 #the number of epochs spent on training the model using one sample, before switching to the next sample
    name='newModel' #name of the model
    
    #provide the paths to save the training log, trained models, and plots, and the path to the directory where the data is stored
    logsavepath=f'../STACI/log/{dataset}/'+name
    modelsavepath=f'../STACI/models/{dataset}/'+name
    plotsavepath=f'../STACI/plots/{dataset}/'+name
    savedir=f'../STACI/adjacencies/{dataset}/'+name

    if not os.path.exists(logsavepath):
        os.makedirs(logsavepath)
    if not os.path.exists(modelsavepath):
        os.makedirs(modelsavepath)
    if not os.path.exists(plotsavepath):
        os.makedirs(plotsavepath)
    if not os.path.exists(savedir):
        os.makedirs(savedir)
    
    # Configure figure folder path
    dataset_figure_folder_path = f"{figure_folder_path}/{dataset}/single_sample_method_benchmarking/" \
                                 f"{model_name}/{current_timestamp}"
    os.makedirs(dataset_figure_folder_path, exist_ok=True)
    
    # Create new adata to store results from training runs in storage-efficient way
    if adata_new is None:
        adata_original = sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")
        
        adata_new = sc.AnnData(sp.csr_matrix(
            (adata_original.shape[0], adata_original.shape[1]),
            dtype=np.float32))
        adata_new.var_names = adata_original.var_names
        adata_new.obs_names = adata_original.obs_names
        adata_new.obs["cell_type"] = adata_original.obs[cell_type_key].values
        if niche_type_key in adata_original.obs.columns:
            adata_new.obs["niche_type"] = adata_original.obs[niche_type_key].values
        adata_new.obsm["spatial"] = adata_original.obsm["spatial"]
        del(adata_original)
        gc.collect()

    model_seeds = list(range(0, 10))
    for run_number, n_neighbors in zip(np.arange(n_start_run, n_end_run+1), n_neighbor_list):
        # Load data
        adata = sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")
        
        if "seqfish" in dataset:
            adata.obs["batch"] = adata.obs["sample"]
        training_samples = adata.obs["batch"].unique().tolist()
        sampleidx = {sample: sample for sample in training_samples}
        
        # Store raw counts in adata.X
        adata.X = adata.layers["counts"]
        adata.X = adata.X.toarray()
        if "log1p" in adata.uns:
            del(adata.uns["log1p"])

        # Compute spatial neighborhood graph
        sq.gr.spatial_neighbors(adata,
                                coord_type="generic",
                                spatial_key="spatial",
                                n_neighs=n_neighbors)
        
        maskedgeName = f'knn{n_neighbors}_connectivity'
        
        # Make adjacency matrix symmetric
        adata.obsp[adj_key] = adata.obsp[adj_key].maximum(
            adata.obsp[adj_key].T)

        start_time = time.time()

        #normalize the gene expression
        #batch information should be stored in the metadata as 'sample'
        featureslist={}
        adj_list = {}
        for s in sampleidx.keys():
            adata_sample=adata[adata.obs['batch']==sampleidx[s]]
            featurelog_train=np.log2(adata_sample.X+1/2)
            scaler = MinMaxScaler()
            featurelog_train_minmax=np.transpose(scaler.fit_transform(np.transpose(featurelog_train)))
            featureslist[s+'X_logminmax']=torch.tensor(featurelog_train_minmax)

            adj_list[sampleidx[s]] = adata_sample.obsp[adj_key]
            del(adata_sample)
            gc.collect()

        num_features = adata.shape[1]
        
        hidden1=3*num_features #Number of units in hidden layer 1
        hidden2=3*num_features #Number of units in hidden layer 2
        fc_dim1=3*num_features #Number of units in the fully connected layer of the decoder
        
        del(adata)
        gc.collect()

        adjnormlist={}
        pos_weightlist={}

        normlist={}
        for ai in adj_list.keys():
            adjnormlist[ai]=preprocessing.preprocess_graph(adj_list[ai])

            pos_weightlist[ai] = torch.tensor(float(adj_list[ai].shape[0] * adj_list[ai].shape[0] - adj_list[ai].sum()) / adj_list[ai].sum()) #using full unmasked adj
            normlist[ai] = adj_list[ai].shape[0] * adj_list[ai].shape[0] / float((adj_list[ai].shape[0] * adj_list[ai].shape[0] - adj_list[ai].sum()) * 2)
            
            adj_label=adj_list[ai] + sp.eye(adj_list[ai].shape[0])
            adj_list[ai]=torch.tensor(adj_label.todense()) # very memory intensive

        rawdata=sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")
        rawdata.X = rawdata.X.toarray()
        if "seqfish" in dataset:
            rawdata.obs["batch"] = rawdata.obs["sample"]
        
        features_raw_list={}
        for s in sampleidx.keys():
            features_raw_list[s+'X_'+'raw']=torch.tensor(rawdata.X[rawdata.obs['batch']==sampleidx[s]])

        # Set cuda and seed
        np.random.seed(model_seeds[run_number-1])
        if use_cuda and (not torch.cuda.is_available()):
            print('cuda not available')
            use_cuda=False
        torch.manual_seed(model_seeds[run_number-1])
        if use_cuda:
            torch.cuda.manual_seed(model_seeds[run_number-1])
            torch.backends.cudnn.enabled = True

        # loop over all train/validation sets
        np.random.seed(model_seeds[run_number-1])
        torch.manual_seed(model_seeds[run_number-1])
        if use_cuda:
            torch.cuda.manual_seed(model_seeds[run_number-1])
            torch.backends.cudnn.enabled = True

        mse=torch.nn.MSELoss()
        # Create model
        model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA(num_features, hidden1,hidden2,fc_dim1, dropout)
        loss_kl=optimizer.optimizer_kl
        loss_x=optimizer.optimizer_zinb
        loss_a=optimizer.optimizer_CE

        if use_cuda:
            model.cuda()

        optimizerVAEXA = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        def train(epoch):
            t = time.time()
            model.train()
            optimizerVAEXA.zero_grad()

            adj_recon,mu,logvar,z,features_recon = model(features, adj_norm)

            loss_kl_train=loss_kl(mu, logvar, train_nodes_idx)
            loss_x_train=loss_x(features_recon, features,train_nodes_idx,XreconWeight,ridgeL,features_raw)
            loss_a_train=loss_a(adj_recon, adj_label, pos_weight, norm,train_nodes_idx)

            loss=loss_kl_train+loss_x_train #for lossXreconOnly_wKL only
            loss=loss+loss_a_train

            loss.backward()
            optimizerVAEXA.step()

            if not fastmode:
                # Evaluate validation set performance separately,
                # deactivates dropout during validation run & no variation in z.
                model.eval()
                adj_recon,mu,logvar,z, features_recon = model(features, adj_norm)


            loss_x_val=loss_x(features_recon, features,val_nodes_idx,XreconWeight,ridgeL,features_raw)
            loss_a_val=loss_a(adj_recon, adj_label, pos_weight, norm,val_nodes_idx)

            loss_val=loss_x_val
            loss_val=loss_val+loss_a_val

            print(training_samples_t+' Epoch: {:04d}'.format(epoch),
                  'loss_train: {:.4f}'.format(loss.item()),
                  'loss_kl_train: {:.4f}'.format(loss_kl_train.item()),
                  'loss_x_train: {:.4f}'.format(loss_x_train.item()),
                  'loss_a_train: {:.4f}'.format(loss_a_train.item()),
                  'loss_val: {:.4f}'.format(loss_val.item()),
                  'loss_x_val: {:.4f}'.format(loss_x_val.item()),
                  'loss_a_val: {:.4f}'.format(loss_a_val.item()),
                  'time: {:.4f}s'.format(time.time() - t))

            return loss.item(),loss_kl_train.item(),loss_x_train.item(),loss_a_train.item(),loss_val.item(),loss_x_val.item(),loss_a_val.item()        

        # print('cross-validation ',seti)
        train_loss_ep=[None]*epochs
        train_loss_kl_ep=[None]*epochs
        train_loss_x_ep=[None]*epochs
        train_loss_a_ep=[None]*epochs
        train_loss_adv_ep=[None]*epochs
        train_loss_advD_ep=[None]*epochs
        val_loss_ep=[None]*epochs
        val_loss_x_ep=[None]*epochs
        val_loss_a_ep=[None]*epochs
        val_loss_adv_ep=[None]*epochs
        val_loss_advD_ep=[None]*epochs
        t_ep=time.time()

        for ep in range(epochs):
            t=int(ep/switchFreq)%len(training_samples)
            training_samples_t=training_samples[t]

            adj_norm=adjnormlist[training_samples_t].cuda().float()
            adj_label=adj_list[training_samples_t].cuda().float()
            features=featureslist[training_samples_t+'X_logminmax'].cuda().float()
            pos_weight=pos_weightlist[training_samples_t]
            norm=normlist[training_samples_t]
            features_raw=features_raw_list[training_samples_t+'X_raw'].cuda()
            num_nodes,_ = features.shape

            maskpath=os.path.join(savedir,'trainMask',training_samples_t+'_'+maskedgeName+'_seed'+str(model_seeds[run_number-1])+'.pkl')
            if useSavedMaskedEdges and os.path.exists(maskpath):
                with open(maskpath, 'rb') as input:
                    maskedgeres = pickle.load(input)
            else:
                # construct training, validation, and test sets
                maskedgeres= preprocessing.mask_nodes_edges(features.shape[0],testNodeSize=testNodes,valNodeSize=valNodes,seed=model_seeds[run_number-1])
                os.makedirs(savedir+"/trainMask", exist_ok=True)
                with open(maskpath, 'wb') as output:
                    pickle.dump(maskedgeres, output, pickle.HIGHEST_PROTOCOL)
            train_nodes_idx,val_nodes_idx,test_nodes_idx = maskedgeres
            if use_cuda:
                train_nodes_idx=train_nodes_idx.cuda()
                val_nodes_idx=val_nodes_idx.cuda()
                test_nodes_idx=test_nodes_idx.cuda()

            train_loss_ep[ep],train_loss_kl_ep[ep],train_loss_x_ep[ep],train_loss_a_ep[ep],val_loss_ep[ep],val_loss_x_ep[ep],val_loss_a_ep[ep]=train(ep)

            if ep%saveFreq == 0:
                torch.save(model.cpu().state_dict(), os.path.join(modelsavepath,str(ep)+'.pt'))
            if use_cuda:
                model.cuda()
                torch.cuda.empty_cache()
        print(' total time: {:.4f}s'.format(time.time() - t_ep))
        
        # Measure time for model training
        end_time = time.time()
        elapsed_time = end_time - start_time
        hours, rem = divmod(elapsed_time, 3600)
        minutes, seconds = divmod(rem, 60)
        print(f"Duration of model training in run {run_number}: "
              f"{int(hours)} hours, {int(minutes)} minutes and {int(seconds)} seconds.")
        adata_new.uns[f"{model_name}_model_training_duration_run{run_number}"] = (
            elapsed_time)

        # Store latent representation in adata
        adj_recon,mu,logvar,z,features_recon = model(features, adj_norm)
        adata_new.obsm[latent_key + f"_run{run_number}"] = mu.cpu().detach().numpy()

        # Store intermediate adata to disk
        adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad")
        gc.collect()
        torch.cuda.empty_cache()
        
    # Store final adata to disk
    adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad") 

### 2.2 Train Models on Benchmarking Datasets

In [None]:
train_staci_models(dataset="seqfish_mouse_organogenesis_embryo2",
                   cell_type_key="celltype_mapped_refined",
                   adata_new=None,
                   n_start_run=1,
                   n_end_run=8,
                   n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
for subsample_pct in [50, 25, 10, 5, 1]:
    train_staci_models(dataset=f"seqfish_mouse_organogenesis_subsample_{subsample_pct}pct_embryo2",
                       cell_type_key="celltype_mapped_refined",
                       adata_new=None,
                       n_start_run=1,
                       n_end_run=8,
                       n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
for subsample_pct in [10, 5, 1]: # 50, 25 pct exhausts memory
    train_staci_models(dataset=f"nanostring_cosmx_human_nsclc_subsample_{subsample_pct}pct_batch5",
                       cell_type_key="cell_type",
                       adata_new=None,
                       n_start_run=1,
                       n_end_run=8,
                       n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
for subsample_pct in [5, 1]: # 50, 25, 10 pct exhausts memory
    train_staci_models(dataset=f"vizgen_merfish_mouse_liver_subsample_{subsample_pct}pct",
                       cell_type_key="Cell_Type",
                       adata_new=None,
                       n_start_run=1,
                       n_end_run=8,
                       n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
for subsample_pct in [25, 10, 5, 1]: # 50 pct exhausts memory
    train_staci_models(dataset=f"slideseqv2_mouse_hippocampus_subsample_{subsample_pct}pct",
                       cell_type_key="cell_type",
                       adata_new=None,
                       n_start_run=1,
                       n_end_run=8,
                       n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
train_staci_models(dataset="sim1_1105genes_10000locs_strongincrements",
                   cell_type_key="cell_types",
                   niche_type_key="niche_types",
                   adata_new=None,
                   n_start_run=1,
                   n_end_run=8,
                   n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
train_staci_models(dataset="starmap_mouse_mpfc",
                   cell_type_key="cell_type",
                   niche_type_key="niche_type",
                   adata_new=None,
                   n_start_run=1,
                   n_end_run=8,
                   n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])