In [None]:
import anndata as ad
import matplotlib.pyplot as plt
import mudata as md
import muon
import scanpy as sc
import scvi
import time

t = time.time()
scvi.settings.seed = 1234

In [None]:
sc.set_figure_params(figsize=(4, 4))

%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

In [None]:
adata_rna = sc.read_h5ad('./data/spatial DBIT-seq mouse embryo/mouse_embro_0713_rna.h5ad')
adata_adt = sc.read_h5ad('./data/spatial DBIT-seq mouse embryo/mouse_embro_0713_protein.h5ad')
adata_adt=adata_adt[adata_rna.obs_names].copy()

adata_rna.var_names_make_unique()
adata_adt.var_names_make_unique()

In [None]:
sc.pp.filter_genes(adata_rna, min_cells=10)

In [None]:
adata_rna.obs['batch'] = 'Thymus'
adata_rna.layers['counts'] = adata_rna.X.copy()

In [None]:
mdata = md.MuData({"rna": adata_rna, "protein": adata_adt})
mdata

In [None]:
sc.pp.highly_variable_genes(
    mdata.mod["rna"],
    n_top_genes=4000,
    flavor="seurat_v3",
    batch_key="batch",
    layer="counts",
)
# Place subsetted counts in a new modality
mdata.mod["rna_subset"] = mdata.mod["rna"][
    :, mdata.mod["rna"].var["highly_variable"]
].copy()

In [None]:
mdata

In [None]:
mdata.update()

In [None]:
scvi.model.TOTALVI.setup_mudata(
    mdata,
    rna_layer="counts",
    protein_layer=None,
    batch_key="batch",
    modalities={
        "rna_layer": "rna_subset",
        "protein_layer": "protein",
        "batch_key": "rna_subset",
    },
)

In [None]:
vae = scvi.model.TOTALVI(mdata)

In [None]:
vae.train()

In [None]:
fig, ax = plt.subplots(1, 1)
vae.history["elbo_train"].plot(ax=ax, label="train")
vae.history["elbo_validation"].plot(ax=ax, label="validation")
ax.set(title="Negative ELBO over training epochs", ylim=(1200, 1400))
ax.legend()

In [None]:
rna = mdata.mod["rna_subset"]
protein = mdata.mod["protein"]
# arbitrarily store latent in rna modality
rna.obsm["X_totalVI"] = vae.get_latent_representation()