### Notebook for joint modelling of Protein (CITE) and GEX for AMC Mouse Immune project with `TotalVI`

- **Developed by:** Carlos Talavera-López Ph.D
- **Institute of Computational Biology - Computational Health Department - Helmholtz Munich**
- **Created on**: 240510
- **Last modified**: 240510

### Import required modules

In [37]:
import scvi
import muon
import torch
import anndata
import warnings
import numpy as np
import mudata as md
import scanpy as sc
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

### Set up working environment

In [38]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

-----
anndata     0.10.7
scanpy      1.10.1
-----
PIL                 10.3.0
absl                NA
asttokens           NA
attr                23.2.0
chex                0.1.86
colorama            0.4.6
comm                0.2.2
contextlib2         NA
cycler              0.12.1
cython_runtime      NA
dateutil            2.9.0.post0
debugpy             1.8.1
decorator           5.1.1
docrep              0.3.2
etils               1.8.0
executing           2.0.1
flax                0.8.3
fsspec              2024.3.1
h5py                3.11.0
importlib_resources NA
ipykernel           6.29.4
ipywidgets          8.1.2
jax                 0.4.28
jaxlib              0.4.28
jedi                0.19.1
joblib              1.4.2
kiwisolver          1.4.5
legacy_api_wrap     NA
lightning           2.1.4
lightning_utilities 0.11.2
llvmlite            0.42.0
matplotlib          3.8.4
matplotlib_inline   0.1.7
ml_collections      NA
ml_dtypes           0.4.0
mpl_toolkits        NA
mpmath            

In [39]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
sns.set_theme()
torch.set_float32_matmul_precision("high")
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
print("Last run with scvi-tools version:", scvi.__version__)
backend = 'pytorch'

Seed set to 1712


Last run with scvi-tools version: 1.1.2


### Read in Healthy data

In [40]:
mudata = muon.read_h5mu('../data/Subsetted_monocytes_ac240507.raw.h5mu')
mudata

### Format `muon` object for downstream analysis

In [41]:
mudata.mod['rna'].layers['counts'] = mudata.mod['rna'].X.copy()

### Select HVGs

In [42]:
sc.pp.highly_variable_genes(
    mudata.mod["rna"],
    n_top_genes = 7000,
    flavor = "seurat_v3",
    batch_key = "sample",
    layer = "counts",
    span = 1,
    subset = True
)

extracting highly variable genes
--> added
    'highly_variable', boolean vector (adata.var)
    'highly_variable_rank', float vector (adata.var)
    'means', float vector (adata.var)
    'variances', float vector (adata.var)
    'variances_norm', float vector (adata.var)


In [43]:
mudata.mod

{'rna': AnnData object with n_obs × n_vars = 3316 × 7000
     obs: 'cell_source', 'donor', 'n_counts', 'sample', 'seed_labels', 'condition', 'genotype', 'infection', 'library', 'model', 'n_genes_by_counts', 'total_counts', 'doublet_scores', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', '_scvi_batch', '_scvi_labels', 'batch', 'C_scANVI', 'leiden', 'classification'
     var: 'gene_ids', 'feature_types', 'mt', 'ribo', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
     uns: 'C_scANVI_colors', 'classification_colors', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'rank_genes_groups', 'hvg'
     obsm: 'X_scANVI', 'X_scVI', 'X_umap'
     layers: 'counts'
     obsp: 'connectivities', 'distances',
 'prot': AnnData object with n_obs × n_vars = 3316 × 99
     obs: 'library', 'batch'
     var: 'gene_ids', 'feature_types'
     uns: 'ne

In [44]:
mudata.update()

### Set up `muon` dataset for input to `TotalVI`.

In [45]:
scvi.model.TOTALVI.setup_mudata(
    mudata,
    rna_layer = "counts",
    protein_layer = None,
    batch_key = "sample",
    modalities = {
        "rna_layer": "rna",
        "protein_layer": "prot",
        "batch_key": "rna",
    },
)

INFO     Found batches with missing protein expression                                                             


In [46]:
model = scvi.model.TOTALVI(mudata, empirical_protein_background_prior = False)

In [47]:
model.train(400, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 accelerator = 'gpu',
                 devices= [1])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Epoch 31/400:   8%|▊         | 30/400 [00:04<01:03,  5.82it/s, v_num=1, train_loss_step=5.57e+3, train_loss_epoch=5.71e+3]

In [None]:
fig, ax = plt.subplots(1, 1)
model.history["elbo_train"].plot(ax=ax, label="train")
model.history["elbo_validation"].plot(ax=ax, label="validation")
ax.set(title="Negative ELBO over training epochs", ylim=(1200, 1400))
ax.legend()

### Analyze outputs

In [None]:
rna = mudata.mod["rna"]
protein = mudata.mod["prot"]

In [None]:
TOTALVI_LATENT_KEY = "X_totalVI"
rna.obsm[TOTALVI_LATENT_KEY] = model.get_latent_representation()

In [None]:
muon.pl.embedding(
    mudata,
    basis="rna:X_umap",
    color=protein.var_names,
    frameon=False,
    ncols=6,
    vmax="p99",
    wspace=0.1,
    layer="denoised_protein",
)

### Visualize probability of foreground

In [None]:
mudata

In [None]:
muon.pl.embedding(
    mudata,
    basis="rna:X_umap",
    layer="protein_foreground_prob",
    color = ['Ly6G_TotalA', 'CD11b_TotalA', 'CD62L_TotalA', 'IAIE_TotalA', 'ICAM1_TotalA', 'Ly6C_TotalA', 'CD115_TotalA', 'CXCR4_TotalA', 'MSR1_TotalA', 'CD64_TotalA', 'FCeRIa_TotalA', 'CCR3_TotalA', 'CD49d_TotalA', 'CD80_TotalA', 'CD117_TotalA',
       'Sca1_TotalA', 'CD11c_TotalA', 'TIM4_TotalA', 'CX3CR1_TotalA', 'XCR1_TotalA', 'F480_TotalA', 'CD86_TotalA', 'CD135_TotalA', 'CD103_TotalA', 'CD169_TotalA', 'CD8a_TotalA', 'SiglecH_TotalA', 'CD19_TotalA', 'CD3_TotalA', 'CD63_TotalA', 'CD9_TotalA',
       'CD163_TotalA', 'NK11_TotalA', 'CD279_TotalA', 'CD127_TotalA', 'CD68_TotalA', 'Sirpa_TotalA', 'CD274_TotalA', 'ITGB7_TotalA', 'CD4_TotalA', 'CD26_TotalA', 'MGL2_TotalA', 'TCRgd_TotalA', 'CCR2_TotalA', 'CD44_TotalA', 'CD21_35_TotalA', 'CD43_TotalA',
       'Hamster_TotalA', 'Rat_IgG1_TotalA', 'Rat_IgG2a_TotalA', 'Rat_IgG2b_TotalA', 'CD47_TotalA', 'SiglecF_TotalA', 'CD137_TotalA', 'CD36_TotalA', 'CCR5_TotalA', 'CD278_TotalA', 'PIRAB_TotalA', 'CD5_TotalA', 'CD304_TotalA', 'CD40_TotalA', 'CD14_TotalA',
       'CD95_TotalA', 'CD300cd_TotalA', 'IL1RL1_TotalA', 'TCRbeta_TotalA', 'Mac2_TotalA', 'CD137L_TotalA', 'CD178_TotalA', 'CD55_TotalA', 'TIGIT_TotalA', 'CD226_TotalA', 'CD39_TotalA', 'JAML_TotalA', 'CXCR5_TotalA', 'MGL1_TotalA', 'CD24_TotalA', 'CD88_TotalA',
       'CD11a_TotalA', 'CD81_TotalA', 'CD83_TotalA', 'Pdpn_TotalA', 'IgM_TotalA', 'TIM3_TotalA', 'BTLA_TotalA', 'CD223_TotalA', 'CD25_TotalA', 'CD152_TotalA', 'KLRG1_TotalA', 'rna:condition', 'rna:genotype', 'rna:infection', 'rna:classification'],
    frameon=False,
    ncols=6,
    vmax="p99",
    wspace=0.1,
    color_map="RdPu",
    size = 1.5
)

### Differential protein expression

In [None]:
de_df = model.differential_expression(
    groupby = "rna:classification", 
    delta = 0.5, 
    batch_correction = True
)
de_df.head(5)

In [None]:
filtered_pro = {}
filtered_rna = {}
cats = rna.obs['classification'].cat.categories
for i, c in enumerate(cats):
    cid = f"{c} vs Rest"
    cell_type_df = de_df.loc[de_df.comparison == cid]
    cell_type_df = cell_type_df.sort_values("lfc_median", ascending=False)

    cell_type_df = cell_type_df[cell_type_df.lfc_median > 0]

    pro_rows = cell_type_df.index.str.contains("TotalSeqB")
    data_pro = cell_type_df.iloc[pro_rows]
    data_pro = data_pro[data_pro["bayes_factor"] > 0.7]

    data_rna = cell_type_df.iloc[~pro_rows]
    data_rna = data_rna[data_rna["bayes_factor"] > 3]
    data_rna = data_rna[data_rna["non_zeros_proportion1"] > 0.1]

    filtered_pro[c] = data_pro.index.tolist()[:3]
    filtered_rna[c] = data_rna.index.tolist()[:2]

In [None]:
sc.tl.dendrogram(rna, groupby='classification', use_rep=TOTALVI_LATENT_KEY)
protein.obs['classification'] = rna.obs['classification']
protein.obsm[TOTALVI_LATENT_KEY] = rna.obsm[TOTALVI_LATENT_KEY]
sc.tl.dendrogram(protein, groupby='classification', use_rep=TOTALVI_LATENT_KEY)

In [None]:
sc.pl.dotplot(
    rna,
    filtered_rna,
    groupby='classification'
    ,
    dendrogram=True,
    standard_scale="var",
    swap_axes=True,
)