## Import, load dataset, load saved model

In [1]:
import csv
import gzip
import os
import scipy.io
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from scvi.dataset import GeneExpressionDataset, Dataset10X
from scvi.models import VAE, TOTALVI
from scvi.inference import TotalPosterior, TotalTrainer, Posterior, UnsupervisedTrainer
from scvi.inference import TotalPosteriorPredictiveCheck as totalPPC

from scipy.special import softmax
from sklearn.decomposition import PCA
import umap

save_path = "../data/10X"
%load_ext autoreload
%autoreload 2
%matplotlib inline

  from numpy.core.umath_tests import inner1d
INFO:hyperopt.utils:Failed to load dill, try installing dill via "pip install dill" for enhanced pickling support.
INFO:hyperopt.fmin:Failed to load dill, try installing dill via "pip install dill" for enhanced pickling support.
INFO:hyperopt.mongoexp:Failed to load dill, try installing dill via "pip install dill" for enhanced pickling support.


In [2]:
dataset = Dataset10X(dataset_name="pbmc_10k_protein_v3", save_path=save_path, measurement_names_column=1)

INFO:scvi.dataset.dataset:File ../data/10X/pbmc_10k_protein_v3/filtered_feature_bc_matrix.tar.gz already downloaded
INFO:scvi.dataset.dataset10X:Preprocessing dataset
INFO:scvi.dataset.dataset10X:Finished preprocessing dataset
INFO:scvi.dataset.dataset:Remapping batch_indices to [0,N]
INFO:scvi.dataset.dataset:Remapping labels to [0,N]
INFO:scvi.dataset.dataset:Computing the library size for the new data
INFO:scvi.dataset.dataset:Downsampled from 7865 to 7865 cells


In [3]:
dataset.subsample_genes(new_n_genes=5000)
dataset.filter_genes_by_count(100)
# Filter control proteins
dataset.protein_expression = dataset.protein_expression[:, :-3]
dataset.protein_names = dataset.protein_names[:-3]

INFO:scvi.dataset.dataset:Downsampling from 33538 to 5000 genes
INFO:scvi.dataset.dataset:Computing the library size for the new data
INFO:scvi.dataset.dataset:Filtering non-expressing cells.
INFO:scvi.dataset.dataset:Computing the library size for the new data
INFO:scvi.dataset.dataset:Downsampled from 7865 to 7865 cells
INFO:scvi.dataset.dataset:Downsampling from 5000 to 4996 genes
INFO:scvi.dataset.dataset:Computing the library size for the new data
INFO:scvi.dataset.dataset:Filtering non-expressing cells.
INFO:scvi.dataset.dataset:Computing the library size for the new data
INFO:scvi.dataset.dataset:Downsampled from 7865 to 7865 cells


In [4]:
totalvae = TOTALVI(dataset.nb_genes, len(dataset.protein_names))
use_cuda = False
lr = 1e-2
early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 50,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 20,
    "lr_factor": 0.5,
    "posterior_class": TotalPosterior,
}

trainer = TotalTrainer(
    totalvae,
    dataset,
    train_size=0.90,
    test_size=0.05,
    use_cuda=use_cuda,
    frequency=1,
    early_stopping_kwargs=early_stopping_kwargs,
)

In [5]:
%%capture
saved_model = torch.load("../saved_models/pbmc10k_totalVI.pt", map_location="cpu")
totalvae.load_state_dict(saved_model['model_state_dict'])
totalvae.eval()

In [11]:
trainer.validation_set.compute_reconstruction_error(totalvae)

(3377.2773735687024, 69.13958631520356)

## Create posterior

In [None]:
full_posterior = trainer.create_posterior(totalvae, dataset, indices=np.arange(len(dataset)), type_class=TotalPosterior)
latent_mean, batch_index, label, library_gene, library_protein, library_background = full_posterior.sequential().get_latent()
latent, _, _, _, _, _ = full_posterior.sequential().get_latent(sample=True)
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
N_SAMPLES = 100
px_dropout = sigmoid(full_posterior.sequential().get_sample_dropout(n_samples=N_SAMPLES, give_mean=True))
adjusted_px_scale = full_posterior.sequential().get_normalized_denoised_expresssion(n_samples=N_SAMPLES, give_mean=True)

In [None]:
umap_dr = umap.UMAP(n_neighbors=15, random_state=42, min_dist=0.1).fit_transform(latent_mean)

In [None]:
fig, axarr = plt.subplots(4,5, figsize=(15,10))
for i in range(len(dataset.protein_names)):
    axarr.flat[i].hist(px_dropout[:, dataset.nb_genes + i])
    axarr.flat[i].set_title(dataset.protein_names[i])
plt.suptitle('Probability in Background')
sns.despine()
plt.tight_layout()