### Notebook for identification of gene weights in fetal stem cells populations with `expimap` 

- **Developed by:** Anna Maguza
- **Place:** Wuerzburg Institute for System Immunology
- **Created date:** 17th April 2024
- **Last modified date:** 22nd April 2024

#### Import required modules

In [1]:
import gdown
import torch
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import scarches as sca
import matplotlib.pyplot as plt

  self.seed = seed
  self.dl_pin_memory_gpu_training = (
 captum (see https://github.com/pytorch/captum).


In [3]:
def X_is_raw(adata):
    return np.array_equal(adata.X.sum(axis=0).astype(int), adata.X.sum(axis=0))

#### Set up working environment

In [4]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

-----
anndata     0.10.5.post1
scanpy      1.9.8
-----
PIL                         10.2.0
absl                        NA
aiohttp                     3.9.3
aiosignal                   1.3.1
annotated_types             0.6.0
anyio                       NA
array_api_compat            1.4.1
arrow                       1.3.0
asttokens                   NA
async_timeout               4.0.3
attr                        23.2.0
attrs                       23.2.0
babel                       2.14.0
backoff                     2.2.1
bs4                         4.12.3
certifi                     2024.02.02
cffi                        1.16.0
charset_normalizer          3.3.2
chex                        0.1.7
click                       8.1.7
comm                        0.2.1
contextlib2                 NA
croniter                    NA
cycler                      0.12.1
cython_runtime              NA
dateutil                    2.9.0
debugpy                     1.8.1
decorator                   5.1.1

In [5]:
warnings.simplefilter(action = 'ignore')
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

In [6]:
torch.cuda.is_available()

True

#### Read data

In [54]:
input = 'FetalSC_data/Fetal_cells_scvi.h5ad'
adata = sc.read_h5ad(input)

In [55]:
X_is_raw(adata)

True

In [56]:
adata = adata.raw.to_adata()
adata

AnnData object with n_obs × n_vars = 231646 × 26442
    obs: 'Sample_ID', 'Cell Type', 'Study_name', 'Donor_ID', 'Diagnosis', 'Age', 'Region code', 'Fraction', 'Sex', 'Library_Preparation_Protocol', 'batch', 'Age_group', 'Location', 'Cell States', 'Cell States GCA', 'Chem', 'Layer', 'Cell States Kong', 'dataset', 'n_genes_by_counts', 'total_counts', 'total_counts_mito', 'pct_counts_mito', 'total_counts_ribo', 'pct_counts_ribo', 'Cell_ID', '_scvi_batch', '_scvi_labels', 'n_genes', 'n_counts'
    var: 'feature_types-0-0-0', 'gene_name-1-0-0', 'gene_id-0-0', 'GENE-1-0'
    uns: '_scvi_manager_uuid', '_scvi_uuid', 'hvg', 'neighbors', 'umap'
    obsm: 'X_scVI', 'X_umap', '_scvi_extra_categorical_covs', '_scvi_extra_continuous_covs'
    obsp: 'connectivities', 'distances'

In [10]:
adataSC = adata[adata.obs['Cell States']=='FXYD3+_CKB+_SC']

### Read the Reactome annotations

In [11]:
reactome_link = '/mnt/LaCIE/annaM/reactome/reactome.gmt'

### Prepare reference data with ReactomeDB pathways

In [12]:
sca.utils.add_annotations(adataSC, reactome_link, min_genes = 12, clean = True)

- Remove all genes that are present in the data but absent in ReactomeDB

In [13]:
adataSC._inplace_subset_var(adataSC.varm['I'].sum(1)>0)

### Calculate HVGs

+ Delete donors with only one cells

In [14]:
donors_to_keep = adataSC.obs['Donor_ID'].value_counts()[adataSC.obs['Donor_ID'].value_counts()>1].index
adataSC = adataSC[adataSC.obs['Donor_ID'].isin(donors_to_keep)]

In [15]:
adataSC_raw = adataSC.copy()
adataSC.layers['counts'] = adataSC.X.copy()

sc.pp.highly_variable_genes(
    adataSC,
    flavor = "seurat_v3",
    n_top_genes = 3000,
    layer = "counts",
    batch_key = "Donor_ID",
    subset = True,
    span = 1
)
adataSC

If you pass `n_top_genes`, all cutoffs are ignored.
extracting highly variable genes
--> added
    'highly_variable', boolean vector (adata.var)
    'highly_variable_rank', float vector (adata.var)
    'means', float vector (adata.var)
    'variances', float vector (adata.var)
    'variances_norm', float vector (adata.var)


AnnData object with n_obs × n_vars = 293 × 3000
    obs: 'Sample_ID', 'Cell Type', 'Study_name', 'Donor_ID', 'Diagnosis', 'Age', 'Region code', 'Fraction', 'Sex', 'Library_Preparation_Protocol', 'batch', 'Age_group', 'Location', 'Cell States', 'Cell States GCA', 'Chem', 'Layer', 'Cell States Kong', 'dataset', 'n_genes_by_counts', 'total_counts', 'total_counts_mito', 'pct_counts_mito', 'total_counts_ribo', 'pct_counts_ribo', 'Cell_ID', '_scvi_batch', '_scvi_labels', 'n_genes', 'n_counts'
    var: 'feature_types-0-0-0', 'gene_name-1-0-0', 'gene_id-0-0', 'GENE-1-0', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    uns: '_scvi_manager_uuid', '_scvi_uuid', 'hvg', 'neighbors', 'umap', 'terms'
    obsm: 'X_scVI', 'X_umap', '_scvi_extra_categorical_covs', '_scvi_extra_continuous_covs'
    varm: 'I'
    layers: 'counts'
    obsp: 'connectivities', 'distances'

- Filter out all annotations (terms) with less than 12 genes.

In [16]:
select_terms = adataSC.varm['I'].sum(0)>12
adataSC.uns['terms'] = np.array(adataSC.uns['terms'])[select_terms].tolist()
adataSC.varm['I'] = adataSC.varm['I'][:, select_terms]

- Filter out genes not present in any of the terms after selection of HVGs.

In [17]:
adataSC._inplace_subset_var(adataSC.varm['I'].sum(1)>0)

### Create expiMap model and train it on reference dataset

In [18]:
intr_cvae = sca.models.EXPIMAP(
    adata = adataSC,
    condition_key = 'Cell States',
    hidden_layer_sizes = [256, 256, 256],
    recon_loss = 'nb'
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 2992 256 1
	Hidden Layer 1 in/out: 256 256
	Hidden Layer 2 in/out: 256 256
	Mean/Var Layer in/out: 256 399
Decoder Architecture:
	Masked linear layer in, ext_m, ext, cond, out:  399 0 0 1 2992
	with hard mask.
Last Decoder layer: softmax


In [19]:
ALPHA = 0.7

In [20]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss", # val_unweighted_loss
    "threshold": 0,
    "patience": 50,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}
intr_cvae.train(
    n_epochs = 400,
    alpha_epoch_anneal = 100,
    alpha = ALPHA,
    alpha_kl = 0.5,
    weight_decay = 0.,
    early_stopping_kwargs = early_stopping_kwargs,
    use_early_stopping = True,
    monitor_only_val = False,
    seed = 1712,
    accelerator = "gpu",
    devices = [0]
)

Preparing (293, 2992)
Instantiating dataset
Init the group lasso proximal operator for the main terms.
 |██████████████------| 70.2%  - epoch_loss: 2826.4805501302 - epoch_recon_loss: 2662.2540690104 - epoch_kl_loss: 328.4529825846 - val_loss: 2376.0605468750 - val_recon_loss: 2237.8078613281 - val_kl_loss: 276.505523681661
ADJUSTED LR
 |████████████████----| 82.8%  - epoch_loss: 2810.1414388021 - epoch_recon_loss: 2654.0559895833 - epoch_kl_loss: 312.1708780924 - val_loss: 2368.1330566406 - val_recon_loss: 2225.3942871094 - val_kl_loss: 285.4774169922
ADJUSTED LR
 |█████████████████---| 86.0%  - epoch_loss: 2805.5184733073 - epoch_recon_loss: 2642.5289713542 - epoch_kl_loss: 325.9790445964 - val_loss: 2366.8515625000 - val_recon_loss: 2224.8706054688 - val_kl_loss: 283.9620666504
ADJUSTED LR
 |██████████████████--| 90.5%  - epoch_loss: 2786.3281250000 - epoch_recon_loss: 2629.6144205729 - epoch_kl_loss: 313.4275919596 - val_loss: 2366.4731445312 - val_recon_loss: 2224.5529785156 - val

In [21]:
MEAN = False

In [22]:
adataSC.obsm['X_cvae'] = intr_cvae.get_latent(mean = MEAN, only_active = True)

### Initlizling the model for query training

+ Prepare query dataset

In [23]:
adata_query = adata[adata.obs['Cell States']!='FXYD3+_CKB+_SC']

In [24]:
adata_query = adata_query[:, adataSC.var_names]

+ Add terms to query dataset

In [25]:
adata_query.uns['terms'] = adataSC.uns['terms']

+ Train the model

In [26]:
q_intr_cvae = sca.models.EXPIMAP.load_query_data(adata_query, intr_cvae,
                                                 unfreeze_ext=True,
                                                 use_hsic=True,
                                                 hsic_one_vs_all=True
                                                )

AnnData object with n_obs × n_vars = 231352 × 2992
    obs: 'Sample_ID', 'Cell Type', 'Study_name', 'Donor_ID', 'Diagnosis', 'Age', 'Region code', 'Fraction', 'Sex', 'Library_Preparation_Protocol', 'batch', 'Age_group', 'Location', 'Cell States', 'Cell States GCA', 'Chem', 'Layer', 'Cell States Kong', 'dataset', 'n_genes_by_counts', 'total_counts', 'total_counts_mito', 'pct_counts_mito', 'total_counts_ribo', 'pct_counts_ribo', 'Cell_ID', '_scvi_batch', '_scvi_labels', 'n_genes', 'n_counts'
    var: 'feature_types-0-0-0', 'gene_name-1-0-0', 'gene_id-0-0', 'GENE-1-0'
    uns: '_scvi_manager_uuid', '_scvi_uuid', 'hvg', 'neighbors', 'umap', 'terms'
    obsm: 'X_scVI', 'X_umap', '_scvi_extra_categorical_covs', '_scvi_extra_continuous_covs'
    obsp: 'connectivities', 'distances'

INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 2992 256 101
	Hidden Layer 1 in/out: 256 256
	Hidden Layer 2 in/out: 256 256
	Mean/Var Layer in/out: 256 399
Decoder Archi

In [27]:
q_intr_cvae.train(n_epochs = 400, 
                  alpha_epoch_anneal = 100, 
                  weight_decay = 0.,
                  alpha_kl = 0.1, 
                  seed = 1712, 
                  alpha_l1=0.96,
                  gamma_ext=0.7,
                  gamma_epoch_anneal=50, 
                  beta=3.,
                  use_early_stopping = True,
                  accelerator = "gpu",
                  devices = [0])

Preparing (231352, 2992)
Instantiating dataset
 |████████████████████| 100.0%  - val_loss: 1722.4457097343 - val_recon_loss: 1705.0819442496 - val_kl_loss: 173.6376127802
Saving best state of network...
Best State was in Epoch 398


### Get latent representation of reference + query dataset

In [76]:
adata_joined = sc.AnnData.concatenate(adataSC, adata_query, batch_key = 'batch_join', uns_merge = 'same', index_unique = None)
adata_joined

AnnData object with n_obs × n_vars = 231645 × 2992
    obs: 'Sample_ID', 'Cell Type', 'Study_name', 'Donor_ID', 'Diagnosis', 'Age', 'Region code', 'Fraction', 'Sex', 'Library_Preparation_Protocol', 'batch', 'Age_group', 'Location', 'Cell States', 'Cell States GCA', 'Chem', 'Layer', 'Cell States Kong', 'dataset', 'n_genes_by_counts', 'total_counts', 'total_counts_mito', 'pct_counts_mito', 'total_counts_ribo', 'pct_counts_ribo', 'Cell_ID', '_scvi_batch', '_scvi_labels', 'n_genes', 'n_counts', 'batch_join'
    var: 'feature_types-0-0-0', 'gene_name-1-0-0', 'gene_id-0-0', 'GENE-1-0', 'highly_variable-0', 'highly_variable_rank-0', 'means-0', 'variances-0', 'variances_norm-0', 'highly_variable_nbatches-0'
    uns: '_scvi_manager_uuid', '_scvi_uuid', 'hvg', 'neighbors', 'umap', 'terms'
    obsm: 'X_scVI', 'X_umap', '_scvi_extra_categorical_covs', '_scvi_extra_continuous_covs'

+ Get latent representation for the concatenated dataset

In [77]:
adata_joined.obsm['X_cvae'] = q_intr_cvae.get_latent(adata_joined.X, 
                                                adata_joined.obs['Cell States'], 
                                                mean = MEAN, 
                                                only_active = True)

+ Visualize latent representation on the UMAP

In [78]:
sc.pp.neighbors(adata_joined, use_rep = "X_cvae", n_neighbors = 20, metric = 'minkowski')
sc.tl.umap(adata_joined, min_dist = 0.5, spread = 6, random_state = 1712)

computing neighbors
    finished: added to `.uns['neighbors']`
    `.obsp['distances']`, distances for each pair of neighbors
    `.obsp['connectivities']`, weighted adjacency matrix (0:00:33)
computing UMAP
    finished: added
    'X_umap', UMAP coordinates (adata.obsm) (0:01:42)


+ Extract meaningful latent directions

In [79]:
q_intr_cvae.latent_directions(adata=adata_joined)

In [32]:
adata_joined_epi = adata_joined.copy()
adata_joined_epi = adata_joined_epi[adata_joined_epi.obs['Cell Type'].isin(['Epithelial'])]

In [None]:
sc.set_figure_params(dpi=300, figsize=(20, 70))
q_intr_cvae.latent_enrich(groups = 'Cell States', comparison = 'FXYD3+_CKB+_SC', adata = adata_joined_epi, use_directions = True)
fig = sca.plotting.plot_abs_bfs(adata_joined_epi, yt_step = 0.3, scale_y = 2, fontsize = 9)

+ Compare by cell states

In [None]:
q_intr_cvae.latent_enrich(groups='Cell States', adata=adata_joined_epi, n_sample=10000, use_directions=True)
fig = sca.plotting.plot_abs_bfs(adata_joined_epi, n_cols=3, scale_y=2.6, yt_step=0.6, fontsize = 9)
fig.set_size_inches(20, 70)

### Visualize gene programs of interest

In [80]:
adata_joined.obs['comparison'] = adata_joined.obs['Cell States'] == 'FXYD3+_CKB+_SC'

adata_joined.obs['comparison'] = adata_joined.obs['comparison'].astype(str)

In [81]:
terms = adata_joined.uns['terms']

In [82]:
terms = adata_joined.uns['terms']
select_terms = ['METABOLISM_OF_LIPIDS_AND_LIPOP', 'METABOLISM_OF_PROTEINS', 'IMMUNE_SYSTEM', 'METABOLISM_OF_MRNA', 'TRANSMEMBRANE_TRANSPORT_OF_SMA',
                'METABOLISM_OF_AMINO_ACIDS_AND_']
idx = [terms.index(term) for term in select_terms]

In [83]:
latents = (q_intr_cvae.get_latent(adata_joined.X, adata_joined.obs['Cell States'], mean=MEAN) * adata_joined.uns['directions'])[:, idx]

In [84]:
adata_joined.obs['METABOLISM_OF_LIPIDS_AND_LIPOP'] = latents[:, 0]
adata_joined.obs['METABOLISM_OF_PROTEINS'] = latents[:, 1]
adata_joined.obs['IMMUNE_SYSTEM'] = latents[:, 2]
adata_joined.obs['METABOLISM_OF_MRNA'] = latents[:, 3]
adata_joined.obs['TRANSMEMBRANE_TRANSPORT_OF_SMA'] = latents[:, 4]
adata_joined.obs['METABOLISM_OF_AMINO_ACIDS_AND_'] = latents[:, 5]

In [None]:
sc.set_figure_params(dpi=120, figsize=(4, 4))
sc.pl.scatter(adata_joined, x='METABOLISM_OF_LIPIDS_AND_LIPOP', y='METABOLISM_OF_PROTEINS', color='comparison', size=4)

In [None]:
sc.set_figure_params(dpi=120, figsize=(4, 4))
sc.pl.scatter(adata_joined, x='METABOLISM_OF_LIPIDS_AND_LIPOP', y='METABOLISM_OF_PROTEINS', color='Cell States', size=4)

In [None]:
sc.set_figure_params(dpi=120, figsize=(4, 4))
sc.pl.scatter(adata_joined, x='IMMUNE_SYSTEM', y='METABOLISM_OF_MRNA', color='comparison', size=4)

In [None]:
sc.set_figure_params(dpi=120, figsize=(4, 4))
sc.pl.scatter(adata_joined, x='IMMUNE_SYSTEM', y='METABOLISM_OF_MRNA', color='Cell States', size=4)

In [None]:
sc.set_figure_params(dpi=120, figsize=(4, 4))
sc.pl.scatter(adata_joined, x='TRANSMEMBRANE_TRANSPORT_OF_SMA', y='METABOLISM_OF_AMINO_ACIDS_AND_', color='comparison', size=4)

In [None]:
sc.set_figure_params(dpi=120, figsize=(4, 4))
sc.pl.scatter(adata_joined, x='TRANSMEMBRANE_TRANSPORT_OF_SMA', y='METABOLISM_OF_AMINO_ACIDS_AND_', color='Cell States', size=4)

### Extract gene weights

In [46]:
terms_to_process = [
    'METABOLISM_OF_LIPIDS_AND_LIPOP', 'METABOLISM_OF_PROTEINS', 'IMMUNE_SYSTEM', 'METABOLISM_OF_MRNA', 'TRANSMEMBRANE_TRANSPORT_OF_SMA',
                'METABOLISM_OF_AMINO_ACIDS_AND_'
]

for term in terms_to_process:
    df = q_intr_cvae.term_genes(term, terms=adata_joined.uns['terms'])
    
    file_path = f'/mnt/LaCIE/annaM/gut_project/Processed_data/Gut_data/Fetal_stem_cells/FetalSC_and_other_fetal_cells/expimap/{term}_gene_weights.csv'
    
    df.to_csv(file_path)


#### Export anndata object

In [88]:
adata_export = anndata.AnnData(X = adata.X, obs = adata_joined.obs, var = adata.var, uns = adata_joined.uns, obsm = adata_joined.obsm, obsp = adata_joined.obsp)

In [89]:
adata_export.write('/mnt/LaCIE/annaM/gut_project/Processed_data/Gut_data/Fetal_stem_cells/FetalSC_and_other_fetal_cells/expimap/Fetal_cells_after_expimap_Focus_Model.h5ad')