# scVI

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>)
- **Date of Creation:** 05.01.2023
- **Date of Last Modification:** 18.07.2024 (Sebastian Birk; <sebastian.birk@helmholtz-munich.de>)

- The scVI source code is available at https://github.com/scverse/scvi-tools.
- The corresponding publication is "Lopez, R., Regier, J., Cole, M. B., Jordan, M. I. & Yosef, N. Deep generative modeling for single-cell transcriptomics. Nat. Methods 15, 1053–1058 (2018)".
- The workflow of this notebook follows the tutorial from https://docs.scvi-tools.org/en/stable/tutorials/notebooks/harmonization.html.

- Run this notebook in the cellcharter environment, installable from ```('../../../envs/environment_cellcharter.yaml')```.

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import gc
import os
import time
from datetime import datetime

import anndata as ad
import scvi
import scanpy as sc
import scipy.sparse as sp
import squidpy as sq
import matplotlib.pyplot as plt
import numpy as np

### 1.2 Define Parameters

In [None]:
model_name = "scvi"
latent_key = f"{model_name}_latent"
mapping_entity_key = "reference"
condition_key = "batch"
counts_key = "counts"
spatial_key = "spatial"
adj_key = "spatial_connectivities"

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Directories

In [None]:
st_data_gold_folder_path = "../../../datasets/st_data/gold"
st_data_results_folder_path = "../../../datasets/st_data/results" 
figure_folder_path = f"../../../figures"
benchmarking_folder_path = "../../../artifacts/sample_integration_method_benchmarking"

# Create required directories
os.makedirs(st_data_gold_folder_path, exist_ok=True)
os.makedirs(st_data_results_folder_path, exist_ok=True)

## 2. scVI Model

### 2.1 Define Training Function

In [None]:
def train_scvi_models(dataset,
                      reference_batches,
                      cell_type_key,
                      adata_new=None,
                      n_start_run=1,
                      n_end_run=8,
                      n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                      filter_genes: bool=False,
                      n_svg: int=3000):    
    # Create new adata to store results from training runs in storage-efficient way
    if adata_new is None:  
        adata_batch_list = []
        if reference_batches is not None:
            for batch in reference_batches:
                adata_batch = ad.read_h5ad(
                    f"{st_data_gold_folder_path}/{dataset}_{batch}.h5ad")
                adata_batch.obs[mapping_entity_key] = "reference"
                adata_batch_list.append(adata_batch)
            adata_original = ad.concat(adata_batch_list, join="inner")
        else:
            adata_original = ad.read_h5ad(f"{st_data_gold_folder_path}/{dataset}.h5ad")

        adata_new = sc.AnnData(sp.csr_matrix(
            (adata_original.shape[0], adata_original.shape[1]),
            dtype=np.float32))
        adata_new.var_names = adata_original.var_names
        adata_new.obs_names = adata_original.obs_names
        adata_new.obs["cell_type"] = adata_original.obs[cell_type_key].values
        adata_new.obsm["spatial"] = adata_original.obsm["spatial"]
        adata_new.obs[condition_key] = adata_original.obs[condition_key]
        adata_new.obs[mapping_entity_key] = adata_original.obs[mapping_entity_key] 
        del(adata_original)
        
    model_seeds = list(range(10))
    for run_number, n_neighbors in zip(np.arange(n_start_run, n_end_run+1), n_neighbor_list):
        # n_neighbors is here only used for the latent neighbor graph construction used for
        # UMAP generation and clustering as scVI is not a spatial method
        
        # Load data
        adata_batch_list = []
        if reference_batches is not None:
            for batch in reference_batches:
                print(f"Processing batch {batch}...")
                print("Loading data...")
                adata_batch = ad.read_h5ad(
                    f"{st_data_gold_folder_path}/{dataset}_{batch}.h5ad")
                adata_batch.obs[mapping_entity_key] = "reference"
                
                if filter_genes:
                    print("Computing spatial neighborhood graph...\n")
                    # Compute (separate) spatial neighborhood graphs
                    sq.gr.spatial_neighbors(adata_batch,
                                            coord_type="generic",
                                            spatial_key=spatial_key,
                                            n_neighs=n_neighbors)
                    # Make adjacency matrix symmetric
                    adata_batch.obsp[adj_key] = (
                        adata_batch.obsp[adj_key].maximum(
                            adata_batch.obsp[adj_key].T))
                adata_batch_list.append(adata_batch)
            adata = ad.concat(adata_batch_list, join="inner")

            if filter_genes:
                # Combine spatial neighborhood graphs as disconnected components
                batch_connectivities = []
                len_before_batch = 0
                for i in range(len(adata_batch_list)):
                    if i == 0: # first batch
                        after_batch_connectivities_extension = sp.csr_matrix(
                            (adata_batch_list[0].shape[0],
                            (adata.shape[0] -
                            adata_batch_list[0].shape[0])))
                        batch_connectivities.append(sp.hstack(
                            (adata_batch_list[0].obsp[adj_key],
                            after_batch_connectivities_extension)))
                    elif i == (len(adata_batch_list) - 1): # last batch
                        before_batch_connectivities_extension = sp.csr_matrix(
                            (adata_batch_list[i].shape[0],
                            (adata.shape[0] -
                            adata_batch_list[i].shape[0])))
                        batch_connectivities.append(sp.hstack(
                            (before_batch_connectivities_extension,
                            adata_batch_list[i].obsp[adj_key])))
                    else: # middle batches
                        before_batch_connectivities_extension = sp.csr_matrix(
                            (adata_batch_list[i].shape[0], len_before_batch))
                        after_batch_connectivities_extension = sp.csr_matrix(
                            (adata_batch_list[i].shape[0],
                            (adata.shape[0] -
                            adata_batch_list[i].shape[0] -
                            len_before_batch)))
                        batch_connectivities.append(sp.hstack(
                            (before_batch_connectivities_extension,
                            adata_batch_list[i].obsp[adj_key],
                            after_batch_connectivities_extension)))
                    len_before_batch += adata_batch_list[i].shape[0]
                connectivities = sp.vstack(batch_connectivities)
                adata.obsp[adj_key] = connectivities
        else:
            adata = ad.read_h5ad(f"{st_data_gold_folder_path}/{dataset}.h5ad")
            
            if filter_genes:
                # Compute (separate) spatial neighborhood graphs
                sq.gr.spatial_neighbors(adata,
                                        coord_type="generic",
                                        spatial_key=spatial_key,
                                        n_neighs=n_neighbors)
                # Make adjacency matrix symmetric
                adata.obsp[adj_key] = (
                    adata.obsp[adj_key].maximum(
                        adata.obsp[adj_key].T))
            
        if filter_genes:
            sc.pp.filter_genes(adata,
                               min_cells=0)
            sq.gr.spatial_autocorr(adata, mode="moran", genes=adata.var_names)
            sv_genes = adata.uns["moranI"].index[:n_svg].tolist()
            adata.var["spatially_variable"] = adata.var_names.isin(sv_genes)
            adata = adata[:, adata.var["spatially_variable"] == True].copy()
            print(f"Keeping {len(adata.var_names)} spatially variable genes.")
        
        start_time = time.time()
        
        scvi.settings.seed = model_seeds[run_number-1]

        # Setup adata
        scvi.model.SCVI.setup_anndata(adata,
                                      layer=counts_key,
                                      batch_key=condition_key)

        # Initialize model
        # Use hyperparams that provenly work well on integration tasks
        model = scvi.model.SCVI(adata,
                                n_layers=2,
                                n_latent=30,
                                gene_likelihood="nb")

        # Train model
        model.train()

        # Store latent representation
        adata.obsm[latent_key] = model.get_latent_representation()
        
        # Measure time for model training
        end_time = time.time()
        elapsed_time = end_time - start_time
        hours, rem = divmod(elapsed_time, 3600)
        minutes, seconds = divmod(rem, 60)
        print(f"Duration of model training in run {run_number}: "
              f"{int(hours)} hours, {int(minutes)} minutes and {int(seconds)} seconds.")
        adata_new.uns[f"{model_name}_model_training_duration_run{run_number}"] = (
            elapsed_time)
        
        # Store latent representation
        adata_new.obsm[latent_key + f"_run{run_number}"] = adata.obsm[latent_key]
        
        # Use latent representation for UMAP generation
        sc.pp.neighbors(adata_new,
                        use_rep=f"{latent_key}_run{run_number}",
                        key_added=f"{latent_key}_run{run_number}")
        sc.tl.umap(adata_new,
                   neighbors_key=f"{latent_key}_run{run_number}")
        adata_new.obsm[f"{latent_key}_run{run_number}_X_umap"] = adata_new.obsm["X_umap"]
        del(adata_new.obsm["X_umap"])

        # Store intermediate adata to disk
        adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad")  

        # Free memory
        del(adata)
        del(model)
        gc.collect()
        
    # Store final adata to disk
    adata_new.write(f"{benchmarking_folder_path}/{dataset}_{model_name}.h5ad") 

### 2.2 Train Models on Benchmarking Datasets

In [None]:
train_scvi_models(dataset="seqfish_mouse_organogenesis",
                  reference_batches=[f"batch{i}" for i in range(1,7)],
                  cell_type_key="celltype_mapped_refined",
                  adata_new=None,
                  n_start_run=1,
                  n_end_run=8,
                  n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
for subsample_pct in [50, 25, 10, 5, 1]:
    train_scvi_models(dataset=f"seqfish_mouse_organogenesis_subsample_{subsample_pct}pct",
                      reference_batches=[f"batch{i}" for i in range(1,7)],
                      cell_type_key="celltype_mapped_refined",
                      adata_new=None,
                      n_start_run=1,
                      n_end_run=8,
                      n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
train_scvi_models(dataset="seqfish_mouse_organogenesis_imputed",
                  reference_batches=[f"batch{i}" for i in range(1,7)],
                  cell_type_key="celltype_mapped_refined",
                  adata_new=None,
                  n_start_run=1,
                  n_end_run=8,
                  n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                  filter_genes=True,
                  n_svg=3000)

In [None]:
for subsample_pct in [50, 25, 10, 5, 1]:
    train_scvi_models(dataset=f"seqfish_mouse_organogenesis_imputed_subsample_{subsample_pct}pct",
                      reference_batches=[f"batch{i}" for i in range(1,7)],
                      cell_type_key="celltype_mapped_refined",
                      adata_new=None,
                      n_start_run=1,
                      n_end_run=8,
                      n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16],
                      filter_genes=True,
                      n_svg=3000)

In [None]:
train_scvi_models(dataset="nanostring_cosmx_human_nsclc",
                  reference_batches=[f"batch{i}" for i in range(1, 4)],
                  cell_type_key="cell_type",
                  adata_new=None,
                  n_start_run=1,
                  n_end_run=8,
                  n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])

In [None]:
for subsample_pct in [50, 25, 10, 5, 1]: # might be reversed in stored object
    train_scvi_models(dataset=f"nanostring_cosmx_human_nsclc_subsample_{subsample_pct}pct",
                      reference_batches=[f"batch{i}" for i in range(1,4)],
                      cell_type_key="cell_type",
                      adata_new=None,
                      n_start_run=1,
                      n_end_run=8,
                      n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16])