In [None]:
import os

from aavomics import aavomics
from aavomics import database
from pepars.plotting import plotting
import anndata
import numpy
import statsmodels
from statsmodels.stats import multitest

import pandas
import scvi
import scanpy
from skimage.filters import threshold_otsu

In [None]:
# Only get this many cells, tops. Leave to None if getting all cells
MAX_NUM_CELLS = 200000

# Which alignment to use. Set to None to use the first available
ALIGNMENT_NAME = "cellranger_5.0.1_gex_mm10_2020_A"

SEED = 1042

TAXONOMY_NAME = "CCN202105041"

CLUSTER_OBS_NAME = "leiden_scVI"

TRANSFORM_TO_PLOT = "X_tsne"

In [None]:
total_num_target_cells = 0

for cell_set in database.CELL_SETS:
    total_num_target_cells += cell_set.target_num_cells
    
downsample_factor = min(1, MAX_NUM_CELLS/total_num_target_cells)

In [None]:
adatas = []
cell_set_names = []

genes_df = None

numpy.random.seed(SEED)
        
for cell_set_index, cell_set in enumerate(database.CELL_SETS):
    
    anndata_file_path = cell_set.get_anndata_file_path(alignment_name=ALIGNMENT_NAME)
    
    if not os.path.exists(anndata_file_path):
        print("Missing %s, skipping" % cell_set.name)
        continue
    
    adata = anndata.read(anndata_file_path)
    
    print(cell_set.name)
    
    barcode_total_transcript_counts = numpy.array(adata.X.sum(axis=1)).flatten()
    
    cutoff = aavomics.get_background_trough(barcode_total_transcript_counts)
    signal_mask = barcode_total_transcript_counts >= cutoff
    
    print("%s: %i droplets above threshold of %i" % (cell_set.name, signal_mask.sum(), cutoff))
    
    adata = adata[signal_mask]
    barcode_total_transcript_counts = barcode_total_transcript_counts[signal_mask]
    
    barcode_probabilities = barcode_total_transcript_counts / barcode_total_transcript_counts.sum()
    barcode_indices = list(range(len(barcode_probabilities)))
    
    num_cells = int(numpy.round(cell_set.target_num_cells * downsample_factor))
    
    weighted_random_barcode_indices = numpy.random.choice(
        barcode_indices,
        size=num_cells,
        p=barcode_probabilities,
        replace=False
    )
    
    adata = adata[weighted_random_barcode_indices]
    
    adatas.append(adata)
    cell_set_names.append(cell_set.name)
    
merged_adata = adatas[0].concatenate(adatas[1:], batch_key="Cell Set", batch_categories=numpy.array(cell_set_names))

In [None]:
gene_maxes = merged_adata.X.max(axis=0)
gene_maxes = numpy.array(gene_maxes.todense()).flatten()
gene_mask = (gene_maxes > 0)
merged_adata = merged_adata[:, gene_mask].copy()

In [None]:
cell_type_categorical_type = pandas.CategoricalDtype(categories=["Debris", "Multiplets", "Neurons", "Non-Neurons", ""])

merged_adata.obs[TAXONOMY_NAME] = pandas.Series(dtype=cell_type_categorical_type)
merged_adata.obs[TAXONOMY_NAME].loc[:] = "Debris"

In [None]:
scvi.data.setup_anndata(merged_adata, batch_key="Cell Set", labels_key=TAXONOMY_NAME)

In [None]:
vae = scvi.model.SCVI(
    merged_adata,
    n_latent=20,
    n_layers=2,
    n_hidden=256
)

vae.train(
    frequency=1,
    n_epochs_kl_warmup=None,
    n_iter_kl_warmup=128*5000/400, # Based on documentation at https://www.scvi-tools.org/en/stable/api/reference/scvi.core.trainers.UnsupervisedTrainer.html
    seed=SEED
)

In [None]:
merged_adata.obsm["X_scVI"] = vae.get_latent_representation(merged_adata)
scanpy.pp.neighbors(merged_adata, use_rep="X_scVI", random_state=SEED)
scanpy.tl.leiden(merged_adata, key_added=CLUSTER_OBS_NAME, random_state=SEED, resolution=2) # Resolution 2 to distinguish between doublet clusters
scanpy.tl.tsne(merged_adata, use_rep="X_scVI", n_jobs=16, random_state=SEED)

In [None]:
normalized_gene_expression = vae.get_normalized_expression(merged_adata)

In [None]:
aavomics.plot_clusters(merged_adata.obsm[TRANSFORM_TO_PLOT], merged_adata.obs["Cell Set"], filename=os.path.join("out", "samples.html"))
aavomics.plot_clusters(merged_adata.obsm[TRANSFORM_TO_PLOT], merged_adata.obs[CLUSTER_OBS_NAME], filename=os.path.join("out", "clusters.html"))
total_transcript_counts = numpy.array(merged_adata.X.sum(axis=1)).flatten()
aavomics.plot_gene_expression(merged_adata.obsm[TRANSFORM_TO_PLOT], total_transcript_counts, filename=os.path.join("out", "total_transcript_counts.html"))

In [None]:
cluster_gene_non_zeros_df = pandas.DataFrame(index=sorted(merged_adata.obs[CLUSTER_OBS_NAME].unique().astype(numpy.uint16)), columns=merged_adata.var.index)

In [None]:
import torch
import numpy as np
from scvi import _CONSTANTS
from scvi.core.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial
from typing import Dict, Iterable, Optional, Sequence, Union
from anndata import AnnData

from scvi.model._utils import (
    _get_batch_code_from_category,
    _get_var_names_from_setup_anndata,
    scrna_raw_counts_properties,
)

@torch.no_grad()
def posterior_predictive_sample(
    self,
    adata: Optional[AnnData] = None,
    indices: Optional[Sequence[int]] = None,
    n_samples: int = 1,
    gene_list: Optional[Sequence[str]] = None,
    batch_size: Optional[int] = None,
    transform_batch: Optional[int] = None
) -> np.ndarray:
    r"""
    Generate observation samples from the posterior predictive distribution.

    The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`.

    Parameters
    ----------
    adata
        AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
        AnnData object used to initialize the model.
    indices
        Indices of cells in adata to use. If `None`, all cells are used.
    n_samples
        Number of samples for each cell.
    gene_list
        Names of genes of interest.
    batch_size
        Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

    Returns
    -------
    x_new : :py:class:`torch.Tensor`
        tensor with shape (n_cells, n_genes, n_samples)
    """
    if self.model.gene_likelihood not in ["zinb", "nb", "poisson"]:
        raise ValueError("Invalid gene_likelihood.")

    adata = self._validate_anndata(adata)
    scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)

    if indices is None:
        indices = np.arange(adata.n_obs)

    if gene_list is None:
        gene_mask = slice(None)
    else:
        all_genes = _get_var_names_from_setup_anndata(adata)
        gene_mask = [True if gene in gene_list else False for gene in all_genes]

    x_new = []
    for tensors in scdl:
        x = tensors[_CONSTANTS.X_KEY]
        batch_idx = tensors[_CONSTANTS.BATCH_KEY]
        labels = tensors[_CONSTANTS.LABELS_KEY]
        outputs = self.model.inference(
            x, batch_index=batch_idx, y=labels, n_samples=n_samples, transform_batch=transform_batch
        )
        px_r = outputs["px_r"]
        px_rate = outputs["px_rate"]
        px_dropout = outputs["px_dropout"]

        if self.model.gene_likelihood == "poisson":
            l_train = px_rate
            l_train = torch.clamp(l_train, max=1e8)
            dist = torch.distributions.Poisson(
                l_train
            )  # Shape : (n_samples, n_cells_batch, n_genes)
        elif self.model.gene_likelihood == "nb":
            dist = NegativeBinomial(mu=px_rate, theta=px_r)
        elif self.model.gene_likelihood == "zinb":
            dist = ZeroInflatedNegativeBinomial(
                mu=px_rate, theta=px_r, zi_logits=px_dropout
            )
        else:
            raise ValueError(
                "{} reconstruction error not handled right now".format(
                    self.model.gene_likelihood
                )
            )
        if n_samples > 1:
            exprs = dist.sample().permute(
                [1, 2, 0]
            )  # Shape : (n_cells_batch, n_genes, n_samples)
        else:
            exprs = dist.sample()

        if gene_list is not None:
            exprs = exprs[:, gene_mask, ...]

        x_new.append(exprs.cpu())
    x_new = torch.cat(x_new)  # Shape (n_cells, n_genes, n_samples)

    return x_new.numpy()

In [None]:
clusters = sorted(merged_adata.obs[CLUSTER_OBS_NAME].unique(), key=lambda x: int(x))

NUM_SAMPLES = 5000

batch_cluster_gene_non_zeros_dfs = {}

for batch_id in merged_adata.obs["_scvi_batch"].unique():
    
    cluster_gene_non_zeros_df = pandas.DataFrame(index=sorted(merged_adata.obs[CLUSTER_OBS_NAME].unique().astype(numpy.uint16)), columns=merged_adata.var.index)

    for cluster in clusters:

        print("Batch %i cluster %s" % (batch_id, cluster))

        adata_copy = merged_adata[merged_adata.obs[CLUSTER_OBS_NAME] == cluster].copy()

        random_sample = adata_copy[numpy.random.choice(list(range(0, adata_copy.shape[0])), replace=True, size=NUM_SAMPLES)].copy()

        samples = posterior_predictive_sample(vae, random_sample, n_samples=1, transform_batch=batch_id)

        non_zero_counts = (samples != 0).sum(axis=0)

        p_non_zeros = numpy.array(non_zero_counts/NUM_SAMPLES)

        cluster_gene_non_zeros_df.loc[int(cluster), :] = p_non_zeros
        
    batch_cluster_gene_non_zeros_dfs[batch_id] = cluster_gene_non_zeros_df

In [None]:
mt_ratios = numpy.array(merged_adata[:, merged_adata.var["Gene Name"].str.startswith("mt-")].X.sum(axis=1)/merged_adata.X.sum(axis=1)).flatten()
aavomics.plot_gene_expression(merged_adata.obsm[TRANSFORM_TO_PLOT], mt_ratios, filename=os.path.join("out", "mt_ratios.html"))

In [None]:
MARKER_GENE_CELL_TYPES = {

    "Sox9": "Astrocytes",

    "Cldn5": "Vascular Cells", # Endothelial Cells
    "Pdgfrb": "Vascular Cells", # Pericytes
    "Hba-a1": "Vascular Cells", # Red Blood Cells

    "Rbfox3": "Neurons", # Neurons

    "Cx3cr1": "Immune Cells", # Microglia
    "Mrc1": "Immune Cells", #"Perivascular Macrophages"

    "Olig1": "Oligodendrocytes", # Oligodendrocytes
}

In [None]:
all_counts = []

clusters = merged_adata.obs[CLUSTER_OBS_NAME].unique().astype(numpy.uint16)
batches = merged_adata.obs["_scvi_batch"].unique()

marker_gene_cluster_batch_counts = {}
batch_ratio_thresholds = {}

for marker_gene in MARKER_GENE_CELL_TYPES:
    
    ensembl_id = merged_adata.var.loc[merged_adata.var['Gene Name']==marker_gene].index[0]
    
    cluster_batch_counts = {cluster: 0 for cluster in clusters}

    for batch_id in batches:

        cluster_gene_non_zeros_df = pandas.DataFrame(index=sorted(clusters), columns=merged_adata.var.index)

        for cluster in cluster_gene_non_zeros_df.index.values:
            
            cluster_gene_non_zeros_df.loc[int(cluster), :] = batch_cluster_gene_non_zeros_dfs[batch_id].loc[int(cluster)]

        values = cluster_gene_non_zeros_df[ensembl_id].astype(numpy.float32).values

        nan_filter = ~numpy.isnan(values)
        values = values[nan_filter].reshape((-1, 1))
        
        threshold = threshold_otsu(values)

        clusters_above_threshold = cluster_gene_non_zeros_df[nan_filter].index[values.flatten() > threshold]
        
        for cluster in clusters_above_threshold:
            cluster_batch_counts[cluster] += 1
            
        marker_gene_cluster_batch_counts[marker_gene] = cluster_batch_counts

    all_counts.extend(cluster_batch_counts.values())

    batch_ratio_threshold = threshold_otsu(numpy.array(list(cluster_batch_counts.values()))/len(batches))
    batch_ratio_thresholds[marker_gene] = batch_ratio_threshold

batch_ratio_threshold = threshold_otsu(numpy.array(all_counts)/len(batches))

In [None]:
clusters = merged_adata.obs[CLUSTER_OBS_NAME].unique().astype(numpy.uint16)

marker_gene_clusters = {}

for marker_gene in MARKER_GENE_CELL_TYPES:
    
    marker_gene_clusters[marker_gene] = []
    
    for cluster, batch_count in marker_gene_cluster_batch_counts[marker_gene].items():
        
        if batch_count/len(batches) > batch_ratio_thresholds[marker_gene]:
            marker_gene_clusters[marker_gene].append(cluster)
    
    print("%s clusters above threshold: " % marker_gene, marker_gene_clusters[marker_gene])

In [None]:
cell_type_clusters = {cell_type: set() for cell_type in cell_type_categorical_type.categories}

for cluster in cluster_gene_non_zeros_df.index.values:
    
    cluster_cell_types = set()
    
    cell_type_counts = {}
    max_cell_type_counts = {}
    
    for marker_gene in MARKER_GENE_CELL_TYPES:
        
        cell_type = MARKER_GENE_CELL_TYPES[marker_gene]
        
        if cell_type in max_cell_type_counts:
            max_cell_type_counts[cell_type] += 1
        else:
            max_cell_type_counts[cell_type] = 1
        
        if cluster in marker_gene_clusters[marker_gene]:
            
            if cell_type in cell_type_counts:
                cell_type_counts[cell_type] += 1
            else:
                cell_type_counts[cell_type] = 1
    
    if "Neurons" in cell_type_counts:
        if len(cell_type_counts) == 1:
            cell_type_clusters["Neurons"].add(cluster)
        else:
            cell_type_clusters["Debris"].add(cluster)
    elif len(cell_type_counts) == 0:
        cell_type_clusters["Debris"].add(cluster)
    else:
        cell_type_clusters["Non-Neurons"].add(cluster)

In [None]:
automatic_clusters = numpy.empty((merged_adata.shape[0]), dtype=numpy.object)
automatic_clusters[:] = "Debris"

for cell_type, clusters in cell_type_clusters.items():
    for cluster in clusters:
        automatic_clusters[merged_adata.obs[CLUSTER_OBS_NAME] == str(cluster)] = cell_type

In [None]:
merged_adata.obs[TAXONOMY_NAME] = merged_adata.obs[TAXONOMY_NAME].astype(cell_type_categorical_type)
merged_adata.obs[TAXONOMY_NAME].loc[:] = automatic_clusters
merged_adata.obs[TAXONOMY_NAME] = merged_adata.obs[TAXONOMY_NAME].astype(cell_type_categorical_type)

In [None]:
genes_of_interest = MARKER_GENE_CELL_TYPES

for gene in list(genes_of_interest):

    ensembl_id = merged_adata.var.loc[merged_adata.var['Gene Name']==gene].index[0]
    raw_gene_counts = numpy.array(normalized_gene_expression.loc[:, ensembl_id].values).reshape((-1,))

    aavomics.plot_gene_expression(merged_adata.obsm[TRANSFORM_TO_PLOT], numpy.log2(raw_gene_counts), filename=os.path.join("out", "gene_expression_%s_normalized.html" % gene))

    raw_gene_counts = numpy.array(merged_adata[:, ensembl_id].X.todense()).reshape((-1,))

    aavomics.plot_gene_expression(merged_adata.obsm[TRANSFORM_TO_PLOT], raw_gene_counts, filename=os.path.join("out", "gene_expression_%s_raw.html" % gene))

In [None]:
aavomics.plot_clusters(merged_adata.obsm[TRANSFORM_TO_PLOT], merged_adata.obs[TAXONOMY_NAME], filename=os.path.join("out", "%s.html" % TAXONOMY_NAME))

In [None]:
import scrublet

merged_adata.obs["doublet"] = numpy.array(merged_adata.obs["Cell Set"] == None).astype(numpy.bool)
merged_adata.obs["p_doublet"] = numpy.zeros((merged_adata.shape[0]))

for cell_set in merged_adata.obs["Cell Set"].unique():
    
    print(cell_set)
    
    mask = merged_adata.obs["Cell Set"] == cell_set
    
    cell_set_counts = merged_adata[mask].X
    
    scrub = scrublet.Scrublet(cell_set_counts)
    
    doublet_scores, predicted_doublets = scrub.scrub_doublets()
    
    if predicted_doublets is not None:
        merged_adata.obs.loc[mask, "doublet"] = predicted_doublets
        merged_adata.obs.loc[mask, "p_doublet"] = doublet_scores

In [None]:
aavomics.plot_clusters(merged_adata.obsm[TRANSFORM_TO_PLOT], merged_adata.obs["doublet"], filename=os.path.join("out", "doublet.html"))

In [None]:
percents = []

for cluster in merged_adata.obs[CLUSTER_OBS_NAME].unique():
    
    cluster_mask = merged_adata.obs[CLUSTER_OBS_NAME] == cluster
    cluster_multiplet_mask = cluster_mask & merged_adata.obs["doublet"]
    
    percent = cluster_multiplet_mask.sum()/cluster_mask.sum()*100
    percents.append(percent)

percent_threshold = threshold_otsu(numpy.array(percents))
percent_threshold

In [None]:
merged_adata.obs[TAXONOMY_NAME].loc[:] = automatic_clusters
merged_adata.obs[TAXONOMY_NAME] = merged_adata.obs[TAXONOMY_NAME].astype(cell_type_categorical_type)

In [None]:
merged_adata.obs.loc[(merged_adata.obs["doublet"]) & (merged_adata.obs[TAXONOMY_NAME] != "Debris"), TAXONOMY_NAME] = "Multiplets"

In [None]:
for cluster in merged_adata.obs[CLUSTER_OBS_NAME].unique():
    
    cluster_mask = merged_adata.obs[CLUSTER_OBS_NAME] == cluster
    cluster_multiplet_mask = cluster_mask & merged_adata.obs["doublet"]
    
    percent = cluster_multiplet_mask.sum()/cluster_mask.sum()*100
    
    if percent > percent_threshold:
        merged_adata.obs.loc[cluster_mask, TAXONOMY_NAME] = "Multiplets"

In [None]:
aavomics.plot_clusters(merged_adata.obsm[TRANSFORM_TO_PLOT], merged_adata.obs[TAXONOMY_NAME], filename=os.path.join("out", "%s.html" % TAXONOMY_NAME))

In [None]:
scvi.data.setup_anndata(merged_adata, batch_key="Cell Set", labels_key=TAXONOMY_NAME)

In [None]:
scanvi = scvi.model.SCANVI(
    merged_adata,
    unlabeled_category="",
    pretrained_model=vae,
    n_latent=20,
    n_layers=2,
    n_hidden=256
)

results = scanvi.train(
    unsupervised_trainer_kwargs={
        "seed": SEED + 1
    },
    semisupervised_trainer_kwargs={
        "seed": SEED + 2,
        "n_iter_kl_warmup": 128*5000/400,
        "n_epochs_kl_warmup": None
    },
    balanced_sampling=True,
    frequency=1,
    n_epochs_kl_warmup=None,
    n_iter_kl_warmup=128*5000/400, # Based on documentation at https://www.scvi-tools.org/en/stable/api/reference/scvi.core.trainers.UnsupervisedTrainer.html
)

scanvi.save("droplet_classifier")

merged_adata.write_h5ad(os.path.join(database.DATA_PATH, "aavomics_mouse_cortex_2021_droplet_training_data.h5"))

In [None]:
train_indices = scanvi.trainer.train_test_validation()[0].indices
test_indices = scanvi.trainer.train_test_validation()[1].indices

In [None]:
predicted_labels = scanvi.predict(merged_adata)
prediction_scores = scanvi.predict(merged_adata, soft=True)
prediction_scores_max = prediction_scores.max(axis=1)

In [None]:
print("Accuracy: %.2f%%" % (100*(predicted_labels == merged_adata.obs[TAXONOMY_NAME]).sum()/merged_adata.shape[0]))
print("Train Accuracy: %.2f%%" % (100*(predicted_labels[train_indices] == merged_adata[train_indices].obs[TAXONOMY_NAME]).sum()/train_indices.shape[0]))
print("Test Accuracy: %.2f%%" % (100*(predicted_labels[test_indices] == merged_adata[test_indices].obs[TAXONOMY_NAME]).sum()/test_indices.shape[0]))

In [None]:
aavomics.plot_clusters(merged_adata.obsm[TRANSFORM_TO_PLOT], predicted_labels, filename=os.path.join("out", "%s_predicted.html" % TAXONOMY_NAME))

In [None]:
COLUMNS_TO_DROP = [TAXONOMY_NAME, "p_%s" % TAXONOMY_NAME, "doublet", "p_doublet"]

for cell_set_index, cell_set in enumerate(database.CELL_SETS):
    
    anndata_file_path = cell_set.get_anndata_file_path(alignment_name=ALIGNMENT_NAME)
    
    if not os.path.exists(anndata_file_path):
        print("Missing %s, skipping" % cell_set.name)
        continue
        
    print(cell_set.name)
    
    adata = anndata.read(anndata_file_path)
    
    for column in COLUMNS_TO_DROP:
    
        if column in adata.obs.columns:
            adata.obs.drop(column, axis=1, inplace=True)
    
    barcode_total_transcript_counts = numpy.array(adata.X.sum(axis=1)).flatten()
    cutoff = aavomics.get_background_trough(barcode_total_transcript_counts)
    signal_mask = barcode_total_transcript_counts >= cutoff
    
    filtered_adata = adata[signal_mask].copy()
    filtered_adata = filtered_adata[:, merged_adata.var.index.values].copy()
    filtered_adata.X = filtered_adata.X.tocsr()
    filtered_adata.obs["Cell Set"] = pandas.Series(dtype=merged_adata.obs["Cell Set"].dtype)
    filtered_adata.obs["Cell Set"].loc[:] = cell_set.name
    
    filtered_adata.obs[TAXONOMY_NAME] = pandas.Series(dtype=merged_adata.obs[TAXONOMY_NAME].dtype)
    filtered_adata.obs[TAXONOMY_NAME].loc[:] = "Debris"
    
    scvi.data.setup_anndata(filtered_adata, batch_key="Cell Set", labels_key=TAXONOMY_NAME)
    
    test_predicted_labels = scanvi.predict(filtered_adata)
    test_prediction_scores = scanvi.predict(filtered_adata, soft=True)
    test_prediction_scores_max = test_prediction_scores.max(axis=1)
    filtered_adata.obs[TAXONOMY_NAME] = test_predicted_labels
    
    p_values_corrected = statsmodels.stats.multitest.multipletests(1-test_prediction_scores_max, method="fdr_bh", alpha=0.05)
    p_values_corrected_mask = p_values_corrected[0]
    
    not_debris_adata = filtered_adata[p_values_corrected_mask & (test_predicted_labels != "Debris") & (test_predicted_labels != "Multiplets")].copy()
    
    cell_set_counts = not_debris_adata.X
    
    scrub = scrublet.Scrublet(cell_set_counts)
    
    doublet_scores, predicted_doublets = scrub.scrub_doublets()
    
    if predicted_doublets is None:
        predicted_doublets = numpy.zeros((not_debris_adata.shape[0],)).astype(numpy.bool)
        
    adata.obs.loc[not_debris_adata.obs.index, "doublet"] = predicted_doublets
    adata.obs.loc[not_debris_adata.obs.index, "p_doublet"] = doublet_scores
    
    adata.obs["Cell Called"] = pandas.Series(dtype=numpy.bool)
    adata.obs["Cell Called"].loc[:] = False
    
    adata.obs.loc[not_debris_adata[~predicted_doublets].obs.index, "Cell Called"] = True

    num_cells = (adata.obs["Cell Called"] == True).sum()
    
    adata.obs.loc[filtered_adata.obs.index, TAXONOMY_NAME] = filtered_adata.obs[TAXONOMY_NAME]
    adata.obs.loc[filtered_adata.obs.index, "p_%s" % TAXONOMY_NAME] = test_prediction_scores_max
    
    print(cell_set.name, num_cells)
    adata.write_h5ad(anndata_file_path)