# NicheCompass Sample Integration Tutorial

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 18.05.2023
- **Date of Last Modification:** 18.05.2023

## 1. Setup

### 1.1 Import Libraries

In [None]:
import os
from datetime import datetime

import anndata as ad
import numpy as np
import scanpy as sc
import scipy.sparse as sp
import squidpy as sq

from nichecompass.models import NicheCompass
from nichecompass.utils import (add_gps_from_gp_dict_to_adata,
                                extract_gp_dict_from_mebocost_es_interactions,
                                extract_gp_dict_from_nichenet_ligand_target_mx,
                                extract_gp_dict_from_omnipath_lr_interactions,
                                filter_and_combine_gp_dict_gps,
                                get_unique_genes_from_gp_dict)

### 1.2 Define Parameters

In [None]:
### Dataset ###
dataset = "seqfish_mouse_organogenesis_imputed"
batches = ["batch1", "batch2", "batch3"]
spatial_key = "spatial"
n_neighbors = 12
filter_genes = True
n_hvg = 3000

### Model ###
# AnnData keys
counts_key = "log_normalized_counts"
adj_key = "spatial_connectivities"
condition_key = "batch"
gp_names_key = "nichecompass_gp_names"
active_gp_names_key = "nichecompass_active_gp_names"
gp_targets_mask_key = "nichecompass_gp_targets"
gp_sources_mask_key = "nichecompass_gp_sources"
latent_key = "nichecompass_latent"

# Architecture
active_gp_thresh_ratio = 0.1
node_label_method = "one-hop-norm" # choose one-hop-attention for more performant model training
log_variational = counts_key == "counts" # True if raw counts, False if log normalized counts

# Trainer
n_epochs = 100
n_epochs_all_gps = 25
n_epochs_no_cond_contrastive = 5
lr = 0.001
lambda_edge_recon = 500000.
lambda_gene_expr_recon = 300.
lambda_cond_contrastive = 0.
contrastive_logits_ratio = 0.
lambda_l1_masked = 30.
edge_batch_size = 128 # 2048
use_cuda_if_available = True

### Analysis ###
cell_type_key = "celltype_mapped_refined"
random_seed = 0
celL_type_colors = {
    "Epiblast" : "#FFD8B8", # Peach
    "Primitive Streak" : "#F9CB9C", # Light Peach
    "Caudal epiblast" : "#FFB8A2", # Apricot
    "PGC" : "#F08CAE", # Pink
    "Anterior Primitive Streak" : "#F49FAD", # Pale Pink
    "Notochord" : "#E7F0A4", # Pale Canary
    "Def. endoderm" : "#C7F2C2", # Light Green
    "Definitive endoderm" : "#97E6A1", # Soft Green
    "Gut" : "#FE938C", # Coral
    "Gut tube" : "#A5DEE4", # Pale Blue
    "Nascent mesoderm" : "#5CA4A9", # Blue-Green
    "Mixed mesoderm" : "#6B5B95", # Lavender
    "Intermediate mesoderm" : "#F9A03F", # Orange
    "Caudal Mesoderm" : "#F7DB6A", # Light Yellow
    "Paraxial mesoderm" : "#EEBB4D", # Light Amber
    "Somitic mesoderm" : "#D6E4B2", # Pale Green
    "Pharyngeal mesoderm" : "#A8DADC", # Pale Cyan
    "Splanchnic mesoderm" : "#3D5A80", # Dark Blue
    "Cardiomyocytes" : "#3E3F8A", # Navy Blue
    "Allantois" : "#218380", # Teal
    "ExE mesoderm" : "#90BE6D", # Soft Green-Yellow
    "Lateral plate mesoderm" : "#FFD369", # Yellow
    "Mesenchyme" : "#ED553B", # Red-Orange
    "Mixed mesenchymal mesoderm" : "#DA627D", # Mauve
    "Haematoendothelial progenitors" : "#6C5B7B", # Purple
    "Endothelium" : "#4ECDC4", # Mint
    "Blood progenitors 1" : "#65AADD", # Sky Blue
    "Blood progenitors 2" : "#8FBFE0", # Powder Blue
    "Erythroid1" : "#A2D2FF", # Pale Sky Blue
    "Erythroid2" : "#F3C969", # Light Amber-Yellow
    "Erythroid3" : "#EE6C4D", # Light Red-Orange
    "Erythroid" : "#EC4E20", # Bright Red
    "Blood progenitors" : "#D64161", # Dark Pink
    "NMP" : "#FF7A5A", # Bright Coral
    "Rostral neurectoderm" : "#E7A977", # Light Coral
    "Caudal neurectoderm" : "#FECE44", # Bright Yellow
    "Neural crest" : "#FFC55F", # Yellow-Orange
    "Forebrain/Midbrain/Hindbrain" : "#F89E7B", # Light Coral-Orange
    "Spinal cord" : "#7ECEFD", # Baby Blue
    "Surface ectoderm" : "#C9B1BD", # Pale Mauve
    "Visceral endoderm" : "#E6A0C4", # Light Pink
    "ExE endoderm" : "#E36BAE", # Bright Pink
    "ExE ectoderm" : "#8B5B6E", # Mauve-Brown
    "Parietal endoderm" : "#748CAB", # Blue-Gray
    "Low quality" : "#E5E5E5", # Light Gray
    "Cranial mesoderm" : "#C4C4C4", # Gray
    "Anterior somitic tissues" : "#A4A4A4", # Dark Gray
    "Sclerotome" : "#4D4D4D", # Charcoal
    "Dermomyotome" : "#F8B195", # Dusty Peach
    "Posterior somitic tissues" : "#F67280", # Salmon Pink
    "Presomitic mesoderm" : "#C06C84", # Rose
}

### 1.3 Run Notebook Setup

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Create Directories

In [None]:
# Define paths
gp_data_folder_path = "../data/gene_programs"
so_data_folder_path = "../data/spatial_omics"
omnipath_ligand_receptor_interactions_file_path = f"{gp_data_folder_path}/omnipath_lr_interactions.csv"
nichenet_ligand_target_interactions_file_path = f"{gp_data_folder_path}/nichenet_ligand_target_matrix.csv"
mebocost_enzyme_sensor_interactions_folder_path = f"{gp_data_folder_path}/metabolite_enzyme_sensor_gps"
artifacts_folder_path = f"../artifacts"
figure_folder_path = f"{artifacts_folder_path}/figures"
model_folder_path = f"{artifacts_folder_path}/models"

# Create required directories
os.makedirs(figure_folder_path, exist_ok=True)
os.makedirs(model_folder_path, exist_ok=True)

## 2. Model Preparation

### 2.1 Create Prior Knowledge Cell-Cell-Interaction Gene Program (GP) Mask

The user can provide a custom GP mask to NicheCompass based on the dataset, application, and hypothesis of interest. 

As a default, we create a GP mask based on three databases of prior cell-cell-interaction knowledge:
- OmniPath
- NicheNet
- MEBOCOST

In [None]:
# Retrieve OmniPath GPs (source: ligand gene; target: receptor gene)
omnipath_gp_dict = extract_gp_dict_from_omnipath_lr_interactions(
    min_curation_effort=0,
    load_from_disk=False,
    save_to_disk=True,
    file_path=omnipath_ligand_receptor_interactions_file_path,
    plot_gp_gene_count_distributions=True)

# Retrieve unique source and target genes from OmniPath GPs to keep
# those genes when filtering
omnipath_genes = get_unique_genes_from_gp_dict(
    gp_dict=omnipath_gp_dict,
    retrieved_gene_entities=["sources", "targets"])

In [None]:
# Retrieve NicheNet GPs (source: ligand gene; target: target genes)
nichenet_gp_dict = extract_gp_dict_from_nichenet_ligand_target_mx(
    keep_target_genes_ratio=0.01,
    max_n_target_genes_per_gp=1000,
    load_from_disk=False,
    save_to_disk=True,
    file_path=nichenet_ligand_target_interactions_file_path,
    plot_gp_gene_count_distributions=True)

# Retrieve unique source genes from NicheNet GPs to keep
# those genes when filtering
nichenet_source_genes = get_unique_genes_from_gp_dict(
    gp_dict=nichenet_gp_dict,
    retrieved_gene_entities=["sources"])

In [None]:
# Retrieve MEBOCOST GPs (source: enzyme genes; target: sensor genes)
mebocost_gp_dict = extract_gp_dict_from_mebocost_es_interactions(
    dir_path=mebocost_enzyme_sensor_interactions_folder_path,
    species="mouse",
    plot_gp_gene_count_distributions=True)

# Retrieve unique source and target genes from MEBOCOST GPs to keep
# those genes when filtering
mebocost_genes = get_unique_genes_from_gp_dict(
    gp_dict=mebocost_gp_dict,
    retrieved_gene_entities=["sources", "targets"])

In [None]:
# Add GPs into one dictionary
combined_gp_dict = dict(omnipath_gp_dict)
combined_gp_dict.update(nichenet_gp_dict)
combined_gp_dict.update(mebocost_gp_dict)

In [None]:
# Filter and combine GPs
combined_new_gp_dict = filter_and_combine_gp_dict_gps(
    gp_dict=combined_gp_dict,
    gp_filter_mode="subset",
    combine_overlap_gps=True,
    overlap_thresh_source_genes=0.9,
    overlap_thresh_target_genes=0.9,
    overlap_thresh_genes=0.9)

print("Number of gene programs before filtering and combining: "
      f"{len(combined_gp_dict)}.")
print(f"Number of gene programs after filtering and combining: "
      f"{len(combined_new_gp_dict)}.")

### 2.2 Load Data & Compute Spatial Neighbor Graph

- NicheCompass expects a precomputed spatial adjacency matrix stored in 'adata.obsp[adj_key]'.
- The user can customize the spatial neighbor graph construction based on the dataset and application of interest.

In [None]:
adata_batch_list = []

for batch in batches:
    print(f"Processing batch {batch}...")
    print("Loading data...")
    adata_batch = sc.read_h5ad(
        f"{so_data_folder_path}/{dataset}_{batch}.h5ad")

    print("Computing spatial neighborhood graph...\n")
    # Compute (separate) spatial neighborhood graphs
    sq.gr.spatial_neighbors(adata_batch,
                            coord_type="generic",
                            spatial_key=spatial_key,
                            n_neighs=n_neighbors)
    
    # Make adjacency matrix symmetric
    adata_batch.obsp[adj_key] = (
        adata_batch.obsp[adj_key].maximum(
            adata_batch.obsp[adj_key].T))
    adata_batch_list.append(adata_batch)
adata = ad.concat(adata_batch_list, join="inner")

# Combine spatial neighborhood graphs as disconnected components
batch_connectivities = []
len_before_batch = 0
for i in range(len(adata_batch_list)):
    if i == 0: # first batch
        after_batch_connectivities_extension = sp.csr_matrix(
            (adata_batch_list[0].shape[0],
            (adata.shape[0] -
            adata_batch_list[0].shape[0])))
        batch_connectivities.append(sp.hstack(
            (adata_batch_list[0].obsp[adj_key],
            after_batch_connectivities_extension)))
    elif i == (len(adata_batch_list) - 1): # last batch
        before_batch_connectivities_extension = sp.csr_matrix(
            (adata_batch_list[i].shape[0],
            (adata.shape[0] -
            adata_batch_list[i].shape[0])))
        batch_connectivities.append(sp.hstack(
            (before_batch_connectivities_extension,
            adata_batch_list[i].obsp[adj_key])))
    else: # middle batches
        before_batch_connectivities_extension = sp.csr_matrix(
            (adata_batch_list[i].shape[0], len_before_batch))
        after_batch_connectivities_extension = sp.csr_matrix(
            (adata_batch_list[i].shape[0],
            (adata.shape[0] -
            adata_batch_list[i].shape[0] -
            len_before_batch)))
        batch_connectivities.append(sp.hstack(
            (before_batch_connectivities_extension,
            adata_batch_list[i].obsp[adj_key],
            after_batch_connectivities_extension)))
    len_before_batch += adata_batch_list[i].shape[0]
adata.obsp[adj_key] = sp.vstack(batch_connectivities)

### 2.3 Filter Genes

In [None]:
if filter_genes:
    print("Filtering genes...")
    # Filter genes and only keep ligand, receptor, metabolitye enzyme, 
    # metabolite sensor and the 'n_hvg' highly variable genes (potential target
    # genes of nichenet)
    gp_dict_genes = get_unique_genes_from_gp_dict(
        gp_dict=combined_new_gp_dict,
            retrieved_gene_entities=["sources", "targets"])
    print(f"Starting with {len(adata.var_names)} genes.")
    sc.pp.filter_genes(adata,
                       min_cells=0)
    print(f"Keeping {len(adata.var_names)} genes after filtering genes with "
          "expression in 0 cells.")

    if (adata.layers[counts_key].astype(int).astype(np.float32).sum() == 
    adata.layers[counts_key].sum()): # raw counts
        hvg_flavor = "seurat_v3"
    else: # log normalized counts
        hvg_flavor = "seurat"

    sc.pp.highly_variable_genes(
        adata,
        layer=counts_key,
        n_top_genes=n_hvg,
        flavor=hvg_flavor,
        batch_key=condition_key,
        subset=False)

    # Get gene program relevant genes
    gp_relevant_genes = list(set(omnipath_genes + 
                                 nichenet_source_genes + 
                                 mebocost_genes))
    
    adata.var["gp_relevant"] = (
        adata.var.index.str.upper().isin(gp_relevant_genes))
    adata.var["keep_gene"] = (adata.var["gp_relevant"] | 
                              adata.var["highly_variable"])
    adata = adata[:, adata.var["keep_gene"] == True]
    print(f"Keeping {len(adata.var_names)} highly variable or gene program "
          "relevant genes.")
    adata = (adata[:, adata.var_names[adata.var_names.str.upper().isin(
                gp_dict_genes)].sort_values()])
    print(f"Keeping {len(adata.var_names)} genes after filtering genes not in "
          "gp dict.")

### 2.4 Add GP Mask to Data

In [None]:
# Add the GP dictionary as binary masks to the adata
add_gps_from_gp_dict_to_adata(
    gp_dict=combined_new_gp_dict,
    adata=adata,
    gp_targets_mask_key=gp_targets_mask_key,
    gp_sources_mask_key=gp_sources_mask_key,
    gp_names_key=gp_names_key,
    min_genes_per_gp=1,
    min_source_genes_per_gp=0,
    min_target_genes_per_gp=0,
    max_genes_per_gp=None,
    max_source_genes_per_gp=None,
    max_target_genes_per_gp=None)

### 2.5 Explore Data

In [None]:
for batch_idx in range(1, len(batches) + 1):
    adata_batch = adata[adata.obs[condition_key] == f"embryo{batch_idx}"]
    
    print(f"Summary of batch batch{batch_idx}:")
    print(f"Number of nodes (observations): {adata_batch.layers[counts_key].shape[0]}")
    print(f"Number of node features (genes): {adata_batch.layers[counts_key].shape[1]}")

    # Visualize cell-level annotated data in physical space
    sc.pl.spatial(adata_batch,
                  color=cell_type_key,
                  palette=celL_type_colors,
                  spot_size=0.03)        

## 3. Model Training

### 3.1 Initialize, Train & Save Model

In [None]:
# Initialize model
model = NicheCompass(adata,
                     counts_key=counts_key,
                     adj_key=adj_key,
                     condition_key=condition_key,
                     gp_names_key=gp_names_key,
                     active_gp_names_key=active_gp_names_key,
                     gp_targets_mask_key=gp_targets_mask_key,
                     gp_sources_mask_key=gp_sources_mask_key,
                     latent_key=latent_key,
                     active_gp_thresh_ratio=active_gp_thresh_ratio,
                     log_variational=log_variational,
                     node_label_method=node_label_method)

In [None]:
# Train model
model.train(n_epochs=n_epochs,
            n_epochs_all_gps=n_epochs_all_gps,
            n_epochs_no_cond_contrastive=n_epochs_no_cond_contrastive,
            lr=lr,
            lambda_edge_recon=lambda_edge_recon,
            lambda_gene_expr_recon=lambda_gene_expr_recon,
            lambda_cond_contrastive=lambda_cond_contrastive,
            contrastive_logits_ratio=contrastive_logits_ratio,
            lambda_l1_masked=lambda_l1_masked,
            edge_batch_size=edge_batch_size,
            use_cuda_if_available=use_cuda_if_available)

In [None]:
# Compute latent neighbor graph
sc.pp.neighbors(model.adata,
                use_rep=latent_key,
                key_added=latent_key)


# Compute UMAP embedding
sc.tl.umap(model.adata,
           neighbors_key=latent_key)

In [None]:
# Save trained model
model.save(dir_path=f"{model_folder_path}/{current_timestamp}",
           overwrite=True,
           save_adata=True,
           adata_file_name="adata.h5ad")

## 4. Analysis

In [None]:
# Load trained model
load_timestamp = "17052023_140928"
model = NicheCompass.load(dir_path=f"{model_folder_path}/{load_timestamp}",
                          adata=None,
                          adata_file_name="adata.h5ad",
                          gp_names_key=gp_names_key)

In [None]:
# Check number of active GPs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
# Inspect GPs
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True].head()

### 4.1 NicheCompass Latent Manifold Overview

In [None]:
# Plot UMAP with cell type annotations
sc.pl.umap(model.adata,
           color=[cell_type_key])

### 4.2 NicheCompass Latent Cluster Differential GP Testing

### 4.3 NicheCompass Cell Type Analysis