In [None]:
import os

from aavomics import database
from aavomics import aavomics
import scrublet

import scvi
import scanpy
import numpy
import pandas
import statsmodels
from statsmodels.stats import multitest

from skimage.filters import threshold_otsu
from statsmodels.stats.proportion import proportions_ztest

import anndata

In [None]:
SEED = 1042

In [None]:
# Which alignment to use. Set to None to use the first available
ALIGNMENT_NAME = "cellranger_5.0.1_gex_mm10_2020_A"

TAXONOMY_NAME = "CCN202105041"
SUBTYPE_TAXONOMY_NAME = "CCN202105051"

CLUSTER_OBS_NAME = "leiden_scVI"

TRANSFORM_TO_PLOT = "X_tsne"

In [None]:
adatas = []

for cell_set in database.CELL_SETS:
    
    anndata_file_path = cell_set.get_anndata_file_path(alignment_name=ALIGNMENT_NAME)
    
    if not os.path.exists(anndata_file_path):
        print("Skipping %s, anndata doesn't exist" % cell_set.name)
        continue
    
    adata = anndata.read(anndata_file_path)
    
    adata = adata[(adata.obs[TAXONOMY_NAME] == "Non-Neurons") & (adata.obs["Cell Called"] == "True")].copy()
    adata.obs["Cell Set"] = cell_set.name
    
    print(cell_set.name, adata.shape[0])
    
    adatas.append(adata)

In [None]:
adata = adatas[0].concatenate(adatas[1:])

gene_mask = (adata.X.max(axis=0) > 0).toarray().flatten()
adata = adata[:, gene_mask].copy()

In [None]:
scvi.data.setup_anndata(adata, batch_key="Cell Set")

In [None]:
vae = scvi.model.SCVI(
    adata,
    n_latent=20,
    n_layers=2,
    n_hidden=256
)

vae.train(
    seed=SEED,
    frequency=1,
    n_epochs_kl_warmup=None,
    n_iter_kl_warmup=128*5000/400, # Based on documentation at https://www.scvi-tools.org/en/stable/api/reference/scvi.core.trainers.UnsupervisedTrainer.html
)

In [None]:
normalized_gene_expression = vae.get_normalized_expression(vae.adata)
adata.obsm["X_scVI"] = vae.get_latent_representation(vae.adata)
scanpy.pp.neighbors(vae.adata, use_rep="X_scVI", random_state=SEED)
scanpy.tl.tsne(vae.adata, use_rep="X_scVI", n_jobs=16, random_state=SEED)
scanpy.tl.leiden(vae.adata, key_added=CLUSTER_OBS_NAME, random_state=SEED, resolution=2) # Resolution 2 to distinguish between doublet clusters

In [None]:
aavomics.plot_clusters(adata.obsm[TRANSFORM_TO_PLOT], adata.obs[CLUSTER_OBS_NAME], filename=os.path.join("out", "clusters_non_neuronal.html"))
aavomics.plot_clusters(adata.obsm[TRANSFORM_TO_PLOT], adata.obs["Cell Set"], filename=os.path.join("out", "samples_non_neuronal.html"))
total_transcript_counts = numpy.array(adata.X.sum(axis=1)).flatten()
aavomics.plot_gene_expression(adata.obsm[TRANSFORM_TO_PLOT], total_transcript_counts, filename=os.path.join("out", "total_transcript_counts_non_neuronal.html"))

In [None]:
import torch
import numpy as np
from scvi import _CONSTANTS
from scvi.core.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial
from typing import Dict, Iterable, Optional, Sequence, Union
from anndata import AnnData

from scvi.model._utils import (
    _get_batch_code_from_category,
    _get_var_names_from_setup_anndata,
    scrna_raw_counts_properties,
)

@torch.no_grad()
def posterior_predictive_sample(
    self,
    adata: Optional[AnnData] = None,
    indices: Optional[Sequence[int]] = None,
    n_samples: int = 1,
    gene_list: Optional[Sequence[str]] = None,
    batch_size: Optional[int] = None,
    transform_batch: Optional[int] = None
) -> np.ndarray:
    r"""
    Generate observation samples from the posterior predictive distribution.

    The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`.

    Parameters
    ----------
    adata
        AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
        AnnData object used to initialize the model.
    indices
        Indices of cells in adata to use. If `None`, all cells are used.
    n_samples
        Number of samples for each cell.
    gene_list
        Names of genes of interest.
    batch_size
        Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

    Returns
    -------
    x_new : :py:class:`torch.Tensor`
        tensor with shape (n_cells, n_genes, n_samples)
    """
    if self.model.gene_likelihood not in ["zinb", "nb", "poisson"]:
        raise ValueError("Invalid gene_likelihood.")

    adata = self._validate_anndata(adata)
    scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)

    if indices is None:
        indices = np.arange(adata.n_obs)

    if gene_list is None:
        gene_mask = slice(None)
    else:
        all_genes = _get_var_names_from_setup_anndata(adata)
        gene_mask = [True if gene in gene_list else False for gene in all_genes]

    x_new = []
    for tensors in scdl:
        x = tensors[_CONSTANTS.X_KEY]
        batch_idx = tensors[_CONSTANTS.BATCH_KEY]
        labels = tensors[_CONSTANTS.LABELS_KEY]
        outputs = self.model.inference(
            x, batch_index=batch_idx, y=labels, n_samples=n_samples, transform_batch=transform_batch
        )
        px_r = outputs["px_r"]
        px_rate = outputs["px_rate"]
        px_dropout = outputs["px_dropout"]

        if self.model.gene_likelihood == "poisson":
            l_train = px_rate
            l_train = torch.clamp(l_train, max=1e8)
            dist = torch.distributions.Poisson(
                l_train
            )  # Shape : (n_samples, n_cells_batch, n_genes)
        elif self.model.gene_likelihood == "nb":
            dist = NegativeBinomial(mu=px_rate, theta=px_r)
        elif self.model.gene_likelihood == "zinb":
            dist = ZeroInflatedNegativeBinomial(
                mu=px_rate, theta=px_r, zi_logits=px_dropout
            )
        else:
            raise ValueError(
                "{} reconstruction error not handled right now".format(
                    self.model.gene_likelihood
                )
            )
        if n_samples > 1:
            exprs = dist.sample().permute(
                [1, 2, 0]
            )  # Shape : (n_cells_batch, n_genes, n_samples)
        else:
            exprs = dist.sample()

        if gene_list is not None:
            exprs = exprs[:, gene_mask, ...]

        x_new.append(exprs.cpu())
    x_new = torch.cat(x_new)  # Shape (n_cells, n_genes, n_samples)

    return x_new.numpy()

In [None]:
clusters = sorted(adata.obs[CLUSTER_OBS_NAME].unique(), key=lambda x: int(x))

cluster_marker_p_non_zeros = {}

NUM_SAMPLES = 5000

for cluster in clusters:
    
    print(cluster)

    adata_copy = adata[adata.obs[CLUSTER_OBS_NAME] == cluster].copy()
    
    p_non_zeros = []

    for batch_id in adata_copy.obs["_scvi_batch"].unique():
        
        random_sample = adata_copy[numpy.random.choice(list(range(0, adata_copy.shape[0])), replace=True, size=NUM_SAMPLES)].copy()

        samples = posterior_predictive_sample(vae, random_sample, n_samples=1, transform_batch=batch_id)
        
        non_zero_counts = (samples != 0).sum(axis=0)
        
        p_non_zeros.append(non_zero_counts/NUM_SAMPLES)
        
    p_non_zeros = numpy.array(p_non_zeros).mean(axis=0)
    
    cluster_marker_p_non_zeros[cluster] = p_non_zeros
    
cluster_gene_non_zeros_df = pandas.DataFrame.from_dict(cluster_marker_p_non_zeros, orient="index", columns=adata.var.index)

In [None]:
CELL_TYPE_MARKER_GENES = {
    "Vascular Cells": {
        "Endothelial Cells": {
            "Slc2a1": 1,
            "Cldn5": 1
        },
        "Pericytes": {
            "Pdgfrb": 1,
            "Rgs5": 1,
            "Abcc9": 1
        },
        "Red Blood Cells": {
            "Hba-a1": 1,
            "Hba-a2": 1
        },
        "Vascular SMCs": {
            "Acta2": 1,
            "Myh11": 1,
            "Tagln": 1
        },
        "VLMCs": {
            "Vtn": 1,
            "Lum": 1,
            "Col1a2": 1,
            "Tbx18": 1
        }
    },
    "Immune Cells": {
        "Microglia": {
            "Tmem119": 1,
            "Cx3cr1": 1,
            "P2ry12": 1
        },
        "Leukocytes": {
            "Itgal": 1,
            "Gzma": 1
        },
        "Perivascular Macrophages": {
            "Mrc1": 1
        }
    },
    "Oligodendrocytes": {
        "OPCs": {
            "Pdgfra": 1,
            "Cspg4": 1
        },
        "Mature Oligodendrocytes": {
            "Mag": 1,
            "Mbp": 1
        },
        "Committed Oligodendrocytes": {
            "Ptprz1": 1,
            "Bmp4": 1,
            "Nkx2-2": 1,
            "Vcan": 1
        },
        "VLMCs": {
            "Vtn": 1,
            "Lum": 1,
            "Col1a2": 1,
            "Tbx18": 1
        }
    },
    "Astrocytes": {
        "Astrocytes": {
            "Sox9": 1
        }
    }
}

In [None]:
cluster_cell_type_label = {}

for cluster in numpy.unique(cluster_gene_non_zeros_df.index):

    possible_cell_types = set()

    for cell_type in CELL_TYPE_MARKER_GENES:
        
        possible_cell_subtypes = set()
        max_num_conditions_met = 0

        for cell_subtype in CELL_TYPE_MARKER_GENES[cell_type]:

            num_conditions_met = 0

            for marker_gene in CELL_TYPE_MARKER_GENES[cell_type][cell_subtype]:

                values = []

                ensembl_id = adata.var.loc[adata.var['Gene Name']==marker_gene].index[0]
                values = cluster_gene_non_zeros_df.loc[:, ensembl_id].values.flatten()
                cluster_value = cluster_gene_non_zeros_df.loc[cluster, ensembl_id]

                threshold = threshold_otsu(numpy.array(values))

                if cluster_value > threshold and CELL_TYPE_MARKER_GENES[cell_type][cell_subtype][marker_gene] == 1:
                    num_conditions_met += 1
                    print("Cluster %s is expressing %s" % (cluster, marker_gene))
                elif cluster_value <= threshold and CELL_TYPE_MARKER_GENES[cell_type][cell_subtype][marker_gene] == 0:
                    num_conditions_met += 1

            if num_conditions_met == len(CELL_TYPE_MARKER_GENES[cell_type][cell_subtype]):
                if num_conditions_met > max_num_conditions_met:
                    max_num_conditions_met = num_conditions_met
                    possible_cell_subtypes = set([cell_subtype])
                elif num_conditions_met == max_num_conditions_met:
                    possible_cell_subtypes.add(cell_subtype)
                    
        possible_cell_types.update(possible_cell_subtypes)

    print("Cluster %s could be" % cluster, possible_cell_types)

    if len(possible_cell_types) > 1:
        cluster_cell_type_label[cluster] = "Multiplets"
    elif len(possible_cell_types) == 0:
        cluster_cell_type_label[cluster] = "Unknown"
    else:
        cluster_cell_type_label[cluster] = list(possible_cell_types)[0]

In [None]:
new_cluster_cell_type_label = {}

separate_clusters = set()

cluster_marker_genes = {}

for cluster, cell_type_label in cluster_cell_type_label.items():
    
    print(cluster, cell_type_label)
    
    if cell_type_label == "Multiplets" or cell_type_label == "Unknown":
        continue
    
    other_clusters = []
    
    for other_cluster, other_cell_type_label in cluster_cell_type_label.items():
        if other_cell_type_label == cluster_cell_type_label[cluster] and cluster != other_cluster:
            other_clusters.append(other_cluster)
                    
    if len(other_clusters) == 0:
        continue
        
    genes_above_threshold = cluster_marker_p_non_zeros[cluster] > 0.5
    
    other_cluster_num_non_zero = 0
    num_cells_other_clusters = 0
    
    for other_cluster in other_clusters:
        genes_above_threshold = genes_above_threshold & (cluster_marker_p_non_zeros[other_cluster] < cluster_marker_p_non_zeros[cluster]/10)
        
        num_cells_other_cluster = (adata.obs[CLUSTER_OBS_NAME] == other_cluster).sum()
        num_cells_other_clusters += num_cells_other_cluster
        other_cluster_num_non_zero += num_cells_other_cluster * cluster_marker_p_non_zeros[other_cluster]
        
    genes = adata.var.copy()
    genes["p_non_zero"] = cluster_marker_p_non_zeros[cluster]
    genes["other_clusters_p_non_zero"] = other_cluster_num_non_zero/num_cells_other_clusters
    
    
    genes = genes[genes_above_threshold]
    
    num_cluster_cells = (adata.obs[CLUSTER_OBS_NAME] == cluster).sum()
    
    zs = []
    
    for gene in genes.iterrows():
        z, p = proportions_ztest(
            [gene[1]["p_non_zero"] * num_cluster_cells, gene[1]["other_clusters_p_non_zero"] * num_cells_other_clusters],
            [num_cluster_cells, num_cells_other_clusters]
        )
        
        zs.append(z)
    
    if len(genes) > 0:
        separate_clusters.add(cluster)
        genes["2proportionz"] = zs
        
        genes = genes.sort_values(by="p_non_zero", ascending=False)
        
        display(genes)
        
        top_gene = genes.iloc[0]["Gene Name"]
        
        new_cluster_cell_type_label[cluster] = "%s+ %s" % (top_gene, cell_type_label)
        
        for other_cluster in other_clusters:
            new_cluster_cell_type_label[other_cluster] = "%s- %s" % (top_gene, cell_type_label)

In [None]:
new_cluster_cell_type_label

In [None]:
for cluster, label in new_cluster_cell_type_label.items():
    
    if label == "Cldn5+ Pericytes":
        label = "Multiplets"
    elif label == "Cldn5- Pericytes":
        label = "Pericytes"
    
    cluster_cell_type_label[cluster] = label

In [None]:
cluster_cell_type_label

In [None]:
GENES_OF_INTEREST = set()

for cell_type in CELL_TYPE_MARKER_GENES:
    for cell_subtype in CELL_TYPE_MARKER_GENES[cell_type]:
        GENES_OF_INTEREST.update(CELL_TYPE_MARKER_GENES[cell_type][cell_subtype])
        
GENES_OF_INTEREST = ["C1ql1"]

for gene in GENES_OF_INTEREST:

    ensembl_id = adata.var.loc[adata.var['Gene Name']==gene].index[0]
    normalized_gene_counts = numpy.array(normalized_gene_expression.loc[:, ensembl_id].values).reshape((-1,))

    aavomics.plot_gene_expression(adata.obsm[TRANSFORM_TO_PLOT], numpy.log2(normalized_gene_counts), filename=os.path.join("out", "gene_expression_%s_normalized_non_neuronal.html" % gene))

    raw_gene_counts = numpy.array(adata[:, ensembl_id].X.todense()).reshape((-1,))

    aavomics.plot_gene_expression(adata.obsm[TRANSFORM_TO_PLOT], numpy.log2(raw_gene_counts+1), filename=os.path.join("out", "gene_expression_%s_raw_non_neuronal.html" % gene))

In [None]:
cell_type_categorical_type = pandas.CategoricalDtype(categories=list(set(cluster_cell_type_label.values())))
adata.obs[SUBTYPE_TAXONOMY_NAME] = pandas.Series(dtype=cell_type_categorical_type)

for cluster, cell_type_label in cluster_cell_type_label.items():
    adata.obs[SUBTYPE_TAXONOMY_NAME].loc[adata.obs[CLUSTER_OBS_NAME] == cluster] = cell_type_label

aavomics.plot_clusters(adata.obsm[TRANSFORM_TO_PLOT], adata.obs[SUBTYPE_TAXONOMY_NAME], filename=os.path.join("out", "cell_types_non_neuronal.html"))

In [None]:
adata.obs["p_%s" % SUBTYPE_TAXONOMY_NAME] = 1

In [None]:
COLUMNS_TO_DROP = [SUBTYPE_TAXONOMY_NAME, "p_%s" % SUBTYPE_TAXONOMY_NAME]

for cell_set_index, cell_set_name in enumerate(adata.obs["Cell Set"].unique()):
        
    print(cell_set_name)
    
    cell_set = database.CELL_SETS_DICT[cell_set_name]
    
    anndata_file_path = cell_set.get_anndata_file_path(alignment_name=ALIGNMENT_NAME)
    
    if not os.path.exists(anndata_file_path):
        print("Missing %s, skipping" % cell_set.name)
        continue
    
    cell_set_adata = anndata.read(anndata_file_path)
    
    for column in COLUMNS_TO_DROP:
    
        if column in cell_set_adata.obs.columns:
            cell_set_adata.obs.drop(column, axis=1, inplace=True)
        
    adata_filtered = adata[adata.obs["Cell Set"] == cell_set_name]
    adata_filtered.obs.index = ["-".join(x.split("-")[0:-1]) for x in adata_filtered.obs.index]
    cell_set_adata.obs[SUBTYPE_TAXONOMY_NAME] = adata_filtered.obs[SUBTYPE_TAXONOMY_NAME]
    cell_set_adata.obs["p_%s" % SUBTYPE_TAXONOMY_NAME] = adata_filtered.obs["p_%s" % SUBTYPE_TAXONOMY_NAME]
    
    cell_set_adata.write_h5ad(anndata_file_path)