## Import, load dataset, load saved model

In [None]:
import csv
import gzip
import os
import scipy.io
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import torch

from scvi.dataset import GeneExpressionDataset, Dataset10X
from scvi.models import VAE, TOTALVI
from scvi.inference import TotalPosterior, TotalTrainer, Posterior, UnsupervisedTrainer
from totalppc import TotalPosteriorPredictiveCheck as totalPPC

import umap

sns.set(context="notebook", font_scale=1.15, style="ticks")
save_path = "../data/10X"
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
dataset = Dataset10X(dataset_name="malt_10k_protein_v3", save_path=save_path, 
                     measurement_names_column=1, dense=True)

In [None]:
def filter_dataset(dataset):
    high_count_genes = (dataset.X > 0).sum(axis=0).ravel() > 0.01 * dataset.X.shape[0]
    dataset.update_genes(high_count_genes)
    dataset.subsample_genes(new_n_genes=5000)
    high_gene_count_cells = (dataset.X > 0).sum(axis=1).ravel() > 500
    # Filter control proteins
    non_control_proteins = []
    for i, p in enumerate(dataset.protein_names):
        if not p.startswith("IgG"):
            non_control_proteins.append(i)
        else:
            print(p)
    dataset.protein_expression = dataset.protein_expression[:, non_control_proteins]
    dataset.protein_names = dataset.protein_names[non_control_proteins]
    
    high_protein_cells = dataset.protein_expression.sum(axis=1) >= np.percentile(dataset.protein_expression.sum(axis=1), 1)
    inds_to_keep = np.logical_and(high_gene_count_cells, high_protein_cells)
    dataset.update_cells(inds_to_keep)
    
    
    return dataset, inds_to_keep

In [None]:
dataset, inds_to_keep = filter_dataset(dataset)

In [None]:
totalvae = TOTALVI(dataset.nb_genes, len(dataset.protein_names), n_latent=20)
use_cuda = True
lr = 5e-3
early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 150,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 30,
    "lr_factor": 0.6,
    "posterior_class": TotalPosterior,
}

trainer = TotalTrainer(
    totalvae,
    dataset,
    train_size=0.90,
    test_size=0.04,
    use_cuda=use_cuda,
    frequency=1,
    data_loader_kwargs={"batch_size":256},
    n_epochs_kl_warmup=200,
    n_epochs_back_kl_warmup=200,
    early_stopping_kwargs=early_stopping_kwargs,
    seed=5
)

In [None]:
trainer.train(lr=lr, n_epochs=500)

In [None]:
with torch.no_grad():
    print(trainer.validation_set.compute_marginal_log_likelihood(n_samples_mc=100, batch_size=64))

In [None]:
-trainer.validation_set.compute_elbo(totalvae)

In [None]:
with torch.no_grad():
    print(trainer.validation_set.compute_reconstruction_error(totalvae))

In [None]:
full_data = np.concatenate([dataset.X, dataset.protein_expression], axis=1)
full_dataset = GeneExpressionDataset()
full_dataset.populate_from_data(full_data)

In [None]:
vae = VAE(full_dataset.nb_genes, n_latent=20, reconstruction_loss="nb")
trainer_vae = UnsupervisedTrainer(vae,
                                  full_dataset,
                                  train_size=0.90,
                                  test_size=0.04,
                                  use_cuda=True,
                                  frequency=10,
                                  seed=5,
                                  n_epochs_kl_warmup=200,)
trainer_vae.train(n_epochs=500, lr=3e-3)

In [None]:
from scvi.models.log_likelihood import compute_marginal_log_likelihood, compute_elbo, compute_reconstruction_error
with torch.no_grad():
    print(compute_marginal_log_likelihood(vae, trainer_vae.validation_set, n_samples_mc=100))

In [None]:
plt.plot(trainer.history['elbo_validation_set'][50:], label="validation")
plt.plot(trainer.history['elbo_train_set'][50:], label="train")
plt.plot(trainer.history['elbo_test_set'][50:], label="test")
# plt.ylim(2500, 2600)
sns.despine()
plt.legend()

In [None]:
plt.plot(trainer.history['elbo_validation_set'][10:], label="validation")
plt.plot(trainer.history['elbo_test_set'][10:], label="test")
# plt.ylim(2500, 2700)
sns.despine()
plt.legend()

In [None]:
plt.plot(trainer_vae.history['elbo_validation_set'][10:], label="validation")
plt.plot(trainer_vae.history['elbo_test_set'][10:], label="test")
# plt.ylim(2500, 2700)
sns.despine()
plt.legend()

In [None]:
with torch.no_grad():
    print(-trainer.validation_set.compute_elbo(totalvae))
    print(-compute_elbo(vae, trainer_vae.validation_set))

## Create posterior

In [None]:
full_posterior = trainer.create_posterior(totalvae, dataset, indices=np.arange(len(dataset)), type_class=TotalPosterior)
latent_mean, batch_index, label, library_gene = full_posterior.sequential().get_latent()
latent = full_posterior.sequential().get_latent(sample=True)[0]
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
N_SAMPLES = 50
parsed_protein_names = [p.split("_")[0] for p in dataset.protein_names]
py_mixing = sigmoid(full_posterior.sequential().get_sample_mixing(n_samples=N_SAMPLES, give_mean=True))
protein_pi = pd.DataFrame(data=py_mixing, columns=parsed_protein_names)
denoised_data = np.concatenate(full_posterior.sequential().get_normalized_denoised_expression(n_samples=N_SAMPLES, give_mean=True), axis=-1)

In [None]:
torch.cuda.empty_cache()

## PPC

In [None]:
vae_post = trainer_vae.create_posterior(vae, full_dataset, indices=np.arange(len(dataset)))
ppc_held = totalPPC(posteriors_dict={'totalVI':trainer.validation_set, "scVI":trainer_vae.validation_set}, n_samples=150)
ppc_full = totalPPC(posteriors_dict={'totalVI':full_posterior, "scVI":vae_post}, n_samples=25)

In [None]:
train_indices = trainer.train_set.indices
test_indices = trainer.validation_set.indices
train_data = full_data[trainer.train_set.indices]
ppc_held.store_fa_samples(train_data, train_indices, test_indices, n_components=totalvae.n_latent, normalization="log")
ppc_held.store_fa_samples(train_data, train_indices, test_indices, n_components=totalvae.n_latent, normalization="log_rate")

In [None]:
def calibration_error(ppc, key):
    ps = [2.5, 5, 7.5, 10, 12.5, 15, 17.5, 82.5, 85, 87.5, 90, 92.5, 95, 97.5]
#     ps = [25, 30, 35, 40, 45, 55, 60, 65, 70, 75]
    reverse_ps = ps[::-1]
    percentiles = np.percentile(ppc.posterior_predictive_samples[key], ps, axis=2)
    reverse_percentiles = percentiles[::-1]
    cal_error_genes = 0
    cal_error_proteins = 0
    cal_error_total = 0
    for i, j, truth, reverse_truth in zip(percentiles, reverse_percentiles, ps, reverse_ps):
        if truth > reverse_truth:
            break
        ci = np.logical_and(ppc.raw_counts >= i, ppc.raw_counts <= j)
        pci_genes = np.mean(ci[:, :dataset.nb_genes])
        pci_proteins = np.mean(ci[:, dataset.nb_genes:])
        pci_total = np.mean(ci)
        true_width = (100 - truth * 2) / 100
        cal_error_genes += (pci_genes - true_width)**2
        cal_error_proteins += (pci_proteins - true_width)**2
        cal_error_total += (pci_total - true_width)**2
    print(cal_error_genes, cal_error_proteins, cal_error_total)

calibration_error(ppc_held, "totalVI")
calibration_error(ppc_held, "scVI")
calibration_error(ppc_held, "Factor Analysis (Log)")
calibration_error(ppc_held, "Factor Analysis (Log Rate)")

In [None]:
def mean_squared_log_error(self):
    df = pd.DataFrame()
    for m, samples in self.posterior_predictive_samples.items():
        mean_sample = np.mean(samples, axis=-1)
        mad_gene = np.mean(
            np.square(
                np.log(mean_sample[:, : self.dataset.nb_genes] + 1)
                - np.log(self.raw_counts[:, : self.dataset.nb_genes] + 1)
            )
        )
        mad_pro = np.mean(
            np.square(
                np.log(mean_sample[:, self.dataset.nb_genes :] + 1)
                - np.log(self.raw_counts[:, self.dataset.nb_genes :] + 1)
            )
        )
        df[m] = [mad_gene, mad_pro]

    df.index = ["genes", "proteins"]
    self.metrics["msle"] = df
mean_squared_log_error(ppc_held)

In [None]:
ppc_held.metrics["msle"]

In [None]:
del ppc_held
ppc_full.store_fa_samples(ppc_full.raw_counts, np.arange(len(dataset)), np.arange(len(dataset)), 
                          n_components=totalvae.n_latent, normalization="log")
ppc_full.store_fa_samples(ppc_full.raw_counts, np.arange(len(dataset)), np.arange(len(dataset)), 
                          n_components=totalvae.n_latent, normalization="log_rate")

In [None]:
ppc = ppc_full
ppc.coeff_of_variation(cell_wise=False)

In [None]:
from scipy.stats import ks_2samp

fig, ax = plt.subplots(1, 1)
sns.boxplot(data=ppc.metrics["cv_gene"].iloc[dataset.nb_genes:], ax=ax, showfliers=False)
plt.title("Coefficient of Variation (Proteins)")
sns.despine()
key = "cv_gene"
print(np.median(np.abs(ppc.metrics[key]["totalVI"].iloc[dataset.nb_genes:] - ppc.metrics[key]["raw"].iloc[dataset.nb_genes:])))

print(np.median(np.abs(ppc.metrics[key]["scVI"].iloc[dataset.nb_genes:] - ppc.metrics[key]["raw"].iloc[dataset.nb_genes:])))

print(np.median(np.abs(ppc.metrics[key]["Factor Analysis (Log)"].iloc[dataset.nb_genes:] - ppc.metrics[key]["raw"].iloc[dataset.nb_genes:])))

print(np.median(np.abs(ppc.metrics[key]["Factor Analysis (Log Rate)"].iloc[dataset.nb_genes:] - ppc.metrics[key]["raw"].iloc[dataset.nb_genes:])))

In [None]:
fig, ax = plt.subplots(1, 1)
sns.boxplot(data=ppc.metrics["cv_gene"].iloc[:dataset.nb_genes], ax=ax, showfliers=False)
plt.title("Coefficient of Variation (RNA)")
sns.despine()
key = "cv_gene"
print(np.median(np.abs(ppc.metrics[key]["totalVI"].iloc[:dataset.nb_genes] - ppc.metrics[key]["raw"].iloc[:dataset.nb_genes])))

print(np.median(np.abs(ppc.metrics[key]["Factor Analysis (Log)"].iloc[:dataset.nb_genes] - ppc.metrics[key]["raw"].iloc[:dataset.nb_genes])))

print(np.median(np.abs(ppc.metrics[key]["Factor Analysis (Log Rate)"].iloc[:dataset.nb_genes] - ppc.metrics[key]["raw"].iloc[:dataset.nb_genes])))

print(np.median(np.abs(ppc.metrics[key]["scVI"].iloc[:dataset.nb_genes] - ppc.metrics[key]["raw"].iloc[:dataset.nb_genes])))

