NOTE: Results change from run to run so these results (or any future) will not be equal to those reported in the biorxiv preprint.

## Import, load dataset, load saved model

In [None]:
import csv
import gzip
import os
import scipy.io
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import torch

from scvi.dataset import GeneExpressionDataset, Dataset10X
from scvi.models import VAE, TOTALVI
from scvi.inference import TotalPosterior, TotalTrainer, Posterior, UnsupervisedTrainer
from totalppc import TotalPosteriorPredictiveCheck as totalPPC

from scipy.special import softmax
from sklearn.decomposition import PCA
import umap

sns.set(context="notebook", font_scale=1.15, style="ticks")
save_path = "../data/10X"
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
dataset = Dataset10X(dataset_name="pbmc_10k_protein_v3", save_path=save_path, 
                     measurement_names_column=1, dense=True)

In [None]:
# Code to run doubletdetection
# import doubletdetection
# clf = doubletdetection.BoostClassifier(n_iters=25, use_phenograph=False, standard_scaling=True)
# doublets = clf.fit(dataset.X).predict(p_thresh=1e-16, voter_thresh=0.5)
# np.save("pbmc10kdoublets.npy", doublets)

In [None]:
doublets = np.load("data/pbmc10kdoublets.npy")

In [None]:
def filter_dataset(dataset):
    high_count_genes = (dataset.X > 0).sum(axis=0).ravel() > 0.01 * dataset.X.shape[0]
    dataset.update_genes(high_count_genes)
    dataset.subsample_genes(new_n_genes=5000)

    # Filter control proteins
    non_control_proteins = []
    for i, p in enumerate(dataset.protein_names):
        if not p.startswith("IgG"):
            non_control_proteins.append(i)
        else:
            print(p)
    dataset.protein_expression = dataset.protein_expression[:, non_control_proteins]
    dataset.protein_names = dataset.protein_names[non_control_proteins]
    
    
    high_gene_count_cells = (dataset.X > 0).sum(axis=1).ravel() > 500
    high_protein_cells = dataset.protein_expression.sum(axis=1) >= np.percentile(dataset.protein_expression.sum(axis=1), 1)
    inds_to_keep = np.logical_and(high_gene_count_cells, high_protein_cells)
    inds_to_keep = np.logical_and(inds_to_keep, ~(doublets.astype(np.bool)))
    dataset.update_cells(inds_to_keep)
    return dataset, inds_to_keep

In [None]:
dataset, inds_to_keep = filter_dataset(dataset)

In [None]:
totalvae = TOTALVI(dataset.nb_genes, len(dataset.protein_names), n_latent=20)
use_cuda = True
lr = 5e-3
early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 150,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 30,
    "lr_factor": 0.6,
    "posterior_class": TotalPosterior,
}

trainer = TotalTrainer(
    totalvae,
    dataset,
    train_size=0.90,
    test_size=0.04,
    use_cuda=use_cuda,
    frequency=1,
    data_loader_kwargs={"batch_size":256},
    n_epochs_kl_warmup=200,
    n_epochs_back_kl_warmup=200,
    early_stopping_kwargs=early_stopping_kwargs,
    seed=3
)

In [None]:
trainer.train(lr=lr, n_epochs=500)

In [None]:
with torch.no_grad():
    print(trainer.validation_set.compute_marginal_log_likelihood(n_samples_mc=100, batch_size=64))

In [None]:
-trainer.validation_set.compute_elbo(totalvae)

In [None]:
with torch.no_grad():
    print(trainer.validation_set.compute_reconstruction_error(totalvae))

In [None]:
full_data = np.concatenate([dataset.X, dataset.protein_expression], axis=1)
full_dataset = GeneExpressionDataset()
full_dataset.populate_from_data(full_data)

In [None]:
vae = VAE(full_dataset.nb_genes, n_latent=20, reconstruction_loss="nb")
trainer_vae = UnsupervisedTrainer(vae,
                                  full_dataset,
                                  train_size=0.90,
                                  test_size=0.04,
                                  use_cuda=True,
                                  frequency=10,
                                  seed=3,
                                  n_epochs_kl_warmup=200,)
trainer_vae.train(n_epochs=500, lr=3e-3)

In [None]:
from scvi.models.log_likelihood import compute_marginal_log_likelihood, compute_elbo, compute_reconstruction_error
with torch.no_grad():
    print(compute_marginal_log_likelihood(vae, trainer_vae.validation_set, n_samples_mc=100))

In [None]:
plt.plot(trainer.history['elbo_validation_set'][50:], label="validation")
plt.plot(trainer.history['elbo_train_set'][50:], label="train")
plt.plot(trainer.history['elbo_test_set'][50:], label="test")
sns.despine()
plt.legend()

In [None]:
plt.plot(trainer.history['elbo_validation_set'][10:], label="validation")
plt.plot(trainer.history['elbo_test_set'][10:], label="test")
sns.despine()
plt.legend()

In [None]:
plt.plot(trainer_vae.history['elbo_validation_set'][10:], label="validation")
plt.plot(trainer_vae.history['elbo_test_set'][10:], label="test")
sns.despine()
plt.legend()

In [None]:
with torch.no_grad():
    print(-trainer.validation_set.compute_elbo(totalvae))
    print(-compute_elbo(vae, trainer_vae.validation_set))

## Create posterior

In [None]:
full_posterior = trainer.create_posterior(totalvae, dataset, indices=np.arange(len(dataset)), type_class=TotalPosterior)
latent_mean, batch_index, label, library_gene = full_posterior.sequential().get_latent()
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
N_SAMPLES = 50
parsed_protein_names = [p.split("_")[0] for p in dataset.protein_names]
py_mixing = sigmoid(full_posterior.sequential().get_sample_mixing(n_samples=N_SAMPLES, give_mean=True))
protein_pi = pd.DataFrame(data=py_dropout, columns=parsed_protein_names)
# Function below returns tuple (denoised_gene, denoised_pro) we concat here 
denoised_data = np.concatenate(full_posterior.sequential().get_normalized_denoised_expression(n_samples=N_SAMPLES, give_mean=True), axis=-1)

In [None]:
torch.cuda.empty_cache()

In [None]:
umap_dr = umap.UMAP(n_neighbors=15, random_state=42, min_dist=0.1).fit_transform(latent_mean)

In [None]:
gene = "CD127_TotalSeqB"
fig, ax = plt.subplots(1, 2, figsize=(8, 4), dpi=300)
cax = ax[0].scatter(umap_dr[:, 0], umap_dr[:, 1], s=3, c = py_mixing[:, np.where(dataset.protein_names==gene)[0]].ravel())
ax[0].axis('off')
ax[1].axis('off')
ax[0].set_title("Probability(in background)")
ax[1].scatter(umap_dr[:, 0], umap_dr[:, 1], s=3, c = np.log(dataset.protein_expression[:, np.where(dataset.protein_names==gene)[0]].ravel()+1))
ax[1].set_title("Log Normalized Expression")
fig.colorbar(cax, ax=ax[0], orientation='vertical')

In [None]:
gene = "CD4_TotalSeqB"
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].hist(np.log(1 + dataset.protein_expression[:, np.where(dataset.protein_names == gene)[0]]), bins=50)
sns.despine()
cax = ax[1].scatter(umap_dr[:, 0], umap_dr[:, 1], s=3, c = py_mixing[:, np.where(dataset.protein_names==gene)[0]].ravel(), cmap=plt.cm.viridis)
ax[1].axis('off')
ax[0].set_title("Distribution of CD4 protein counts")
ax[0].set_ylabel("Number of Cells")
ax[0].set_xlabel("log(UMI)")
ax[1].set_title("Probability(Background)")
ax[1].annotate("CD4+ T", xy=(-5, 2))
ax[1].annotate("CD8+ T", xy=(1, -5))
ax[1].annotate("Monocytes", xy=(-5, -15))
fig.colorbar(cax, ax=ax[1], orientation='vertical')
plt.tight_layout()
plt.savefig("monocyte.pdf")

In [None]:
import matplotlib.gridspec as gridspec
fig = plt.figure(figsize=(12, 4))
gs = gridspec.GridSpec(2, 3)
ax = []
ax.append(plt.subplot(gs[0:, 0]))
ax.append(plt.subplot(gs[0:, 1]))
ax.append(plt.subplot(gs[0, 2]))
ax.append(plt.subplot(gs[1, 2]))


ind = 5
gene=np.where(dataset.gene_names == "FCGR3A")[0][0]

c = np.logical_and((py_mixing[:, ind] > 0.9), np.log(dataset.protein_expression[:, ind]+1) < 6.8)
c = np.logical_and(c, np.log(dataset.protein_expression[:, ind]+1) > 3.5)
d = np.logical_and((py_mixing[:, ind] < 0.1), np.log(dataset.protein_expression[:, ind]+1) < 6.8)

ax[0].scatter(np.log(dataset.protein_expression[~c, ind]+1), py_mixing[~c, ind], c="grey", s=3)
ax[0].scatter(np.log(dataset.protein_expression[c, ind]+1), py_mixing[c, ind], c="red", s=3)
ax[0].scatter(np.log(dataset.protein_expression[d, ind]+1), py_mixing[d, ind], c="blue", s=3)

ax[0].axvline(6.8, linestyle="--", c="black")
ax[0].axvline(3.5, linestyle="--", c="black")
ax[0].axhline(0.9, linestyle="--", c="black")
ax[0].axhline(0.1, linestyle="--", c="black")
ax[1].scatter(umap_dr[~c, 0], umap_dr[~c, 1], s=3, c="grey")
ax[1].scatter(umap_dr[c, 0], umap_dr[c, 1], s=1, c="red", alpha=0.3)
ax[1].scatter(umap_dr[d, 0], umap_dr[d, 1], s=1, c="blue", alpha=0.3)
ax[1].tick_params(axis='both', which='both', length=0)
ax[1].set_xticklabels([])
ax[1].set_yticklabels([])

bins = np.linspace(0, 9, 10)
ax[2].hist(dataset.X[c, gene].ravel(), alpha=0.8, 
           color="red", density=True, bins=bins, label="Selected")
ax[2].hist(dataset.X[~c, gene].ravel(), alpha=0.5, 
           color="grey", density=True, bins=bins, label="Rest")
ax[2].set_ylabel("Density")
ax[2].legend()

ax[3].hist(dataset.X[d, gene].ravel(), alpha=0.8, 
           color="blue", density=True, bins=bins, label="Selected")
ax[3].hist(dataset.X[~d, gene].ravel(), alpha=0.5, 
           color="grey", density=True, bins=bins, label="Rest")
ax[3].set_xlabel("CD16 Gene Expression (UMI count)")
ax[3].set_ylabel("Density")
ax[3].legend()

sns.despine()
ax[0].set_xlabel("{} Protein log(UMI count)".format(dataset.protein_names[ind].split("_")[0]))
ax[0].set_ylabel("Probability(Background)")
ax[1].set_xlabel("UMAP 1")
ax[1].set_ylabel("UMAP 2")
ax[1].annotate("NK", xy=(-9, 6))
ax[1].annotate("CD16+ Mono", xy=(6, -9.3))
ax[1].annotate("CD14+ Mono", xy=(-5, -15))
plt.tight_layout()
plt.savefig("cd16_pbmc10k.pdf")

In [None]:
pp_sample = full_posterior.generate(n_samples=5, batch_size=64)[0]

In [None]:
from scipy.stats import pearsonr, spearmanr
import scanpy as sc
import anndata

adata = anndata.AnnData(X=dataset.X)
adata = anndata.AnnData(X=np.concatenate([adata.X, dataset.protein_expression], axis=1))


identifiers = [pn.split('_')[0] for pn in dataset.protein_names]
translation = {"CD8a":"CD8A", "CD3":"CD3G", "CD127":"IL7R", "CD25":"IL2RA", "CD16":"FCGR3A", "CD4":"CD4", 
               "CD14":"CD14", "CD15":"FUT4", "CD56":"NCAM1", "CD19":"CD19", "CD45RA":"PTPRC", "CD45RO":"PTPRC", 
               "PD-1":"PDCD1", "TIGIT":"TIGIT"}
total = []
total_pp = []
scan = []
for i, identifier in enumerate(identifiers):
    pro = np.where(dataset.protein_names == identifier + '_TotalSeqB')[0][0] + dataset.nb_genes
    identifier = translation[identifier]
    try:
        gene = np.where(dataset.gene_names == identifier)[0][0]
    except:
        print(identifier)
        continue
    pearson = pearsonr(denoised_data[:, gene], denoised_data[:, pro])[0]
    spearman_total = spearmanr(denoised_data[:, gene], denoised_data[:, pro])[0]
    spearman_total_pp = 0
    for j in range(pp_sample.shape[-1]):
        spearman_total_pp += spearmanr(pp_sample[:, gene, j], pp_sample[:, pro, j])[0]
    spearman_total_pp /= pp_sample.shape[-1]
    spearman_scan = spearmanr(adata.X[:, gene], adata.X[:, pro])[0]
    total_pp.append(spearman_total_pp)
    total.append(spearman_total)
    scan.append(spearman_scan)


selected_genes = np.random.choice(dataset.nb_genes, 500)
total_random = []
total_pp_random = []
scan_random = []
for i, p in enumerate(dataset.protein_names):
    for g in selected_genes:
        pro = dataset.nb_genes + i
        spearman_total = spearmanr(denoised_data[:, g], denoised_data[:,pro])[0]
        spearman_total_pp = 0
        for j in range(pp_sample.shape[-1]):
            spearman_total_pp += spearmanr(pp_sample[:, g, j], pp_sample[:, pro, j])[0]
        spearman_total_pp /= pp_sample.shape[-1]
        spearman_scan = spearmanr(adata.X[:, g], adata.X[:, pro])[0]
        total_random.append(spearman_total)
        total_pp_random.append(spearman_total_pp)
        scan_random.append(spearman_scan)

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(8, 4))


plot_total = [total_pp, total]
plot_total_random = [total_pp_random, total_random]
for i, ax in enumerate(axarr):
    ax.scatter(scan, plot_total[i], s=20, c='red', label="Same gene", zorder=2)
    ax.scatter(scan_random, plot_total_random[i], s=2, alpha=0.7, c='grey', label="Other", zorder=1)
    legend = ax.legend(loc=4)

    legend.legendHandles[0]._sizes = [15]
    legend.legendHandles[1]._sizes = [15]

    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]

    # now plot both limits against eachother
    ax.plot([-1, 1], [-1, 1], 'k-', alpha=0.75, zorder=0)
    ax.set_aspect('equal')
sns.despine()
axarr[1].set_xlabel("Raw")
axarr[1].set_ylabel("totalVI Denoised")
axarr[1].set_title("Protein-RNA Spearman Correlations")
axarr[0].set_xlabel("Raw")
axarr[0].set_ylabel("totalVI Posterior Predictive")
axarr[0].set_title("Protein-RNA Spearman Correlations")
plt.tight_layout()
plt.savefig("correlations_noised.pdf")

## PPC

In [None]:
vae_post = trainer_vae.create_posterior(vae, full_dataset, indices=np.arange(len(dataset)))
ppc_held = totalPPC(posteriors_dict={'totalVI':trainer.validation_set, "scVI":trainer_vae.validation_set}, n_samples=150)
ppc_full = totalPPC(posteriors_dict={'totalVI':full_posterior, "scVI":vae_post}, n_samples=25)

In [None]:
train_indices = trainer.train_set.indices
test_indices = trainer.validation_set.indices
train_data = full_data[trainer.train_set.indices]
ppc_held.store_fa_samples(train_data, train_indices, test_indices, n_components=totalvae.n_latent, normalization="log")
ppc_held.store_fa_samples(train_data, train_indices, test_indices, n_components=totalvae.n_latent, normalization="log_rate")

In [None]:
def calibration_error(ppc, key):
    ps = [2.5, 5, 7.5, 10, 12.5, 15, 17.5, 82.5, 85, 87.5, 90, 92.5, 95, 97.5]
    reverse_ps = ps[::-1]
    percentiles = np.percentile(ppc.posterior_predictive_samples[key], ps, axis=2)
    reverse_percentiles = percentiles[::-1]
    cal_error_genes = 0
    cal_error_proteins = 0
    cal_error_total = 0
    for i, j, truth, reverse_truth in zip(percentiles, reverse_percentiles, ps, reverse_ps):
        if truth > reverse_truth:
            break
        ci = np.logical_and(ppc.raw_counts >= i, ppc.raw_counts <= j)
        pci_genes = np.mean(ci[:, :dataset.nb_genes])
        pci_proteins = np.mean(ci[:, dataset.nb_genes:])
        pci_total = np.mean(ci)
        true_width = (100 - truth * 2) / 100
        cal_error_genes += (pci_genes - true_width)**2
        cal_error_proteins += (pci_proteins - true_width)**2
        cal_error_total += (pci_total - true_width)**2
    print(cal_error_genes, cal_error_proteins, cal_error_total)

calibration_error(ppc_held, "totalVI")
calibration_error(ppc_held, "scVI")
calibration_error(ppc_held, "Factor Analysis (Log)")
calibration_error(ppc_held, "Factor Analysis (Log Rate)")

In [None]:
def mean_squared_log_error(self):
    df = pd.DataFrame()
    for m, samples in self.posterior_predictive_samples.items():
        mean_sample = np.mean(samples, axis=-1)
        mad_gene = np.mean(
            np.square(
                np.log(mean_sample[:, : self.dataset.nb_genes] + 1)
                - np.log(self.raw_counts[:, : self.dataset.nb_genes] + 1)
            )
        )
        mad_pro = np.mean(
            np.square(
                np.log(mean_sample[:, self.dataset.nb_genes :] + 1)
                - np.log(self.raw_counts[:, self.dataset.nb_genes :] + 1)
            )
        )
        df[m] = [mad_gene, mad_pro]

    df.index = ["genes", "proteins"]
    self.metrics["msle"] = df
mean_squared_log_error(ppc_held)

In [None]:
ppc_held.metrics["msle"]

In [None]:
del ppc_held
ppc_full.store_fa_samples(ppc_full.raw_counts, np.arange(len(dataset)), np.arange(len(dataset)), 
                          n_components=totalvae.n_latent, normalization="log")
ppc_full.store_fa_samples(ppc_full.raw_counts, np.arange(len(dataset)), np.arange(len(dataset)), 
                          n_components=totalvae.n_latent, normalization="log_rate")

In [None]:
ppc = ppc_full
ppc.coeff_of_variation(cell_wise=False)

In [None]:
from scipy.stats import ks_2samp

fig, ax = plt.subplots(1, 1)
sns.boxplot(data=ppc.metrics["cv_gene"].iloc[dataset.nb_genes:], ax=ax, showfliers=False)
plt.title("Coefficient of Variation (Proteins)")
sns.despine()
key = "cv_gene"
print(np.median(np.abs(ppc.metrics[key]["totalVI"].iloc[dataset.nb_genes:] - ppc.metrics[key]["raw"].iloc[dataset.nb_genes:])))

print(np.median(np.abs(ppc.metrics[key]["scVI"].iloc[dataset.nb_genes:] - ppc.metrics[key]["raw"].iloc[dataset.nb_genes:])))

print(np.median(np.abs(ppc.metrics[key]["Factor Analysis (Log)"].iloc[dataset.nb_genes:] - ppc.metrics[key]["raw"].iloc[dataset.nb_genes:])))

print(np.median(np.abs(ppc.metrics[key]["Factor Analysis (Log Rate)"].iloc[dataset.nb_genes:] - ppc.metrics[key]["raw"].iloc[dataset.nb_genes:])))

In [None]:
fig, ax = plt.subplots(1, 1)
sns.boxplot(data=ppc.metrics["cv_gene"].iloc[:dataset.nb_genes], ax=ax, showfliers=False)
plt.title("Coefficient of Variation (RNA)")
sns.despine()
key = "cv_gene"
print(np.median(np.abs(ppc.metrics[key]["totalVI"].iloc[:dataset.nb_genes] - ppc.metrics[key]["raw"].iloc[:dataset.nb_genes])))

print(np.median(np.abs(ppc.metrics[key]["Factor Analysis (Log)"].iloc[:dataset.nb_genes] - ppc.metrics[key]["raw"].iloc[:dataset.nb_genes])))

print(np.median(np.abs(ppc.metrics[key]["Factor Analysis (Log Rate)"].iloc[:dataset.nb_genes] - ppc.metrics[key]["raw"].iloc[:dataset.nb_genes])))

print(np.median(np.abs(ppc.metrics[key]["scVI"].iloc[:dataset.nb_genes] - ppc.metrics[key]["raw"].iloc[:dataset.nb_genes])))

