In [1]:
import numpy as np
import scanpy as sc
import scvi
import pandas as pd
import anndata as ad
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
t = time.time()

In [None]:
scvi.settings.seed = 420

In [10]:
def combine(adata_RNA, adata_peak):
    adata_RNA.var['modality'] = 'Gene Expression'
    adata_peak.var['modality'] = 'Peaks'
    #exp = np.hstack([np.array(adata_RNA.X.toarray()), np.array(adata_peak.X.toarray())])
    exp = np.hstack([np.array(adata_RNA.X.toarray()), adata_peak.X])
    cell_name = list(adata_RNA.obs_names)
    gene_name = list(adata_RNA.var_names) + list(adata_peak.var_names)
    modality = ['Gene Expression'] * adata_RNA.n_vars + ['Peaks'] * adata_peak.n_vars

    obs = pd.DataFrame(index=cell_name)
    var = pd.DataFrame(index=gene_name)
    adata_RNA_peak = ad.AnnData(X=exp, obs=obs, var=var)

    adata_RNA_peak.var['modality'] = modality
    adata_RNA_peak.obsm['spatial'] = adata_RNA.obsm['spatial']

    return adata_RNA_peak

In [None]:
adata_RNA = sc.read_h5ad('./data/P22 mouse brain/P22_mouse_brain_adata_RNA.h5ad')
adata_peak = sc.read_h5ad('./data/P22 mouse brain/P22_mouse_brain_adata_atac.h5ad')

adata_RNA.var_names_make_unique()
adata_peak.var_names_make_unique()

adata = combine(adata_RNA, adata_peak)

In [None]:
adata

In [13]:
# split to three datasets by modality (RNA, ATAC, Multiome), and corrupt data
# by remove some data to create single-modality data
n = 3000
adata_rna = adata[:n].copy()
adata_paired = adata[n:2*n].copy()
adata_atac = adata[2*n:].copy()

In [None]:
# We can now use the organizing method from scvi to concatenate these anndata
adata_mvi = scvi.data.organize_multiome_anndatas(adata_paired, adata_rna, adata_atac)

In [15]:
adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()
#adata_mvi.var

In [None]:
print(adata_mvi.shape)
sc.pp.filter_genes(adata_mvi, min_cells=int(adata_mvi.shape[0] * 0.01))
#sc.pp.filter_cells(adata_mvi, min_genes=3)
print(adata_mvi.shape)

In [17]:
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key="modality")

In [None]:
mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var["modality"] == "Gene Expression").sum(),
    n_regions=(adata_mvi.var["modality"] == "Peaks").sum(),
)
mvi.view_anndata_setup()

In [19]:
# fill nan value with 0
import pandas as pd
df = pd.DataFrame(adata_mvi.X)
df.fillna(0, inplace=True)
adata_mvi.X = df.values

In [None]:
mvi.train()

In [21]:
# obtain latent representation
adata_mvi.obsm["X_MultiVI"] = mvi.get_latent_representation()