### Notebook for prediction of stem cells population in adult healthy gut data from fetal data with `scANVI` 

- **Developed by:** Anna Maguza
- **Place:** Wuerzburg Institute for System Immunology
- **Created date:** 10th April 2024
- **Last modified date:** 6th May 2024

##### Import packages

In [None]:
import scvi
import torch
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import plotnine as p
from pywaffle import Waffle
import matplotlib.pyplot as plt
from scib_metrics.benchmark import Benchmarker

##### Setup working environment

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

In [None]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 3,
)

In [None]:
def X_is_raw(adata):
    return np.array_equal(adata.X.sum(axis=0).astype(int), adata.X.sum(axis=0))

### Read in data

+ Reference data (fetal stem cells)

In [None]:
input = '/mnt/LaCIE/annaM/gut_project/Processed_data/Gut_data/Fetal_stem_cells/Fetal_healthy_stem_cells_leiden.h5ad'
adata_ref = sc.read_h5ad(input)

In [None]:
X_is_raw(adata_ref)

+ Query data (adult stem cells)

In [None]:
adata_query = sc.read_h5ad('/mnt/LaCIE/annaM/gut_project/Processed_data/Gut_data/Healthy_reference/Integrated/Integrated_4_datasets_05042024.h5ad')

In [None]:
X_is_raw(adata_query)

#### Datasets preparation

+ Prepare obs in reference data

In [None]:
adata_query2 = adata_query.copy()

In [None]:
# filter out cells from adata_query2 that are not in adata_ref
adata_query2 = adata_query2[adata_query2.obs.index.isin(adata_ref.obs.index), :]

In [None]:
adata_ref = anndata.AnnData(X = adata_query2.X, var = adata_query2.var, obs = adata_ref.obs)

In [None]:
adata_ref.obs['seed_labels'] = adata_ref.obs['cluster'].copy()

In [None]:
adata_ref_raw = adata_ref.copy()

+ Prepare query data

In [None]:
adata_query = adata_query[adata_query.obs['Age_group'] == 'Adult', :]

In [None]:
stem_cells = ['Stem cells OLFM4', 'Stem cells OLFM4 LGR5', 'Stem cells OLFM4 PCNA', 'Stem_Cells_GCA', 'Stem cells OLFM4 GSTA1', 'Stem_Cells_ext']

adata_query = adata_query[adata_query.obs['Cell States'].isin(stem_cells), :]

In [None]:
adata_query.obs['seed_labels'] = 'Unknown'

In [None]:
adata_query_raw = adata_query.copy() 

In [None]:
xenium = sc.read_10x_h5('/mnt/LaCIE/annaM/gut_project/raw_data/Xenium_10X_datasets/Gut_samples/Non-diseased_pre-designed_and_add-on_panel/outs/cell_feature_matrix.h5')

In [None]:
xenium_genes = xenium.var.index.tolist()

In [None]:
del xenium

#### Extract HGVs

In [None]:
adata.layers['counts'] = adata.X.copy()

In [None]:
sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 2000,
    layer = "counts",
    batch_key = "Donor_ID",
    subset = True,
    span = 1
)

In [None]:
# create a list of highly variable genes
hvg = adata_ref.var.loc[adata_ref.var['highly_variable'], :].index.tolist()

In [None]:
genes_to_keep = hvg + xenium_genes

In [None]:
# remove duplicates
genes_to_keep = list(set(genes_to_keep))

+ Concatenate datasets

In [None]:
adata = adata_ref.concatenate(adata_query, batch_key = 'dataset', batch_categories = ['reference', 'query'], join='inner', index_unique = None)

In [None]:
adata

In [None]:
adata_raw = adata.copy()

+ filter genes

In [None]:
adata = adata[:, adata.var_names.isin(genes_to_keep)]

+ check if there dublicates in the dataset

In [None]:
expression_df = adata.to_df()

duplicated_cells = expression_df.duplicated(keep=False)

if duplicated_cells.any():
    print("There are cells with exactly the same gene expression profiles.")
   
    print(expression_df[duplicated_cells])
else:
    print("No cells have exactly the same gene expression profiles.")

* delete dubplicated cell

In [None]:
adata = adata[~adata.obs.index.isin(['N17_LP_B-TGCACAGAACCTCC']), :]

#### Visualise uncorrected UMAP to identify batches

In [None]:
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40)

In [None]:
sc.tl.umap(adata)

In [None]:
sc.set_figure_params(dpi = 180, figsize=(10,10))
sc.pl.umap(adata, color=['Study_name','Age_group', 'Library_Preparation_Protocol', 'Donor_ID'], size=7)

In [None]:
sc.set_figure_params(dpi = 180, figsize=(12,10))
sc.pl.umap(adata, color=['Region code', 'Fraction', 'Sex', 'Location'], size=7)

In [None]:
sc.set_figure_params(dpi = 180, figsize=(12,10))
sc.pl.umap(adata, color=['n_genes_by_counts', 'total_counts', 'total_counts_mito', 'pct_counts_mito', 'total_counts_ribo', 'pct_counts_ribo'], size=7)

#### Run scVI

In [None]:
adata.layers['counts'] = adata.X.copy()

In [None]:
scvi.model.SCVI.setup_anndata(adata, 
                              categorical_covariate_keys=['Donor_ID', 'Library_Preparation_Protocol'],
                              labels_key = "seed_labels", 
                              layer = 'counts')

In [None]:
scvi_model = scvi.model.SCVI(adata,
                            n_hidden = 1000, n_latent = 256,  
                            dispersion = 'gene-batch', 
                            gene_likelihood = 'nb')

In [None]:
scvi_model.train(500, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 accelerator = "gpu",
                 devices = [0])

In [None]:
adata.obsm["X_scVI"] = scvi_model.get_latent_representation(adata)

#### Evaluate model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scvi_model.history['elbo_train'].astype(float)
    .join(scvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

print(p_)

#### Label transfer with `scANVI` 

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown', n_hidden = 1000, n_latent = 256)

In [None]:
scanvi_model.train(250,
                   check_val_every_n_epoch = 1, 
                   enable_progress_bar = True,
                   accelerator = "gpu",
                   devices = [0])

In [None]:
adata.obs["C_scANVI"] = scanvi_model.predict(adata)

- Extract latent representation

In [None]:
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

### Explore model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scanvi_model.history['elbo_train'].astype(float)
    .join(scanvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

+ Visualize dataset

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scANVI", n_neighbors = 256, metric = 'minkowski')
sc.tl.umap(adata, min_dist = 0.2, spread = 2, random_state = 1712)

In [None]:
adata.obs['C_scANVI'].value_counts()

In [None]:
sc.set_figure_params(dpi = 300, figsize=(10,7))
sc.pl.umap(adata, color = ['Study_name', 'seed_labels', 'C_scANVI','Age_group', 'Library_Preparation_Protocol', 'Donor_ID'], ncols = 3, frameon = False, size = 5)

+ Export anndata object

In [None]:
adata_filtered = adata[adata.obs['dataset'] == 'query']

In [None]:
adata_filtered

In [None]:
adata_raw = adata_raw[adata_raw.obs['dataset'] == 'query']

In [None]:
adata_raw

In [None]:
adata_filtered.obs['C_scANVI'].value_counts()

In [None]:
adata_query_raw = adata_query_raw[~adata_query_raw.obs.index.isin(['N17_LP_B-TGCACAGAACCTCC']), :]

In [None]:
adata_filtered = anndata.AnnData(X = adata_raw.X, obs = adata_filtered.obs, var = adata_raw.var)

In [None]:
adata_filtered

In [None]:
adata_filtered.write_h5ad('/mnt/LaCIE/annaM/gut_project/Processed_data/Gut_data/Adult_stem_cells/Adult_stem_cells_round8.h5ad')

+ Check the scIB metrics

In [None]:
from rich import print
import scib
import scib.metrics
from scib_metrics.benchmark import Benchmarker, BioConservation
from scib_metrics import nearest_neighbors
from scib_metrics.nearest_neighbors import NeighborsResults
import faiss

In [None]:
def faiss_hnsw_nn(X: np.ndarray, k: int):
    """Gpu HNSW nearest neighbor search using faiss.

    See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md
    for index param details.
    """
    X = np.ascontiguousarray(X, dtype=np.float32)
    res = faiss.StandardGpuResources()
    M = 32
    index = faiss.IndexHNSWFlat(X.shape[1], M, faiss.METRIC_L2)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(X)
    distances, indices = gpu_index.search(X, k)
    del index
    del gpu_index
    # distances are squared
    return NeighborsResults(indices=indices, distances=np.sqrt(distances))


def faiss_brute_force_nn(X: np.ndarray, k: int):
    """Gpu brute force nearest neighbor search using faiss."""
    X = np.ascontiguousarray(X, dtype=np.float32)
    res = faiss.StandardGpuResources()
    index = faiss.IndexFlatL2(X.shape[1])
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(X)
    distances, indices = gpu_index.search(X, k)
    del index
    del gpu_index
    # distances are squared
    return NeighborsResults(indices=indices, distances=np.sqrt(distances))

In [None]:
bm = Benchmarker(
    adata,
    batch_key='Study_name',
    label_key='C_scANVI',
    embedding_obsm_keys=['X_scANVI', 'X_scVI', 'X_pca', 'X_umap'],
    n_jobs=-1,
)

In [None]:
bm.prepare(neighbor_computer=faiss_brute_force_nn)

In [None]:
bm.benchmark()

In [None]:
sc.set_figure_params(dpi = 300, figsize=(11,7))
bm.plot_results_table(min_max_scale=False)

In [None]:
bm = Benchmarker(
    adata,
    batch_key='Donor_ID',
    label_key='C_scANVI',
    embedding_obsm_keys=['X_scANVI', 'X_scVI', 'X_pca', 'X_umap'],
    n_jobs=-1,
)

In [None]:
bm.prepare(neighbor_computer=faiss_brute_force_nn)

In [None]:
bm.benchmark()

In [None]:
sc.set_figure_params(dpi = 300, figsize=(11,7))
bm.plot_results_table(min_max_scale=False)