### Notebook for joint modelling of Protein (CITE) and GEX for AMC Mouse Immune project with `TotalVI`

- **Developed by:** Carlos Talavera-López Ph.D
- **Institute of Computational Biology - Computational Health Department - Helmholtz Munich**
- **Created on**: 240510
- **Last modified**: 240510

### Import required modules

In [1]:
import scvi
import muon
import torch
import anndata
import warnings
import numpy as np
import mudata as md
import scanpy as sc
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

### Set up working environment

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
sns.set_theme()
torch.set_float32_matmul_precision("high")
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
print("Last run with scvi-tools version:", scvi.__version__)
backend = 'pytorch'

### Read in Healthy data

In [None]:
mudata = muon.read_h5mu('../data/Subsetted_monocytes_ac240507.raw.h5mu')
mudata

### Format `muon` object for downstream analysis

In [None]:
mudata.mod['rna'].layers['counts'] = mudata.mod['rna'].X.copy()

### Select HVGs

In [None]:
sc.pp.highly_variable_genes(
    mudata.mod["rna"],
    n_top_genes = 7000,
    flavor = "seurat_v3",
    batch_key = "sample",
    layer = "counts",
    span = 1,
    subset = True
)

In [None]:
mudata.mod

In [None]:
mudata.update()

### Set up `muon` dataset for input to `TotalVI`.

In [None]:
scvi.model.TOTALVI.setup_mudata(
    mudata,
    rna_layer = "counts",
    protein_layer = None,
    batch_key = "sample",
    modalities = {
        "rna_layer": "rna",
        "protein_layer": "prot",
        "batch_key": "rna",
    },
)

In [None]:
model = scvi.model.TOTALVI(mudata, empirical_protein_background_prior = False)

In [None]:
model.train(400, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 accelerator = 'mps',
                 devices= [0])

### Label transfer with `scANVI` 

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown')

In [None]:
scanvi_model.train(25)

In [None]:
adata.obs["C_scANVI"] = scanvi_model.predict(adata)

- Extract latent representation

In [None]:
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

- Visualise corrected dataset

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scANVI", n_neighbors = 50, metric = 'minkowski')
sc.tl.umap(adata, min_dist = 0.2, spread = 8, random_state = 1712)
sc.pl.umap(adata, frameon = False, color = ['donor', 'infection', 'disease', 'C_scANVI', 'scNym'], size = 1, legend_fontsize = 5, ncols = 3)