In [1]:
import anndata as ad
import scvi
import scanpy as sc
import mrvi
import pandas as pd
import scipy as sp
import numpy as np
import pickle as pkl
import utils


Global seed set to 0
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


# Data Preprocessing

In [2]:
adata = sc.read('./../data/MGH66_bacdrop.h5ad')

In [4]:
sc.pp.filter_genes(adata, min_cells=20) 
adata

AnnData object with n_obs × n_vars = 1692542 × 3867
    obs: 'sample', 'replicate'
    var: 'n_cells'

In [17]:
# Anything over 50 removes most cells
sc.pp.filter_cells(adata, min_genes=5)
adata

AnnData object with n_obs × n_vars = 34352 × 3867
    obs: 'sample', 'replicate', 'n_genes', '_scvi_sample', '_scvi_labels'
    var: 'n_cells'
    uns: 'log1p', '_scvi_uuid', '_scvi_manager_uuid'
    obsm: '_scvi_categorical_nuisance_keys', 'X_mrvi_z', 'X_mrvi_u'

In [7]:
# u, s, vh = np.linalg.svd(adata.X.toarray()) # TODO: hyperparam tuning

# Running bacdrop data through MrVI

In [8]:
mrvi.MrVI.setup_anndata(adata, sample_key="sample", categorical_nuisance_keys=["replicate"])
mrvi_model = mrvi.MrVI(adata)



In [9]:
mrvi_model.train()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 14/14: 100%|██████████| 14/14 [10:52<00:00, 46.59s/it, loss=78.6, v_num=1]


In [10]:
# Get z representation
adata.obsm["X_mrvi_z"] = mrvi_model.get_latent_representation(give_z=True)

100%|██████████| 4478/4478 [03:39<00:00, 20.36it/s]


In [11]:
# Get u representation
adata.obsm["X_mrvi_u"] = mrvi_model.get_latent_representation(give_z=False)

100%|██████████| 4478/4478 [03:40<00:00, 20.32it/s]


In [12]:
# Cells by n_sample by n_latent
# representations of each cell in its local sample
cell_sample_representations = mrvi_model.get_local_sample_representation()

100%|██████████| 2239/2239 [00:23<00:00, 97.07it/s] 


In [13]:
# Cells by n_sample by n_sample
# local sample-sample distances (section 3.1), quantifies differences in gene expression across biological samples (S x S), where S = set of samples
cell_sample_sample_distances = mrvi_model.get_local_sample_representation(return_distances=True)

100%|██████████| 2239/2239 [00:23<00:00, 96.89it/s] 


In [14]:
adata

AnnData object with n_obs × n_vars = 573099 × 3867
    obs: 'sample', 'replicate', 'n_genes', '_scvi_sample', '_scvi_labels'
    var: 'n_cells'
    uns: 'log1p', '_scvi_uuid', '_scvi_manager_uuid'
    obsm: '_scvi_categorical_nuisance_keys', 'X_mrvi_z', 'X_mrvi_u'

# Saving for futher analysis

In [15]:
# Saved for easy reloads later
outpath = './../data/MrVIoutputs/bacdrop_pp.h5ad' 
adata.write_h5ad(outpath)

In [16]:
# Serializing everything for even easier reloads later
utils.write_pickle(mrvi_model, './../data/pickles/mrvi_model.pickle')

utils.write_pickle(cell_sample_representations, './../data/pickles/sample_representations.pickle')

utils.write_pickle(cell_sample_sample_distances, './../data/pickles/sample_distances.pickle')