In [None]:
import os
import numpy

from aavomics import database
import scanpy
import anndata
import scvi
import pandas

from aavomics import aavomics

from plotly import offline as plotly
from plotly import graph_objects
from plotly.subplots import make_subplots
from skimage.filters import threshold_otsu

In [None]:
ALIGNMENT_NAME = "cellranger_5.0.1_gex_mm10_2020_A"

MARKER_GENES = [
    "Aldh1l1",
    "Sox9",
    "S100b",
    "Cldn5",
    "Slc2a1",
    "Pdgfrb",
    "Rgs5",
    "Abcc9",
    "Hba-a1",
    "Hba-a2",
    "Acta2",
    "Myh11",
    "Tagln",
    "Fam180a",
    "Slc6a13",
    "Dcn",
    "Ptgds",
    "Cx3cr1",
    "Tmem119",
    "Itgal",
    "Gzma",
    "Mrc1",
    "Rbfox3",
    "Olig2",
    "Pdgfra",
    "Cspg4",
    "Mog",
    "Mbp",
    "Ptprz1",
    "Bmp4",
    "Nkx2-2",
    "Vcan"
]

TAXONOMY_NAME = "CCN202105070"
SEED = 1042
CLUSTER_OBS_NAME = "leiden_scVI"

In [None]:
adatas = []
cell_set_names = []
        
for cell_set_index, cell_set in enumerate(database.CELL_SETS):
    
    print(cell_set.name)
    
    anndata_file_path = cell_set.get_anndata_file_path(alignment_name=ALIGNMENT_NAME)
    
    if not os.path.exists(anndata_file_path):
        print("Missing %s, skipping" % cell_set.name)
        continue
    
    adata = anndata.read(anndata_file_path)
    adata = adata[adata.obs["Cell Called"] == "True"].copy()
    
    adatas.append(adata)
    cell_set_names.append(cell_set.name)
    
merged_adata = adatas[0].concatenate(adatas[1:], batch_key="Cell Set", batch_categories=numpy.array(cell_set_names))

In [None]:
gene_maxes = merged_adata.X.max(axis=0)
gene_maxes = numpy.array(gene_maxes.todense()).flatten()
gene_mask = (gene_maxes > 0)
merged_adata = merged_adata[:, gene_mask].copy()

In [None]:
scvi.data.setup_anndata(merged_adata, batch_key="Cell Set")

In [None]:
vae = scvi.model.SCVI(
    merged_adata,
    n_latent=20,
    n_layers=2,
    n_hidden=256
)

vae.train(
    frequency=1,
    n_epochs_kl_warmup=None,
    n_iter_kl_warmup=128*5000/400, # Based on documentation at https://www.scvi-tools.org/en/stable/api/reference/scvi.core.trainers.UnsupervisedTrainer.html
    seed=SEED
)

In [None]:
merged_adata.obsm["X_scVI"] = vae.get_latent_representation(merged_adata)

In [None]:
scanpy.pp.neighbors(merged_adata, use_rep="X_scVI", random_state=SEED)

In [None]:
scanpy.tl.leiden(merged_adata, key_added=CLUSTER_OBS_NAME, random_state=SEED)

In [None]:
scanpy.tl.tsne(merged_adata, use_rep="X_scVI", n_jobs=8, random_state=SEED)

In [None]:
cluster_gene_non_zeros_df = pandas.DataFrame(index=sorted(merged_adata.obs[CLUSTER_OBS_NAME].unique().astype(numpy.uint16)), columns=merged_adata.var.index)

In [None]:
import torch
import numpy as np
from scvi import _CONSTANTS
from scvi.core.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial
from typing import Dict, Iterable, Optional, Sequence, Union
from anndata import AnnData

from scvi.model._utils import (
    _get_batch_code_from_category,
    _get_var_names_from_setup_anndata,
    scrna_raw_counts_properties,
)

@torch.no_grad()
def posterior_predictive_sample(
    self,
    adata: Optional[AnnData] = None,
    indices: Optional[Sequence[int]] = None,
    n_samples: int = 1,
    gene_list: Optional[Sequence[str]] = None,
    batch_size: Optional[int] = None,
    transform_batch: Optional[int] = None
) -> np.ndarray:
    r"""
    Generate observation samples from the posterior predictive distribution.

    The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`.

    Parameters
    ----------
    adata
        AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
        AnnData object used to initialize the model.
    indices
        Indices of cells in adata to use. If `None`, all cells are used.
    n_samples
        Number of samples for each cell.
    gene_list
        Names of genes of interest.
    batch_size
        Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

    Returns
    -------
    x_new : :py:class:`torch.Tensor`
        tensor with shape (n_cells, n_genes, n_samples)
    """
    if self.model.gene_likelihood not in ["zinb", "nb", "poisson"]:
        raise ValueError("Invalid gene_likelihood.")

    adata = self._validate_anndata(adata)
    scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)

    if indices is None:
        indices = np.arange(adata.n_obs)

    if gene_list is None:
        gene_mask = slice(None)
    else:
        all_genes = _get_var_names_from_setup_anndata(adata)
        gene_mask = [True if gene in gene_list else False for gene in all_genes]

    x_new = []
    for tensors in scdl:
        x = tensors[_CONSTANTS.X_KEY]
        batch_idx = tensors[_CONSTANTS.BATCH_KEY]
        labels = tensors[_CONSTANTS.LABELS_KEY]
        outputs = self.model.inference(
            x, batch_index=batch_idx, y=labels, n_samples=n_samples, transform_batch=transform_batch
        )
        px_r = outputs["px_r"]
        px_rate = outputs["px_rate"]
        px_dropout = outputs["px_dropout"]

        if self.model.gene_likelihood == "poisson":
            l_train = px_rate
            l_train = torch.clamp(l_train, max=1e8)
            dist = torch.distributions.Poisson(
                l_train
            )  # Shape : (n_samples, n_cells_batch, n_genes)
        elif self.model.gene_likelihood == "nb":
            dist = NegativeBinomial(mu=px_rate, theta=px_r)
        elif self.model.gene_likelihood == "zinb":
            dist = ZeroInflatedNegativeBinomial(
                mu=px_rate, theta=px_r, zi_logits=px_dropout
            )
        else:
            raise ValueError(
                "{} reconstruction error not handled right now".format(
                    self.model.gene_likelihood
                )
            )
        if n_samples > 1:
            exprs = dist.sample().permute(
                [1, 2, 0]
            )  # Shape : (n_cells_batch, n_genes, n_samples)
        else:
            exprs = dist.sample()

        if gene_list is not None:
            exprs = exprs[:, gene_mask, ...]

        x_new.append(exprs.cpu())
    x_new = torch.cat(x_new)  # Shape (n_cells, n_genes, n_samples)

    return x_new.numpy()

In [None]:
clusters = sorted(merged_adata.obs[CLUSTER_OBS_NAME].unique(), key=lambda x: int(x))

NUM_SAMPLES = 5000

batch_cluster_gene_non_zeros_dfs = {}

for batch_id in merged_adata.obs["_scvi_batch"].unique():
    
    cluster_gene_non_zeros_df = pandas.DataFrame(index=sorted(merged_adata.obs[CLUSTER_OBS_NAME].unique().astype(numpy.uint16)), columns=merged_adata.var.index)

    for cluster in clusters:

        print("Batch %i cluster %s" % (batch_id, cluster))

        adata_copy = merged_adata[merged_adata.obs[CLUSTER_OBS_NAME] == cluster].copy()

        random_sample = adata_copy[numpy.random.choice(list(range(0, adata_copy.shape[0])), replace=True, size=NUM_SAMPLES)].copy()

        samples = posterior_predictive_sample(vae, random_sample, n_samples=1, transform_batch=batch_id)

        non_zero_counts = (samples != 0).sum(axis=0)

        p_non_zeros = numpy.array(non_zero_counts/NUM_SAMPLES)

        cluster_gene_non_zeros_df.loc[int(cluster), :] = p_non_zeros
        
    batch_cluster_gene_non_zeros_dfs[batch_id] = cluster_gene_non_zeros_df

In [None]:
all_counts = []

clusters = merged_adata.obs[CLUSTER_OBS_NAME].unique().astype(numpy.uint16)
batches = merged_adata.obs["_scvi_batch"].unique()

marker_gene_cluster_batch_counts = {}
batch_ratio_thresholds = {}

for marker_gene in MARKER_GENES:
    
    print(len(marker_gene_cluster_batch_counts), marker_gene)
    
    ensembl_id = merged_adata.var.loc[merged_adata.var['Gene Name']==marker_gene].index[0]
    
    cluster_batch_counts = {cluster: 0 for cluster in clusters}

    for batch_id in batches:

        cluster_gene_non_zeros_df = pandas.DataFrame(index=sorted(clusters), columns=merged_adata.var.index)

        for cluster in cluster_gene_non_zeros_df.index.values:
            
            cluster_gene_non_zeros_df.loc[int(cluster), :] = batch_cluster_gene_non_zeros_dfs[batch_id].loc[int(cluster)]

        values = cluster_gene_non_zeros_df[ensembl_id].astype(numpy.float32).values

        nan_filter = ~numpy.isnan(values)
        values = values[nan_filter].reshape((-1, 1))
        
        threshold = threshold_otsu(values)

        clusters_above_threshold = cluster_gene_non_zeros_df[nan_filter].index[values.flatten() > threshold]
        
        for cluster in clusters_above_threshold:
            cluster_batch_counts[cluster] += 1
            
        marker_gene_cluster_batch_counts[ensembl_id] = cluster_batch_counts

    all_counts.extend(cluster_batch_counts.values())

    batch_ratio_threshold = threshold_otsu(numpy.array(list(cluster_batch_counts.values()))/len(batches))
    batch_ratio_thresholds[ensembl_id] = batch_ratio_threshold

batch_ratio_threshold = threshold_otsu(numpy.array(all_counts)/len(batches))

In [None]:
clusters = merged_adata.obs[CLUSTER_OBS_NAME].unique().astype(numpy.uint16)

marker_gene_clusters = {}

for marker_gene in MARKER_GENES:
    
    marker_gene_clusters[marker_gene] = []
    
    for cluster, batch_count in marker_gene_cluster_batch_counts[marker_gene].items():
        
        if batch_count/len(batches) > batch_ratio_thresholds[marker_gene]:
            marker_gene_clusters[marker_gene].append(cluster)
    
    print("%s clusters above threshold: " % marker_gene, marker_gene_clusters[marker_gene])

In [None]:
COLUMNS_TO_DROP = [TAXONOMY_NAME, "c_%s" % TAXONOMY_NAME, "X_%s" % TAXONOMY_NAME, "Y_%s" % TAXONOMY_NAME] + ["%s_%s" % (TAXONOMY_NAME, marker_gene) for marker_gene in MARKER_GENES]

for cell_set_index, cell_set in enumerate(database.CELL_SETS):
    
    anndata_file_path = cell_set.get_anndata_file_path(alignment_name=ALIGNMENT_NAME)
    
    if not os.path.exists(anndata_file_path):
        print("Missing %s, skipping" % cell_set.name)
        continue
        
    print(cell_set.name)
    
    adata = anndata.read(anndata_file_path)
    
    for column in COLUMNS_TO_DROP:
    
        if column in adata.obs.columns:
            adata.obs.drop(column, axis=1, inplace=True)
            
    cell_set_mask = merged_adata.obs["Cell Set"] == cell_set.name
    cell_barcodes = numpy.array(["-".join(x.split("-")[0:-1]) for x in merged_adata[cell_set_mask].obs.index.values])

    adata.obs.loc[cell_barcodes, TAXONOMY_NAME] = merged_adata[cell_set_mask].obs[CLUSTER_OBS_NAME].values
    adata.obs.loc[cell_barcodes, "X_%s" % TAXONOMY_NAME] = merged_adata[cell_set_mask].obsm["X_tsne"][:, 0]
    adata.obs.loc[cell_barcodes, "Y_%s" % TAXONOMY_NAME] = merged_adata[cell_set_mask].obsm["X_tsne"][:, 1]

    adata.write_h5ad(anndata_file_path)

In [None]:
marker_gene_clusters_df = pandas.DataFrame(index=marker_gene_clusters.keys(), columns=["Gene Name"] + list(sorted(clusters)))
marker_gene_clusters_df.fillna(False, inplace=True)

In [None]:
for cluster in merged_adata.obs["leiden_scVI"].unique():
    
    cluster_mask = merged_adata.obs["leiden_scVI"] == cluster
    
    de_df = vae.differential_expression(idx1=cluster_mask, idx2=(~cluster_mask))
    de_df["Gene Name"] = merged_adata.var.loc[de_df.index.values]["Gene Name"]
    de_df.to_csv(os.path.join("out", "%s_cluster_%s_de.csv" % (TAXONOMY_NAME, cluster)))

In [None]:
for ensembl_id in marker_gene_clusters:
    
    gene_name = merged_adata.var.loc[ensembl_id, "Gene Name"]
        
    marker_gene_clusters_df.loc[ensembl_id, "Gene Name"] = gene_name
    
    for cluster in marker_gene_clusters[ensembl_id]:
        
        marker_gene_clusters_df.loc[ensembl_id, cluster] = True

marker_gene_clusters_df.to_csv(os.path.join(database.DATA_PATH, "%s_marker_gene_clusters.csv" % TAXONOMY_NAME))

In [None]:
clusters = sorted(merged_adata.obs[CLUSTER_OBS_NAME].unique(), key=lambda x: int(x))
cluster_gene_batch_counts_df = pandas.DataFrame(index=sorted(clusters), columns=merged_adata.var.index)
cluster_gene_batch_counts_df.fillna(0, inplace=True)

batches = merged_adata.obs["_scvi_batch"].unique()

for batch_id in batches:
    
    print(batch_id)

    cluster_gene_non_zeros_df = pandas.DataFrame(index=sorted(clusters), columns=merged_adata.var.index)
    
    for cluster in cluster_gene_non_zeros_df.index.values:

        cluster_gene_non_zeros_df.loc[int(cluster), :] = batch_cluster_gene_non_zeros_dfs[batch_id].loc[int(cluster)]
        
    for gene_index, ensembl_id in enumerate(cluster_gene_batch_counts_df.columns):
        
        if gene_index % 100 == 0:
            print(gene_index)
        
        values = cluster_gene_non_zeros_df[ensembl_id].astype(numpy.float32).values

        nan_filter = ~numpy.isnan(values)
        values = values[nan_filter].reshape((-1, 1))
        
        threshold = threshold_otsu(values)

        clusters_above_threshold = cluster_gene_non_zeros_df[nan_filter].index[values.flatten() > threshold]
        
        for cluster in clusters_above_threshold:
            cluster_gene_batch_counts_df.loc[str(cluster), ensembl_id] += 1

batches_threshold = threshold_otsu(cluster_gene_batch_counts_df.values.flatten()/len(batches))

all_gene_clusters_df = pandas.DataFrame(index=cluster_gene_batch_counts_df.columns, columns=["Gene Name"] + list(cluster_gene_batch_counts_df.index.values))
all_gene_clusters_df.fillna(False, inplace=True)

for gene_index, ensembl_id in enumerate(cluster_gene_batch_counts_df.columns):
    
    print(gene_index, ensembl_id)
    
    values = cluster_gene_batch_counts_df[ensembl_id].values / len(batches)
    
    gene_name = merged_adata.var.loc[ensembl_id, "Gene Name"]
        
    all_gene_clusters_df.loc[ensembl_id, "Gene Name"] = gene_name
    
    clusters_above_threshold = cluster_gene_batch_counts_df[ensembl_id][values > batches_threshold].index.values
    
    for cluster in clusters_above_threshold:
        
        all_gene_clusters_df.loc[ensembl_id, cluster] = True

all_gene_clusters_df.to_csv(os.path.join(database.DATA_PATH, "%s_gene_clusters.csv" % TAXONOMY_NAME))