ref: https://docs.scvi-tools.org/en/stable/user_guide/notebooks/MultiVI_tutorial.html

## Import modules

In [None]:
import scvi
import anndata
import scipy
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

scvi.settings.seed = 420

In [None]:
import session_info
session_info.show()

## Read in data

In [None]:
adata_mvi=sc.read_h5ad('/nfs/team205/kk18/data/6region_v2/MultiVI/adata_mvi_downsized.h5ad')
print(adata_mvi.X.data[:10])
adata_mvi

In [None]:
adata_mvi.obs['modality'].value_counts()

## Setup and Training MultiVI

In [None]:
scvi.data.setup_anndata(adata_mvi, batch_key='modality', categorical_covariate_keys=['donor_cellnuc'])

In [None]:
scvi.data.view_anndata_setup(adata_mvi)

In [None]:
# When creating the object, we need to specify how many of the features are genes, and how many are genomic regions. 
# This is so MultiVI can determine the exact architecture for each modality.
mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
)
mvi.train()

In [None]:
mvi.save("/nfs/team205/kk18/data/6region_v2/MultiVI/model/trained_multivi_downsized_20210922")

In [None]:
adata_mvi.obsm["MultiVI_latent"] = mvi.get_latent_representation()

In [None]:
adata_mvi.write('/nfs/team205/kk18/data/6region_v2/MultiVI/adata_post-multivi_downsized.h5ad')

In [None]:
sc.pp.neighbors(adata_mvi, use_rep="MultiVI_latent")
sc.tl.umap(adata_mvi, min_dist=0.2)
sc.pl.umap(adata_mvi, color=['modality','cell_states'])