# Annotation

## Environment setup

We'll filter out some deprecation and performance warnings that do not affect our code:

In [None]:
import os
import re
import warnings
import numba
import pandas as pd
import scanpy as sc
import anndata as ad
import torch
import h5py
import scarches as sca
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import gdown

from scipy.sparse import csr_matrix
from scarches.dataset.trvae.data_handling import remove_sparsity

# Check if current working directory is named "python" and change if needed
current_dir = os.getcwd()
if os.path.basename(current_dir) == "python":
    os.chdir("../../../")
    print(f"Changed working directory to: {os.getcwd()}")
else:
    print(f"Current working directory: {os.getcwd()}")

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
warnings.simplefilter("ignore", category=NumbaDeprecationWarning)

In [None]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

## Load data

Load in allen brain atlas reference data

In [None]:
f = h5py.File("data/resources/GSE185862_expression_matrix_10x.hdf5", "r")
for key in f.keys():
    print(key) #Names of the root level object names in HDF5 file - can be groups or datasets.
    print(type(f[key])) # get the object type: usually group or dataset

In [None]:
#Get the HDF5 group; key needs to be a group name from above
group = f[key]

#Checkout what keys are inside that group.
for key in group.keys():
    print(key)
    
source_adata = group["counts"][()]
genes = group["gene"][()]
cells = group["samples"][()]

#After you are done
f.close()

In [None]:
source_adata = ad.AnnData(source_adata.T)
source_adata.obs_names = cells.astype("str")
source_adata.var_names = genes.astype("str")
source_adata

Load in metadata and match to expression matrix rows

In [None]:
metadata = pd.read_csv("data/resources/GSE185862_metadata_10x.csv")
metadata = metadata.set_index("sample_name")
metadata.reindex(index=cells)
metadata = metadata.reset_index()
metadata = metadata.set_index("sample_name")
metadata

Remove metadata cells not in the expression matrix, and visa versa

In [None]:
metadata = metadata[metadata.index.isin(source_adata.obs_names)]

In [None]:
source_adata = source_adata[source_adata.obs_names.isin(metadata.index)]
source_adata

Now lets make metadata columns match the cells in the expression matrix

In [None]:
metadata = metadata.reindex(source_adata.obs_names)

In [None]:
metadata

Now lets add this metadata to the anndata

In [None]:
source_adata.obs = metadata
source_adata.obs

Lets check some basic QC and filtering

In [None]:
sc.pp.filter_cells(source_adata, min_genes=100)
sc.pp.filter_genes(source_adata, min_cells=100)

In [None]:
# mitochondrial genes, "MT-" for human, "Mt-" for mouse
source_adata.var["mt"] = source_adata.var_names.str.startswith("mt-")
# ribosomal genes
source_adata.var["ribo"] = source_adata.var_names.str.startswith(("Rps", "Rpl"))
# hemoglobin genes
source_adata.var["hb"] = source_adata.var_names.str.contains("^Hb[^(P)]")

sc.pp.calculate_qc_metrics(
    source_adata, qc_vars=["mt", "ribo", "hb"], inplace=True, log1p=True
)

In [None]:
sc.pl.violin(
    source_adata,
    ["n_genes_by_counts", "total_counts", "pct_counts_mt"],
    stripplot=False,
    multi_panel=True,
)

In [None]:
sc.pl.scatter(source_adata, "total_counts", "n_genes_by_counts", color="pct_counts_mt")

## Find Variable Genes

In [None]:
source_adata.layers["counts"] = source_adata.X.copy()

In [None]:
# Normalizing to median total counts
sc.pp.normalize_total(source_adata)
# Logarithmize the data
sc.pp.log1p(source_adata)
sc.pp.highly_variable_genes(source_adata, n_top_genes=3000)

Subset to HVGs

In [None]:
source_adata = source_adata[:,source_adata.var.highly_variable].copy()

Revert to raw counts

In [None]:
source_adata.X = source_adata.layers["counts"]
del source_adata.layers["counts"]

Now lets make a cell type column to transfer the annotations

In [None]:
source_adata.obs["cell_type"] = source_adata.obs.cluster_label.str.extract(r'[^_]*_(.*)')

In [None]:
source_adata.write_h5ad("data/resources/source_adata.h5ad")

## Create scVI model and train it on reference dataset

Set settings

In [None]:
source_adata.obs["batch"] = 1
sca.models.SCVI.setup_anndata(source_adata, batch_key="batch")

Create the scVI model instance with NB loss as default. Insert “recon_loss=’mse’,” or “recon_loss=’zinb’,” to change the reconstruction loss

In [None]:
vae = sca.models.SCVI(
    source_adata,
    n_layers=2,
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
)

Train

In [None]:
vae.train()

Lets get the reference latent embedding for later

In [None]:
ref_emb = sc.AnnData(X=vae.get_latent_representation(), obs=source_adata.obs)
ref_emb.obs["reference_or_query"] = "reference"

In [None]:
ref_path="analysis/refscVI"
vae.save(ref_path, overwrite=True)

Let's read in the data pre-processed in R

## Automated annotation

### Annotation by mapping to a reference

scArches, which we will for reference-mapping-based label transfer, takes as its basis an existing (variational autoencoder-based) model that embeds the reference data in a low-dimensional, batch-corrected space. It then slightly extends that model to enable the mapping of an unseen dataset into the same "latent space" (i.e. the low-dimensional embedding). This model extension also enables the learning and removal of batch effects present in the mapped dataset.

Let's start by preparing our data for the mapping to a reference. scArches, the method that enables us to adapt an existing reference model to new data requires raw, non-normalized counts. We will therefore keep our counts layer and remove all other layers from our adata to map. We will set our .X to those raw counts as well.

In [None]:
source_adata = sc.read_h5ad("data/resources/source_adata.h5ad")

In [None]:
ref_path="analysis/refscVI"

In [None]:
adata = sc.read_h5ad("data/cellAnnotation/Multiome.h5ad")

In [None]:
adata_to_map = adata.copy()
for layer in list(adata_to_map.layers.keys()):
    if layer != "counts":
        del adata_to_map.layers[layer]

Lets see if we are missing any needed genes in the query dataset

In [None]:
print(
    "Percent of needed genes found in query dataset:",
    100*(adata_to_map.var_names.isin(source_adata.var_names).sum()/len(source_adata.var_names)),
)

Lets add rows of 0s for the missing genes

In [None]:
missing_genes = [
    gene
    for gene in source_adata.var_names
    if gene not in adata_to_map.var_names
]

In [None]:
missing_gene_adata = sc.AnnData(
    X=csr_matrix(np.zeros(shape=(adata.n_obs, len(missing_genes))), dtype="float32"),
    obs=adata.obs.iloc[:, :],
)
missing_gene_adata.var.index = missing_genes

Concatenate our original adata to the missing genes adata. To make sure we can do this concatenation without errors, we’ll remove the PCA matrix from varm.

In [None]:
if "PCs" in adata_to_map.varm.keys():
    del adata_to_map.varm["PCs"]

In [None]:
adata_to_map = sc.concat(
    [adata_to_map, missing_gene_adata],
    axis=1,
    join="outer",
    index_unique=None,
    merge="unique",
)

In [None]:
print(
    "Percent of needed genes found in query dataset:",
    100*(adata_to_map.var_names.isin(source_adata.var_names).sum()/len(source_adata.var_names)),
)

In [None]:
adata_to_map = adata_to_map[
    :, source_adata.var_names
].copy()

Check genes match

In [None]:
(adata_to_map.var.index == source_adata.var.index).all()

In [None]:
# Remove source adata as we no longer need it
del source_adata

Add a dummy constant batch column (just need a column)

In [None]:
adata_to_map.obs["batch"] = 2

In [None]:
new_vae = sca.models.SCVI.load_query_data(
    adata=adata_to_map,
    reference_model=ref_path,
    freeze_dropout=True,
)

We will now update this reference model so that we can embed our own data (the "query") in the same latent space as the reference. This requires training on our query data using scArches:

In [None]:
new_vae.train(max_epochs=500, plan_kwargs=dict(weight_decay=0.0))

Now that we have updated the model, we can calculate the (ideally batch-corrected) latent representation of our query:

In [None]:
adata.obsm["scVI"] = new_vae.get_latent_representation()

We can now use this newly calculated low-dimensional embedding as a basis for visualization and clustering. Let's calculate the new UMAP using the scVI-based representation of the data.

In [None]:
sc.pp.neighbors(adata, use_rep="scVI")
sc.tl.umap(adata)

To see if the mapping-based UMAP makes general sense, let's look at a few markers and if their expression is localized to specific parts of the UMAP:

In [None]:
sc.pl.umap(
    adata,
    color=["Cx3cr1", "Slc17a5", "Mbp"],
    vmin=0,
    vmax="p99",  # set vmax to the 99th percentile of the gene count instead of the maximum, to prevent outliers from making expression in other cells invisible. Note that this can cause problems for extremely lowly expressed genes.
    sort_order=False,  # do not plot highest expression on top, to not get a biased view of the mean expression among cells
    frameon=False,
    cmap="Reds",  # or choose another color map e.g. from here: https://matplotlib.org/stable/tutorials/colors/colormaps.html
)

Now the essential step is that we can combine the inferred latent space embedding of our query data with the existing reference embedding. Using this joint embedding, we will not only be able to e.g. visualize and cluster the two together, but we can also do label transfer from the query to the reference.<br> 
Let's load the reference embedding: this is often made publicly available with existing atlases.

To perform the label transfer, we will first concatenate the reference and query data using the 10-dimensional embedding. To get there, we will create the same type of AnnData object from our query data as we have from the reference (with the embedding under `.X`) and concatenate the two. With that, we can jointly analyze reference and query including doing transfer from one to the other.

In [None]:
adata_emb = sc.AnnData(X=adata.obsm["scVI"], obs=adata.obs)

In [None]:
adata_emb.obs["reference_or_query"] = "query"

In [None]:
emb_ref_query = sc.concat(
    [ref_emb, adata_emb],
    axis=0,
    join="outer",
    index_unique=None,
    merge="unique",
)

Let's visualize the joint embedding with a UMAP.

In [None]:
sc.pp.neighbors(emb_ref_query)
sc.tl.umap(emb_ref_query)

We can visually get a first impression of whether the reference and query integrated well based on the UMAP:

In [None]:
sc.pl.umap(
    emb_ref_query,
    color=["reference_or_query"],
    sort_order=False,
    frameon=False,
)

The (partial) mixing of query and reference in this UMAP is a good sign! When mapping completely fails, you will often see a full separation of query and reference in the UMAP.

Now let's look at the cell type annotations from the reference. All cells from the query are set to NA here as they don't have annotations yet and shown in black.

We'll make this figure a bit bigger so that we can read the legend well:

In [None]:
sc.set_figure_params(figsize=(10, 10))

In [None]:
sc.pl.umap(
    emb_ref_query,
    color=["cell_type"],
    sort_order=False,
    frameon=False,
    legend_loc="on data",
    legend_fontsize=10,
    na_color="black",
)

As you can already tell from the UMAP, we can guess the cell type of each of our own cells (in black) by looking at which cell types from the reference surround it. This is exactly what a nearest-neighbor-graph-based label transfer approach does: for each query cell it checks what is the most common cell type among its neighboring reference cells. The higher the fraction of reference cells coming from a single cell type, the more confident the label transfer is.

Let's perform the KNN-based label transfer. 

First we set up the label transfer model:

In [None]:
knn_transformer = sca.utils.knn.weighted_knn_trainer(
    train_adata=ref_emb,
    train_adata_emb="X",  # location of our joint embedding
    n_neighbors=15,
)

Now we perform the label transfer:

In [None]:
labels, uncert = sca.utils.knn.weighted_knn_transfer(
    query_adata=adata_emb,
    query_adata_emb="X",  # location of our embedding, query_adata.X in this case
    label_keys="cell_type",  # (start of) obs column name(s) for which to transfer labels
    knn_model=knn_transformer,
    ref_adata_obs=ref_emb.obs,
)

And store the results in our adata:

In [None]:
adata_emb.obs["transf_cell_type"] = labels.loc[adata_emb.obs.index, "cell_type"]
adata_emb.obs["transf_cell_type_unc"] = uncert.loc[adata_emb.obs.index, "cell_type"]

Let's transfer the results to our query adata object which also has our UMAP and gene counts, so that we can visualize all of those together.

In [None]:
adata.obs.loc[adata_emb.obs.index, "transf_cell_type"] = adata_emb.obs[
    "transf_cell_type"
]
adata.obs.loc[adata_emb.obs.index, "transf_cell_type_unc"] = adata_emb.obs[
    "transf_cell_type_unc"
]

We can now visualize the transferred labels in our previously calculated UMAP of our own data:

Let's set the figure size smaller again:

In [None]:
sc.set_figure_params(figsize=(15, 15))

In [None]:
sc.pl.umap(adata, color="transf_cell_type", frameon=False, legend_loc="on data")

Based on the neighbors of each of our query cells we can not only guess the cell type these cells belong to, but also generate a measure for certainty of that label: if a cell has neighbors from several different cell types, our guess will be highly uncertain. This is relevant to assess to what extent we can "trust" the transferred labels! Let's visualize the uncertainty scores:

In [None]:
sc.pl.umap(adata, color="transf_cell_type_unc", frameon=False)

Let's check for each cell type label how high the label transfer uncertainty levels were. This gives us a first impression of which annotations are more contentious/need more manual checks.

In [None]:
fig, ax = plt.subplots(figsize=(15, 7))
ct_order = (
    adata.obs.groupby("transf_cell_type")
    .agg({"transf_cell_type_unc": "median"})
    .sort_values(by="transf_cell_type_unc", ascending=False)
)
sns.boxplot(
    adata.obs,
    x="transf_cell_type",
    y="transf_cell_type_unc",
    color="grey",
    ax=ax,
    order=ct_order.index,
)
ax.tick_params(rotation=90, axis="x")

##### You'll notice that e.g. progenitor cells are often more difficult to distinguish than other cell types. Same for the rather unspecific category "Other T" cells in our annotations. All the way on the right we see pDCs, a cell type that is known to be quite transcriptionally distinct and therefore easier to recognize and label.

##### To incorporate this uncertainty information in our transferred labels, we can set cells with an uncertainty score above e.g. 0.2 to "unknown":

In [None]:
adata.obs["transf_cell_type_certain"] = adata.obs.transf_cell_type.tolist()
adata.obs.loc[
    adata.obs.transf_cell_type_unc > 0.2, "transf_cell_type_certain"
] = "Unknown"

Let's see what our annotations look like after this filtering. Note the Unknown color in the legend and the UMAP.

In [None]:
sc.pl.umap(adata, color="transf_cell_type_certain", frameon=False)

To ease legibility, we can color *only* the "unknown" cells. This will make it easier for us to see how many of those there are. You can do the same with any of the other cell type labels.

In [None]:
sc.pl.umap(adata, color="transf_cell_type_certain", groups="Unknown")

There are quite many of them! These cells will need particularly careful manual reviewing. However, the low-uncertainty annotations surrounding the "unknown cells" will already give us a first idea of what cell type we can expect each cell to belong to.

Finally, store your adata object:

In [None]:
adata.obs["transf_cell_type_unc"] = adata.obs["transf_cell_type_unc"].astype("float")

In [None]:
adata.write("data/cellAnnotation/MultiomeAnnotated.h5ad")
adata.obs.to_csv("data/cellAnnotation/MultiomeMetadataAnnotated.csv")