### Notebook for the label transfer of Reference small dataset (GCA+stem cells) to rest of GCA using `scANVI`

- **Developed by:** Anna Maguxa
- **Institute of Computational Biology - Computational Health Department - Helmholtz Munich**
- 17th February 2022

The main difference of this notebook with other scANVI_label_transferring_on_GCA.ipynb is that in this notebook another reference dataset is used, which has less plasma cells.

### Import required modules

In [None]:
import scvi
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

### Set up working environment

In [None]:
%matplotlib inline

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'

In [None]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 2,
)

### Read in Reference object

In [None]:
reference_input = '/lustre/groups/talaveralopez/workspace/anna.maguza/Processed_datasets/expi_map/Reference_map_subset_after_geosketch2_less_plasma.h5ad'
reference_output = '/lustre/groups/talaveralopez/workspace/anna.maguza/Processed_datasets/expi_map/Reference_map_subset_after_geosketch_output.h5ad'

In [None]:
reference = sc.read_h5ad(reference_input)
reference.X

In [None]:
reference.obs

In [None]:
reference.obs['seed_labels'] = reference.obs['CellType'].copy()

### Read query object

In [None]:
query_input = '/lustre/groups/talaveralopez/workspace/anna.maguza/Processed_datasets/expi_map/Query_map_after_geosketch2_less_plasma.h5ad'
query_output = '/lustre/groups/talaveralopez/workspace/anna.maguza/Processed_datasets/expi_map/Query_map_after_geosketch2_less_plasma_output.h5ad'

In [None]:
query = sc.read_h5ad(query_input)
query.X

In [None]:
query.obs['seed_labels'] = 'Unknown'

In [None]:
# Concatenate reference and query
adata = reference.concatenate(query, batch_key = 'dataset', batch_categories = ['reference', 'query'])

### Select HVGs

In [None]:
adata_raw = adata.copy()
adata.layers['counts'] = adata.X.copy()

### Transfer of annotation with scANVI

In [None]:
scvi.model.SCVI.setup_anndata(adata, batch_key = 'Sample_ID', labels_key = "seed_labels", layer = 'counts')

In [None]:
scvi_model = scvi.model.SCVI(adata, n_latent = 50, n_layers = 3, dispersion = 'gene-batch', gene_likelihood = 'nb')

In [None]:
scvi_model.train(100)

### Label transfer with `scANVI` 

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown')

In [None]:
scanvi_model.train(25)

In [None]:
adata.obs["C_scANVI"] = scanvi_model.predict(adata)

- Extract latent representation

In [None]:
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

- Visualise corrected dataset

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scANVI", n_neighbors = 50, metric = 'minkowski')
sc.tl.umap(adata, min_dist = 0.2, spread = 2, random_state = 1712)

In [None]:
adata.obs

In [None]:
sc.pl.umap(adata, frameon = False, color = ['C_scANVI', 'CellType', 'dataset', 'Diagnosis', 'Study_name', 'Sample_ID'], size = 0.6, legend_fontsize = 5, ncols = 3)

In [None]:
#Write anndata object to file
adata.write_h5ad('/lustre/groups/talaveralopez/workspace/anna.maguza/Processed_datasets/expi_map/GCA_Stem_Cell_after_scanvi_less_plasma_cells.h5ad')

### Models Validation

In [None]:
adata.obs['C_scANVI'].value_counts()

In [None]:
adata.obs['CellTypes'].value_counts()

In [None]:
#Validation of the clusters
df = adata.obs.groupby(["CellType", "C_scANVI"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)

plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")

In [None]:
#Calculate accuracy
print(f"Acc: {np.mean(adata.obs.CellType.cat.codes == adata.obs.C_scANVI.cat.codes)}")