### Notebook for the generation of an integrated manifold with `scANVI`

- **Developed by:** Carlos Talavera-López Ph.D
- **Würzburg Institute for Systems Immunology & Julius-Maximilian-Universität Würzburg**
- v230811

### Import required modules

In [1]:
import torch
import scvi
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.utils import check_random_state
from scib_metrics.benchmark import Benchmarker

### Set up working environment

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

In [None]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 3,
)

### Read in Healthy data

In [None]:
adata_raw = sc.read_h5ad('../data/Marburg_cell_states_locked_ctl230811.raw.h5ad')
adata_raw

### Select randomly 20K cells to use as test set for `scANVI`

In [None]:
adata_subset = sc.pp.subsample(adata_raw, n_obs = 5000, random_state = 1712, copy = True)
adata_subset.obs['cell_compartment'] = 'Unknown'
adata_subset

In [None]:
adata_raw.obs['cell_compartment'] = adata_raw.obs['cell_compartment'].astype(str)
adata_subset.obs['cell_compartment'] = adata_subset.obs['cell_compartment'].astype(str)

subset_compartment_values = adata_subset.obs['cell_compartment']
adata_raw.obs.loc[adata_subset.obs.index, 'cell_compartment'] = subset_compartment_values
adata_raw.obs['cell_compartment'] = pd.Categorical(adata_raw.obs['cell_compartment'])

adata_raw.obs['cell_compartment'].value_counts()

In [None]:
adata_raw.obs['seed_labels'] = adata_raw.obs['cell_compartment'].copy()
adata_raw.obs['seed_labels'].value_counts()

### Select HVGs

In [None]:
raw_adata = adata_raw.copy()
adata_raw.layers['counts'] = adata_raw.X.copy()

sc.pp.highly_variable_genes(
    adata_raw,
    flavor = "seurat_v3",
    n_top_genes = 7000,
    layer = "counts",
    batch_key = "donor",
    subset = True
)
adata_raw

### Transfer of annotation with scANVI

In [None]:
scvi.model.SCVI.setup_anndata(adata_raw, 
categorical_covariate_keys = ["donor"], 
labels_key = "seed_labels", 
layer = 'counts')

In [None]:
scvi_model = scvi.model.SCVI(adata_raw, 
n_latent = 30, 
n_layers = 3, 
dispersion = 'gene-batch', 
gene_likelihood = 'nb')

In [None]:
scvi_model.train()

In [None]:
adata_raw.obsm["X_scVI"] = scvi_model.get_latent_representation(adata_raw)

### Label transfer with `scANVI` 

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown')

In [None]:
scanvi_model.train()

In [None]:
adata_raw.obs["C_scANVI"] = scanvi_model.predict(adata_raw)

- Extract latent representation

In [None]:
adata_raw.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata_raw)

### Visualise corrected dataset

In [None]:
sc.pp.neighbors(adata_raw, use_rep = "X_scANVI", n_neighbors = 30, metric = 'minkowski')
sc.tl.umap(adata_raw, min_dist = 0.4, spread = 8, random_state = 1712) 
sc.pl.umap(adata_raw, frameon = False, color = ['group', 'disease', 'infection', 'C_scANVI', 'seed_labels', 'donor', 'cell_states'], size = 0.4, legend_fontsize = 5, ncols = 4)

In [None]:
sc.pl.umap(adata_raw, frameon = False, color = ['SMK', 'n_genes', 'doublet_scores', 'batch', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'n_counts', 'sample_group', 'IAV_score'], size = 0.4, legend_fontsize = 5, ncols = 4, cmap = 'plasma')

### Compute integration metrics

In [None]:
bm = Benchmarker(
    adata_raw,
    batch_key = ["donor"],
    label_key = "C_scANVI",
    embedding_obsm_keys = ["X_pca", "X_scVI", "X_scANVI"],
    n_jobs = -1,
)
bm.benchmark()

In [None]:
bm.plot_results_table(min_max_scale = False)

### Export annotated sample object 

In [None]:
adata_raw.obs.index = pd.Index(['-'.join(idx.split('-')[:3]) for idx in adata_raw.obs.index])
adata_raw.obs.index

In [None]:
raw_adata.obs.index = pd.Index(['-'.join(idx.split('-')[:3]) for idx in raw_adata.obs.index])
raw_adata.obs.index

In [None]:
adata_raw.obs_names

In [None]:
adata_raw.obs['C_scANVI'].cat.categories

In [None]:
adata_raw.obs['C_scANVI'].value_counts()

### Export annotated object with raw counts

In [None]:
adata_raw

In [None]:
raw_adata

In [None]:
adata_export = anndata.AnnData(X = raw_adata.X, obs = adata_raw.obs, var = raw_adata.var)
adata_export.obsm['X_scVI'] = adata_raw.obsm['X_scVI'].copy()
adata_export.obsm['X_umap'] = adata_raw.obsm['X_umap'].copy()
adata_export.obsm['X_scANVI'] = adata_raw.obsm['X_scANVI'].copy()
adata_export