### Notebook for identification of gene weights in fetal stem cells populations with `expimap` (comparative model)

- **Developed by:** Anna Maguza
- **Place:** Wuerzburg Institute for System Immunology
- **Created date:** 25th April 2024
- **Last modified date:** 25th April 2024

#### Import required modules

In [None]:
import gdown
import torch
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import scarches as sca
import matplotlib.pyplot as plt

In [None]:
def X_is_raw(adata):
    return np.array_equal(adata.X.sum(axis=0).astype(int), adata.X.sum(axis=0))

#### Set up working environment

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
warnings.simplefilter(action = 'ignore')
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

In [None]:
torch.cuda.is_available()

#### Read data

In [None]:
input = 'FetalSC_data/Fetal_cells_scvi.h5ad'
adata = sc.read_h5ad(input)

In [None]:
X_is_raw(adata)

In [None]:
adata = adata.raw.to_adata()
adata

### Read the Reactome annotations

In [None]:
reactome_link = '/mnt/LaCIE/annaM/reactome/reactome.gmt'

### Prepare reference data with ReactomeDB pathways

In [None]:
sca.utils.add_annotations(adata, reactome_link, min_genes = 12, clean = True)

- Remove all genes that are present in the data but absent in ReactomeDB

In [None]:
adata._inplace_subset_var(adata.varm['I'].sum(1)>0)

### Calculate HVGs

+ Delete donors with only one cells

In [None]:
donors_to_keep = adata.obs['Donor_ID'].value_counts()[adata.obs['Donor_ID'].value_counts()>1].index
adata = adata[adata.obs['Donor_ID'].isin(donors_to_keep)]

In [None]:
adata_raw = adata.copy()
adata.layers['counts'] = adata.X.copy()

sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 7000,
    layer = "counts",
    batch_key = "Donor_ID",
    subset = True,
    span = 1
)
adata

- Filter out all annotations (terms) with less than 12 genes.

In [None]:
select_terms = adata.varm['I'].sum(0)>12
adata.uns['terms'] = np.array(adata.uns['terms'])[select_terms].tolist()
adata.varm['I'] = adata.varm['I'][:, select_terms]

- Filter out genes not present in any of the terms after selection of HVGs.

In [None]:
adata._inplace_subset_var(adata.varm['I'].sum(1)>0)

### Create expiMap model and train it on reference dataset

In [None]:
intr_cvae = sca.models.EXPIMAP(
    adata = adata,
    condition_key = 'Cell States',
    hidden_layer_sizes = [256, 256, 256],
    recon_loss = 'nb'
)

In [None]:
ALPHA = 0.7

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss", # val_unweighted_loss
    "threshold": 0,
    "patience": 50,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}
intr_cvae.train(
    n_epochs = 400,
    alpha_epoch_anneal = 100,
    alpha = ALPHA,
    alpha_kl = 0.5,
    weight_decay = 0.,
    early_stopping_kwargs = early_stopping_kwargs,
    use_early_stopping = True,
    monitor_only_val = False,
    seed = 1712,
    accelerator = "gpu",
    devices = [0]
)

In [None]:
MEAN = False

In [None]:
adata.obsm['X_cvae'] = intr_cvae.get_latent(mean = MEAN, only_active = True)

+ Visualize latent representation on the UMAP

In [None]:
sc.pp.neighbors(adata, use_rep = "X_cvae", n_neighbors = 20, metric = 'minkowski')
sc.tl.umap(adata, min_dist = 0.5, spread = 6, random_state = 1712)

In [None]:
sc.set_figure_params(dpi = 300, figsize=(10,7))
sc.pl.umap(adata, color = ['Cell States'], ncols = 3, frameon = False, size = 5)

In [None]:
sc.set_figure_params(dpi = 300, figsize=(10,7))
sc.pl.umap(adata, color = ['Library_Preparation_Protocol', 'Donor_ID'], ncols = 3, frameon = False, size = 5)

+ Extract meaningful latent directions

In [None]:
intr_cvae.latent_directions(adata=adata)

In [None]:
sc.set_figure_params(dpi=300, figsize=(20, 70))
intr_cvae.latent_enrich(groups = 'Cell States', comparison = 'FXYD3+_CKB+_SC', adata = adata, use_directions = True)
fig = sca.plotting.plot_abs_bfs(adata, yt_step = 0.3, scale_y = 2, fontsize = 9)

+ Compare by cell states

In [None]:
adata.obs['Cell Type'].value_counts()

In [None]:
selected_cell_types = ['Epithelial']  
adata_filtered = adata[adata.obs['Cell Type'].isin(selected_cell_types)].copy()

intr_cvae.latent_enrich(groups='Cell States', adata=adata_filtered, n_sample=10000, use_directions=True)

fig = sca.plotting.plot_abs_bfs(adata_filtered, n_cols=3, scale_y=2.6, yt_step=0.6, fontsize = 9)
fig.set_size_inches(20, 70)

In [None]:
selected_cell_types = ['Mesenchymal']  
adata_filtered = adata[adata.obs['Cell Type'].isin(selected_cell_types)].copy()

intr_cvae.latent_enrich(groups='Cell States', adata=adata_filtered, n_sample=10000, use_directions=True)

fig = sca.plotting.plot_abs_bfs(adata_filtered, n_cols=3, scale_y=2.6, yt_step=0.6, fontsize = 9)
fig.set_size_inches(20, 70)

In [None]:
df = adata.obs['Cell States'].value_counts()

In [None]:
selected_cell_types = ['FXYD3+_CKB+_SC', 'MTRNR2L12+ASS1+_SC', 'RPS10+_RPS17+_SC']  
adata_filtered = adata[adata.obs['Cell States'].isin(selected_cell_types)].copy()

intr_cvae.latent_enrich(groups='Cell States', adata=adata_filtered, n_sample=10000, use_directions=True)

fig = sca.plotting.plot_abs_bfs(adata_filtered, n_cols=3, scale_y=2.6, yt_step=0.6, fontsize = 6)
fig.set_size_inches(15, 5)

In [None]:
intr_cvae.latent_enrich(groups='Cell States', adata=adata, n_sample=10000, use_directions=True)

fig = sca.plotting.plot_abs_bfs(adata, n_cols=5, scale_y=2.6, yt_step=0.6, fontsize = 3)
fig.set_size_inches(25, 95)