### Notebook for the functional comparison of CMC genotypes with `expimap`

- **Developed by:** Carlos Talavera-López Ph.D
- **Würzburg Institute for Systems Immunology - Faculty of Medicine - Julius-Maximilian-Universität Würzburg**
- **Created on**: 240216
- **Last modified**: 240216

### Import required modules

In [1]:
import gdown
import torch
import anndata
import warnings
import numpy as np
import scipy as sp
import scanpy as sc
import pandas as pd
import scarches as sca
import matplotlib.pyplot as plt



### Set up working environment

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
warnings.simplefilter(action = 'ignore')
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

### Read in whole data and split in WT (reference) and other genotypes (query)

In [None]:
DMD_CMC = sc.read_h5ad('../../data/heart_mm_nuclei-23-0092_CMC_states_ctl240131.raw.h5ad')
DMD_CMC

In [None]:
if sp.sparse.issparse(DMD_CMC.X):
    DMD_CMC.X = DMD_CMC.X.toarray().astype(np.float32)

In [None]:
DMD_CMC.obs['genotype'].value_counts()

In [None]:
WT_CMC = DMD_CMC[DMD_CMC.obs['genotype'].isin(['WT'])]
WT_CMC

In [None]:
Mdx_CMC = DMD_CMC[~DMD_CMC.obs['genotype'].isin(['WT'])]
Mdx_CMC

### Read the Reactome annotations

In [None]:
url = 'https://drive.google.com/uc?id=1136LntaVr92G1MphGeMVcmpE0AqcqM6c'
output = 'reactome.gmt'
gdown.download(url, output, quiet=False)

### Prepare reference data with ReactomeDB pathways

In [None]:
sca.utils.add_annotations(WT_CMC, 'reactome.gmt', min_genes = 6, clean = True)

- Remove all genes that are present in the data but absent in ReactomeDB

In [None]:
WT_CMC._inplace_subset_var(WT_CMC.varm['I'].sum(1)>0)

### Calculate HVGs

In [None]:
ref_raw = WT_CMC.copy()
WT_CMC.layers['counts'] = WT_CMC.X.copy()

sc.pp.highly_variable_genes(
    WT_CMC,
    flavor = "seurat_v3",
    n_top_genes = 7000,
    layer = "counts",
    batch_key = "sample",
    subset = True,
    span = 1
)
WT_CMC

- Filter out all annotations (terms) with less than 12 genes.

In [None]:
select_terms = WT_CMC.varm['I'].sum(0)>12
WT_CMC.uns['terms'] = np.array(WT_CMC.uns['terms'])[select_terms].tolist()
WT_CMC.varm['I'] = WT_CMC.varm['I'][:, select_terms]

- Filter out genes not present in any of the terms after selection of HVGs.

In [None]:
WT_CMC._inplace_subset_var(WT_CMC.varm['I'].sum(1)>0)

### Create expiMap model and train it on reference dataset

In [None]:
intr_cvae = sca.models.EXPIMAP(
    adata = WT_CMC,
    condition_key = 'genotype',
    hidden_layer_sizes = [256, 256, 256],
    recon_loss = 'nb'
)

In [None]:
ALPHA = 0.7

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss", # val_unweighted_loss
    "threshold": 0,
    "patience": 50,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}
intr_cvae.train(
    n_epochs = 400,
    alpha_epoch_anneal = 100,
    alpha = ALPHA,
    alpha_kl = 0.5,
    weight_decay = 0.,
    early_stopping_kwargs = early_stopping_kwargs,
    use_early_stopping = True,
    monitor_only_val = False,
    seed = 1712,
)

In [None]:
MEAN = False

In [None]:
WT_CMC.obsm['X_cvae'] = intr_cvae.get_latent(mean = MEAN, only_active = True)

### Plot latent space of the reference dataset

In [None]:
sc.pp.neighbors(WT_CMC, use_rep = "X_cvae", n_neighbors = 50, metric = 'minkowski')
sc.tl.umap(WT_CMC, min_dist = 0.3, spread = 4, random_state = 1712)
sc.pl.umap(WT_CMC, frameon = False, color = ['C_scANVI', 'donor', 'genotype'], size = 3, legend_fontsize = 5, ncols = 4)

### Read and format query dataset

In [None]:
Mdx_CMC = Mdx_CMC[:, WT_CMC.var_names].copy()
Mdx_CMC

In [None]:
Mdx_CMC.uns['terms'] = WT_CMC.uns['terms']

### Initlizling the model for query training

In [None]:
q_intr_cvae = sca.models.EXPIMAP.load_query_data(Mdx_CMC, intr_cvae)

In [None]:
q_intr_cvae.train(n_epochs = 400, 
                  alpha_epoch_anneal = 100, 
                  weight_decay = 0.,
                  alpha_kl = 0.1, 
                  seed = 1712, 
                  use_early_stopping = True)

### Get latent representation of reference + query dataset

In [None]:
HHH_CMC = sc.AnnData.concatenate(WT_CMC, Mdx_CMC, batch_key = 'batch_join', uns_merge = 'same')
HHH_CMC

In [None]:
HHH_CMC.obsm['X_cvae'] = q_intr_cvae.get_latent(HHH_CMC.X, 
                                                HHH_CMC.obs['region'], 
                                                mean = MEAN, 
                                                only_active = True)

In [None]:
sc.pp.neighbors(HHH_CMC, use_rep = "X_cvae", n_neighbors = 20, metric = 'minkowski')
sc.tl.umap(HHH_CMC, min_dist = 0.5, spread = 6, random_state = 1712)

In [None]:
HHH_CMC.obs['condition_joint'] = HHH_CMC.obs.region.astype(str)
HHH_CMC.obs['condition_joint'][HHH_CMC.obs['condition_joint'].astype(str)=='nan']='control'

In [None]:
sc.pl.umap(HHH_CMC, frameon = False, color = ['C_scANVI', 'donor', 'genotypes', 'condition_joint'], size = 1, legend_fontsize = 5, ncols = 4)

### Calculate directions of upregulation for each latent score and put them to

In [None]:
q_intr_cvae.latent_directions(adata = HHH_CMC)

### Do gene set enrichment test for condition in reference + query using Bayes Factors.

In [None]:
q_intr_cvae.latent_enrich(groups = 'condition_joint', comparison = 'WT', use_directions = True, adata = HHH_CMC)
fig = sca.plotting.plot_abs_bfs(HHH_CMC, yt_step = 0.3, scale_y = 2, fontsize = 6)

In [None]:
q_intr_cvae.latent_enrich(groups = 'condition_joint', comparison = 'Mdx', use_directions = True, adata = HHH_CMC)
fig = sca.plotting.plot_abs_bfs(HHH_CMC, yt_step = 0.3, scale_y = 3, fontsize = 6)

In [None]:
q_intr_cvae.latent_enrich(groups = 'condition_joint', comparison = 'MdxSCID', use_directions = True, adata = HHH_CMC)
fig = sca.plotting.plot_abs_bfs(HHH_CMC, yt_step = 0.3, scale_y = 3, fontsize = 6)

### Plot the latent variables for query + reference corresponding to the annotations 

In [None]:
terms = HHH_CMC.uns['terms']
select_terms = ['PLATELET_HOMEOSTASIS', 'DEVELOPMENTAL_BIOLOGY', 'GPCR_DOWNSTREAM_SIGNALING']
idx = [terms.index(term) for term in select_terms]

In [None]:
latents = (q_intr_cvae.get_latent(HHH_CMC.X, HHH_CMC.obs['region'], mean = MEAN) * HHH_CMC.uns['directions'])[:, idx]

HHH_CMC.obs['PLATELET_HOMEOSTASIS'] = latents[:, 0]
HHH_CMC.obs['DEVELOPMENTAL_BIOLOGY'] = latents[:, 1]
HHH_CMC.obs['GPCR_DOWNSTREAM_SIGNALING'] = latents[:, 2]

sc.pl.scatter(HHH_CMC, x = 'PLATELET_HOMEOSTASIS', y = 'DEVELOPMENTAL_BIOLOGY', color = 'condition_joint', size = 6)

In [None]:
sc.pl.scatter(HHH_CMC, x = 'PLATELET_HOMEOSTASIS', y = 'DEVELOPMENTAL_BIOLOGY', color = 'C_scANVI', size = 8)