# Mouse Brain Multimodal Tutorial

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of AI for Health (AIH), Talavera-López Lab
- **Date of Creation:** 18.05.2023
- **Date of Last Modification:** 08.02.2024

In this tutorial we apply NicheCompass to a single multimodal sample (postnatal day 22 coronal section) of the spatial ATAC-RNA-seq mouse brain dataset from [Zhang, D. et al. Spatial epigenome–transcriptome co-profiling of mammalian tissues. Nature 1–10 (2023)](https://www.nature.com/articles/s41586-023-05795-1).

The sample has:
- 9215 observations at spot resolution with spot rna cluster and atac cluster annotations
- 22,914 probed genes
- 121,068 called peaks

- Check the repository [README.md](https://github.com/sebastianbirk/nichecompass#installation) for NicheCompass installation instructions.
- The data for this tutorial can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1l9W0MDVZ451k1L7s6GGH4ONH4tEK4EKj). It has to be stored under ```<repository_root>/data/spatial_omics/```.
    - spatial_atac_rna_seq_mouse_brain_atac.h5ad
    - spatial_atac_rna_seq_mouse_brain.h5ad
- A pretrained model to run only the analysis can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1z2DQHV9hG22B5OSWox8U3usf_8LKZGGH). It has to be stored under ```<repository_root>/artifacts/multimodal/<timestamp>/model/```.
    - 02062023_151955

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import warnings
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import squidpy as sq
from matplotlib import gridspec

from nichecompass.models import NicheCompass
from nichecompass.utils import (add_gps_from_gp_dict_to_adata,
                                add_multimodal_mask_to_adata,
                                create_new_color_dict,
                                extract_gp_dict_from_collectri_tf_network,
                                extract_gp_dict_from_mebocost_es_interactions,
                                extract_gp_dict_from_nichenet_lrt_interactions,
                                extract_gp_dict_from_omnipath_lr_interactions,
                                filter_and_combine_gp_dict_gps,
                                get_gene_annotations,
                                generate_enriched_gp_info_plots,
                                generate_multimodal_mapping_dict,
                                get_unique_genes_from_gp_dict)

### 1.2 Define Parameters

In [None]:
### Dataset ###
dataset = "spatial_atac_rna_seq_mouse_brain"
species = "mouse"
spatial_key = "spatial"
n_neighbors = 8
n_sampled_neighbors = 4
filter_genes = True
n_svg = 3000
n_hvg = 0
n_svp = 15000
n_hvp = 0
filter_peaks = True
min_cell_peak_thresh_ratio = 0.005 # 0.05%
min_cell_gene_thresh_ratio = 0.005 # 0.05%

### GP Mask ###
add_collectri_gps = True
add_marker_genes_gps = False

### Model ###
# AnnData keys
counts_key = "counts"
adj_key = "spatial_connectivities"
gp_names_key = "nichecompass_gp_names"
active_gp_names_key = "nichecompass_active_gp_names"
gp_targets_mask_key = "nichecompass_gp_targets"
gp_targets_categories_mask_key = "nichecompass_gp_targets_categories"
gp_sources_mask_key = "nichecompass_gp_sources"
gp_sources_categories_mask_key = "nichecompass_gp_sources_categories"
latent_key = "nichecompass_latent"

# Architecture
active_gp_thresh_ratio = 0.03
conv_layer_encoder = "gcnconv"
node_label_method = "one-hop-norm" # one-hop-attention

# Trainer
n_epochs = 400
n_epochs_all_gps = 25
lr = 0.001
lambda_edge_recon = 5000000.
lambda_gene_expr_recon = 3000.
lambda_chrom_access_recon = 1000.
lambda_l1_masked = 300.
lambda_l1_addon = 300.
lambda_group_lasso = 0.
edge_batch_size = 2048 # reduce if not enough memory
use_cuda_if_available = True

### Analysis ###
rna_cluster_key = "RNA_clusters"
atac_cluster_key = "ATAC_clusters"
latent_leiden_resolution = 0.6
latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"
sample_key = "batch"
spot_size = 30
agg_weights_key = "agg_weights"
differential_gp_test_results_key = "nichecompass_differential_gp_test_results"

### 1.3 Run Notebook Setup

In [None]:
warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", None)

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths

In [None]:
# Define paths
ga_data_folder_path = "../data/gene_annotations"
gp_data_folder_path = "../data/gene_programs"
so_data_folder_path = "../data/spatial_omics"
omnipath_lr_network_file_path = f"{gp_data_folder_path}/omnipath_lr_network.csv"
nichenet_lr_network_file_path = f"{gp_data_folder_path}/nichenet_lr_network_v2_{species}.csv"
nichenet_ligand_target_matrix_file_path = f"{gp_data_folder_path}/nichenet_ligand_target_matrix_v2_{species}.csv"
mebocost_enzyme_sensor_interactions_folder_path = f"{gp_data_folder_path}/metabolite_enzyme_sensor_gps"
collectri_tf_network_file_path = f"{gp_data_folder_path}/collectri_tf_network_{species}.csv"
marker_gp_folder_path = f"{gp_data_folder_path}/marker_gps"
gene_orthologs_mapping_file_path = f"{ga_data_folder_path}/human_mouse_gene_orthologs.csv"
gtf_file_path = f"{ga_data_folder_path}/gencode.vM25.chr_patch_hapl_scaff.annotation.gtf.gz"
artifacts_folder_path = f"../artifacts"
model_folder_path = f"{artifacts_folder_path}/single_sample/{current_timestamp}/model"
figure_folder_path = f"{artifacts_folder_path}/single_sample/{current_timestamp}/figures"

## 2. Model Preparation

### 2.1 Create Prior Knowledge Cell-Cell-Interaction (CCI) Gene Program (GP) Mask

- NicheCompass expects a prior GP mask as input, which it will use to make its latent feature space interpretable (through linear masked decoders). 
- The user can provide a custom GP mask to NicheCompass based on the biological question of interest.
- As a default, we create a GP mask based on four databases of prior knowledge of inter- and intracellular interaction pathways:
    - OmniPath (Ligand-Receptor GPs)
    - MEBOCOST (Enzyme-Sensor GPs)
    - CollecTRI (Transcriptional Regulation GPs)
    - NicheNet (Combined Interaction GPs)

In [None]:
# Retrieve OmniPath GPs (source: ligand genes; target: receptor genes)
omnipath_gp_dict = extract_gp_dict_from_omnipath_lr_interactions(
    species=species,
    min_curation_effort=0,
    load_from_disk=False,
    save_to_disk=True,
    lr_network_file_path=omnipath_lr_network_file_path,
    gene_orthologs_mapping_file_path=gene_orthologs_mapping_file_path,
    plot_gp_gene_count_distributions=True,
    gp_gene_count_distributions_save_path=f"{figure_folder_path}" \
                                           "/omnipath_gp_gene_count_distributions.svg")

In [None]:
# Display example OmniPath GP
omnipath_gp_names = list(omnipath_gp_dict.keys())
random.shuffle(omnipath_gp_names)
omnipath_gp_name = omnipath_gp_names[0]
print(f"{omnipath_gp_name}: {omnipath_gp_dict[omnipath_gp_name]}")

In [None]:
# Retrieve NicheNet GPs (source: ligand gene; target: target genes)
nichenet_gp_dict = extract_gp_dict_from_nichenet_lrt_interactions(
    species=species,
    version="v2",
    keep_target_genes_ratio=1.,
    max_n_target_genes_per_gp=250,
    load_from_disk=True,
    save_to_disk=True,
    lr_network_file_path=nichenet_lr_network_file_path,
    ligand_target_matrix_file_path=nichenet_ligand_target_matrix_file_path,
    gene_orthologs_mapping_file_path=gene_orthologs_mapping_file_path,
    plot_gp_gene_count_distributions=True)

# Retrieve unique source genes from NicheNet GPs to keep
# those genes when filtering
nichenet_lr_genes = get_unique_genes_from_gp_dict(
    gp_dict=nichenet_gp_dict,
    retrieved_gene_categories=["ligand", "receptor"])

In [None]:
len(nichenet_gp_dict)

In [None]:
# Display example NicheNet GPs
for i, (key, value) in enumerate(nichenet_gp_dict.items()):
    if i > 2:
        break
    print(key, value)

In [None]:
# Retrieve MEBOCOST GPs (source: enzyme genes; target: sensor genes)
mebocost_gp_dict = extract_gp_dict_from_mebocost_es_interactions(
    dir_path=mebocost_enzyme_sensor_interactions_folder_path,
    species=species,
    plot_gp_gene_count_distributions=True)

# Retrieve unique source and target genes from MEBOCOST GPs to keep
# those genes when filtering
mebocost_genes = get_unique_genes_from_gp_dict(
    gp_dict=mebocost_gp_dict,
    retrieved_gene_entities=["sources", "targets"])

In [None]:
# Display example MEBOCOST GPs
for i, (key, value) in enumerate(mebocost_gp_dict.items()):
    if i > 2:
        break
    print(key, value)

In [None]:
len(mebocost_gp_dict)

In [None]:
if add_collectri_gps:
    collectri_gp_dict = extract_gp_dict_from_collectri_tf_network(
        species=species,
        tf_network_file_path=collectri_tf_network_file_path,
        load_from_disk=True,
        save_to_disk=True,
        plot_gp_gene_count_distributions=True)

In [None]:
if add_collectri_gps:
    # Display example CollecTRI GPs
    for i, (key, value) in enumerate(collectri_gp_dict.items()):
        if i > 2:
            break
        print(key, value)

In [None]:
len(collectri_gp_dict)

In [None]:
# Add spatial layer marker gene GPs (optional)
if add_marker_genes_gps:
    # Load experimentially validated marker genes
    validated_marker_genes_df = pd.read_csv(f"{gp_data_folder_path}/marker_gps/Validated_markers_MM_layers.tsv",
                                            sep="\t",
                                            header=None,
                                            names=["gene_name", "ensembl_id", "layer"])
    validated_marker_genes_df = validated_marker_genes_df[["layer", "gene_name"]]
    
    # Load ranked marker genes and get top 100 per layer
    ranked_marker_genes_df = pd.DataFrame()
    for ranked_marker_genes_file_name in [
        "Ranked_mm_L2L3.tsv",
        "Ranked_mm_L4.tsv",
        "Ranked_mm_L5.tsv",
        "Ranked_mm_L6.tsv",
        "Ranked_mm_L6b.tsv"]:
        ranked_marker_genes_layer_df = pd.read_csv(
            f"{gp_data_folder_path}/marker_gps/{ranked_marker_genes_file_name}",
            sep="\t",
            header=None,
            names=["ensembl_id", "gene_name", "layer"])
        ranked_marker_genes_layer_df = ranked_marker_genes_layer_df[:100] # filter top 100 genes
        ranked_marker_genes_layer_df = ranked_marker_genes_layer_df[["layer", "gene_name"]]
        ranked_marker_genes_df = pd.concat([ranked_marker_genes_df, ranked_marker_genes_layer_df])
    marker_genes_df = pd.concat([validated_marker_genes_df, ranked_marker_genes_df])
        
    marker_genes_grouped_df = marker_genes_df.groupby("layer")["gene_name"].agg(list).reset_index()
    marker_genes_grouped_df.columns = ["layer", "marker_genes"]
    marker_genes_grouped_df["layer"] = marker_genes_grouped_df["layer"] + "_marker_GP"
                                                               
    marker_genes_gp_dict = {}
    for layer, marker_genes in zip(marker_genes_grouped_df["layer"], marker_genes_grouped_df["marker_genes"]):
        marker_genes_gp_dict[layer] = {
            "sources": [],
            "targets": marker_genes,
            "sources_categories": [],
            "targets_categories": ["marker"] * len(marker_genes)}

In [None]:
if add_marker_genes_gps:
    # Display example marker gene GPs
    for i, (key, value) in enumerate(marker_genes_gp_dict.items()):
        if i > 2:
            break
        print(key, value)

In [None]:
# Add GPs into one combined dictionary
# for model training
combined_gp_dict = dict(omnipath_gp_dict)
combined_gp_dict.update(nichenet_gp_dict)
combined_gp_dict.update(mebocost_gp_dict)
if add_collectri_gps:
    combined_gp_dict.update(collectri_gp_dict)    
if add_marker_genes_gps:
    combined_gp_dict.update(marker_genes_gp_dict)

In [None]:
# Filter and combine GPs to avoid overlaps
combined_new_gp_dict = filter_and_combine_gp_dict_gps(
    gp_dict=combined_gp_dict,
    gp_filter_mode="subset",
    combine_overlap_gps=True,
    overlap_thresh_source_genes=0.9,
    overlap_thresh_target_genes=0.9,
    overlap_thresh_genes=0.9)

print("Number of gene programs before filtering and combining: "
      f"{len(combined_gp_dict)}.")
print(f"Number of gene programs after filtering and combining: "
      f"{len(combined_new_gp_dict)}.")

### 2.2 Load Data & Compute Spatial Neighbor Graph

- NicheCompass expects a precomputed spatial adjacency matrix stored in 'adata.obsp[adj_key]'.
- The user can customize the spatial neighbor graph construction based on the dataset, application, and hypothesis of interest.
- In the multimodal setting, we will provide one adata object per modality to NicheCompass.

In [None]:
# Read data
adata = sc.read_h5ad(
        f"{so_data_folder_path}/{dataset}.h5ad")
adata_atac = sc.read_h5ad(
        f"{so_data_folder_path}/{dataset}_atac.h5ad")

In [None]:
# Compute (separate) spatial neighborhood graphs
sq.gr.spatial_neighbors(adata,
                        coord_type="generic",
                        spatial_key=spatial_key,
                        n_neighs=n_neighbors)

# Make adjacency matrix symmetric
adata.obsp[adj_key] = (
    adata.obsp[adj_key].maximum(
        adata.obsp[adj_key].T))

In [None]:
# Compute (separate) spatial neighborhood graphs
#sq.gr.spatial_neighbors(adata,
#                        coord_type="generic",
#                        spatial_key=spatial_key,
#                        key_added="edge_label_" + spatial_key,
#                        n_neighs=8)

# Make adjacency matrix symmetric
#adata.obsp[f"edge_label_{adj_key}"] = (
#    adata.obsp[f"edge_label_{adj_key}"].maximum(
#        adata.obsp[f"edge_label_{adj_key}"].T))

### 2.3 Filter Genes & Peaks

In [None]:
if filter_genes:
    print("Filtering genes...")
    # Filter genes and only keep ligand, receptor, metabolite enzyme, 
    # metabolite sensor and the 'n_svg' spatially variable genes (potential target
    # genes of nichenet)
    gp_dict_genes = get_unique_genes_from_gp_dict(
        gp_dict=combined_new_gp_dict,
            retrieved_gene_entities=["sources", "targets"])
    print(f"Starting with {len(adata.var_names)} genes.")
    min_cells = int(adata.shape[0] * min_cell_gene_thresh_ratio)
    sc.pp.filter_genes(adata, min_cells=min_cells)
    print(f"Keeping {len(adata.var_names)} genes after filtering genes with "
          f"counts in less than {int(adata.shape[0] * min_cell_gene_thresh_ratio)} cells.")
    
    # Identify highly variable genes
    sc.pp.highly_variable_genes(
        adata,
        layer="counts",
        n_top_genes=n_hvg,
        flavor="seurat_v3",
        subset=False)
    
    # Identify spatially variable genes
    sq.gr.spatial_autocorr(adata, mode="moran", genes=adata.var_names)
    svg_genes = adata.uns["moranI"].index[:n_svg].tolist()
    adata.var["spatially_variable"] = adata.var_names.isin(svg_genes)

    # Get gene program relevant genes
    gp_relevant_genes = []
    #gp_relevant_genes = [gene.upper() for gene in list(set(
    #    omnipath_genes
    #    + nichenet_lr_genes 
    #    + mebocost_genes
    #))]
    
    adata.var["gp_relevant"] = (
        adata.var.index.str.upper().isin(gp_relevant_genes))
    adata.var["keep_gene"] = (adata.var["gp_relevant"] | 
                              adata.var["highly_variable"] |
                              adata.var["spatially_variable"])
    adata = adata[:, adata.var["keep_gene"] == True]
    print(f"Keeping {len(adata.var_names)} spatially variable or gene program "
          "relevant genes.")
    
if filter_peaks:
    print("\nFiltering peaks...")
    print(f"Starting with {len(adata_atac.var_names)} peaks.")
    # Filter out peaks that are rarely detected to reduce GPU footprint of model
    min_cells = int(adata_atac.shape[0] * min_cell_peak_thresh_ratio)
    sc.pp.filter_genes(adata_atac, min_cells=min_cells)
    print(f"Keeping {len(adata_atac.var_names)} peaks after filtering peaks with "
          f"counts in less than {int(adata_atac.shape[0] * min_cell_peak_thresh_ratio)} cells.")
    
    # Filter highly variable peaks
    sc.pp.highly_variable_genes(
        adata_atac,
        layer="counts",
        n_top_genes=n_hvp,
        flavor="seurat_v3",
        subset=False)
    
    # Filter spatially variable peaks
    if n_svp > 0:
        adata_atac.obsp["spatial_connectivities"] = adata.obsp["spatial_connectivities"]
        adata_atac.obsp["spatial_distances"] = adata.obsp["spatial_distances"]

        sq.gr.spatial_autocorr(adata_atac,
                               mode="moran",
                               genes=adata_atac.var_names)
        sv_peaks = adata_atac.uns["moranI"].index[:n_svp].tolist()
        adata_atac.var["spatially_variable"] = adata_atac.var_names.isin(sv_peaks)
        
        adata_atac.var["keep_peak"] = (adata_atac.var["highly_variable"] |
                                       adata_atac.var["spatially_variable"])
        adata_atac = adata_atac[:, adata_atac.var["keep_peak"] == True]
        print(f"Keeping {len(adata_atac.var_names)} peaks after filtering spatially variable "
              f"peaks.")

### 2.4 Annotate Genes & Peaks

Next, we will add positional bp annotations to genes and peaks to be able to match spatially proximal peaks to genes.

In [None]:
adata, adata_atac = get_gene_annotations(
    adata=adata,
    adata_atac=adata_atac,
    gtf_file_path=gtf_file_path)

In [None]:
# Display gene annotations
adata.var[["chrom", "chromStart", "chromEnd"]]

In [None]:
# Display peak annotations
adata_atac.var[["chrom", "chromStart", "chromEnd"]]

### 2.5 Add GP Mask to Data

In [None]:
# Add the GP dictionary as binary masks to the adata
add_gps_from_gp_dict_to_adata(
    gp_dict=combined_new_gp_dict,
    adata=adata,
    gp_targets_mask_key=gp_targets_mask_key,
    gp_targets_categories_mask_key=gp_targets_categories_mask_key,
    gp_sources_mask_key=gp_sources_mask_key,
    gp_sources_categories_mask_key=gp_sources_categories_mask_key,
    gp_names_key=gp_names_key,
    min_genes_per_gp=1,
    min_source_genes_per_gp=0,
    min_target_genes_per_gp=0,
    max_genes_per_gp=None,
    max_source_genes_per_gp=None,
    max_target_genes_per_gp=None,
    plot_gp_gene_count_distributions=True)

### 2.6 Add Chromatin Accessibility Mask to Data

Based on spatial proximity to the genes in the GP mask, we will add a chromatin accessibility mask.

In [None]:
gene_peak_mapping_dict = generate_multimodal_mapping_dict(
    adata=adata,
    adata_atac=adata_atac)

In [None]:
adata, adata_atac = add_multimodal_mask_to_adata(
    adata=adata,
    adata_atac=adata_atac,
    gene_peak_mapping_dict=gene_peak_mapping_dict)

print(f"Keeping {len(adata_atac.var_names)} peaks after filtering peaks with "
      "no matching genes in gp mask.")

### 2.7 Explore Data

In [None]:
rna_cluster_colors = create_new_color_dict(
    adata=adata,
    cat_key=rna_cluster_key)

atac_cluster_colors = create_new_color_dict(
    adata=adata,
    cat_key=atac_cluster_key)

In [None]:
print(f"Number of nodes (observations): {adata.layers['counts'].shape[0]}")
print(f"Number of gene node features: {adata.layers['counts'].shape[1]}")
print(f"Number of peak node features: {adata_atac.layers['counts'].shape[1]}")

# Visualize spot-level annotated data in physical space
sc.pl.spatial(adata,
              color=rna_cluster_key,
              palette=rna_cluster_colors,
              spot_size=spot_size)
sc.pl.spatial(adata_atac,
              color=atac_cluster_key,
              palette=atac_cluster_colors,
              spot_size=spot_size) 

## 3. Model Training

### 3.1 Initialize, Train & Save Model

In [None]:
adata.write("adata_8.h5ad")
adata_atac.write("adata_atac_8.h5ad")

In [None]:
adata = sc.read_h5ad("adata_8_3k.h5ad")
adata_atac = sc.read_h5ad("adata_atac_8_3k.h5ad")

In [None]:
cell_type_df = pd.read_csv(f"{so_data_folder_path}/spatial_atac_rna_seq_mouse_brain_batch2_cell_type_annotations.csv", index_col=0)
adata.obs = adata.obs.merge(cell_type_df, left_index=True, right_index=True, how="left")

In [None]:
adata.obs

In [None]:
adata_test = sc.read_h5ad("../data/spatial_omics/spatial_atac_rna_seq_mouse_brain.h5ad")

In [None]:
"A2m" in [gene.upper() for gene in adata_test.var_names.tolist()]

In [None]:
del(adata_test)

In [None]:
sc.pl.spatial(
    use_raw=False,
    adata=adata_test,
    color="Mtap",
    color_map="RdBu_r",
    spot_size=30,
    title=f"",
    legend_loc=None,
    colorbar_loc="bottom",
    show=True) 

In [None]:
lambda_l1_masked = 100
lambda_l1_addon = 100
lambda_edge_recon = 500000

In [None]:
model.model.get_active_gp_mask()

In [None]:
# Initialize model
model = NicheCompass(adata,
                     adata_atac,
                     #cat_covariates_embeds_injection=["gene_expr_decoder", "chrom_access_decoder"],
                     #cat_covariates_keys=["celltype"],
                     #cat_covariates_no_edges=[False],
                     #cat_covariates_embeds_nums=[8],
                     counts_key=counts_key,
                     adj_key=adj_key,
                     gp_names_key=gp_names_key,
                     active_gp_names_key=active_gp_names_key,
                     gp_targets_mask_key=gp_targets_mask_key,
                     gp_targets_categories_mask_key=gp_targets_categories_mask_key,
                     gp_sources_mask_key=gp_sources_mask_key,
                     gp_sources_categories_mask_key=gp_sources_categories_mask_key,
                     latent_key=latent_key,
                     n_fc_layers_encoder=1,
                     n_layers_encoder=1,
                     encoder_n_attention_heads=4,
                     n_addon_gp=n_addon_gp,
                     encoder_use_bn=False,
                     conv_layer_encoder=conv_layer_encoder,
                     active_gp_thresh_ratio=active_gp_thresh_ratio,
                     node_label_method=node_label_method)

In [None]:
use_cuda_if_available = True

In [None]:
edge_batch_size = 2048

In [None]:
del(model)

In [None]:
raise ValueError

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
import gc
gc.collect()

In [None]:
# Train model
model.train(n_epochs=n_epochs,
            n_epochs_all_gps=n_epochs_all_gps,
            n_epochs_no_edge_recon=0,
            lr=lr,
            lambda_edge_recon=lambda_edge_recon,
            lambda_gene_expr_recon=lambda_gene_expr_recon,
            lambda_chrom_access_recon=lambda_chrom_access_recon,
            lambda_l1_masked=lambda_l1_masked,
            lambda_l1_addon=lambda_l1_addon,
            lambda_group_lasso=lambda_group_lasso,
            edge_batch_size=edge_batch_size,
            use_cuda_if_available=use_cuda_if_available,
            n_sampled_neighbors=n_sampled_neighbors,
            verbose=True)

In [None]:
# Display example active GPs
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True].tail(20)

In [None]:
# Compute latent neighbor graph
sc.pp.neighbors(model.adata,
                use_rep=latent_key,
                key_added=latent_key)

# Compute UMAP embedding
sc.tl.umap(model.adata,
           neighbors_key=latent_key)

In [None]:
# Save trained model
model.save(dir_path=model_folder_path,
           overwrite=True,
           save_adata=True,
           adata_file_name="adata.h5ad",
           #save_adata_atac=True,
           #adata_atac_file_name=f"adata_atac.h5ad"
          )

## 4. Analysis

In [None]:
current_timestamp # 13092023_174230

In [None]:
load_timestamp = "02102023_231635" # 13092023_174230

figure_folder_path = f"{artifacts_folder_path}/multimodal/{load_timestamp}/figures"
model_folder_path = f"{artifacts_folder_path}/multimodal/{load_timestamp}/model"

os.makedirs(figure_folder_path, exist_ok=True)

In [None]:
# 13092023_174230

In [None]:
# Load trained model
model = NicheCompass.load(dir_path=model_folder_path,
                          adata=None,
                          adata_file_name="adata.h5ad",
                          adata_atac=None,
                          adata_atac_file_name="adata_atac.h5ad",
                          gp_names_key=gp_names_key)

### 4.1 Visualize NicheCompass Embeddings

Let's look at the preservation of rna cluster and atac cluster spot annotations in the embedding space. Note that the goal of NicheCompass is not a perfect separation of rna cluster or atac cluster spot annotations but rather to identify spatial cellular niches. Nevertheless, it can be useful to look at the spot annotations if available.

In [None]:
output = model.get_omics_decoder_outputs(
                only_active_gps=True,
                node_batch_size=512)

In [None]:
for i, gene in enumerate(adata.var_names):
    model.adata.obs[f"{gene}_rec"] = output["target_rna_nb_means"][:, i]

In [None]:
model.adata.var_names

In [None]:
model.adata.varm["nichecompass_gp_sources"].sum(1)

In [None]:
model.adata.varm["nichecompass_gp_targets"].sum(1)

In [None]:
model.features_scale_factors_

In [None]:
gene = "Mbp"

sc.pl.spatial(
    use_raw=False,
    adata=model.adata,
    color=gene,
    color_map="RdBu_r",
    spot_size=30,
    title=f"",
    legend_loc=None,
    colorbar_loc="bottom",
    show=True)

sc.pl.spatial(
    use_raw=False,
    adata=model.adata,
    color=f"{gene}_rec",
    color_map="RdBu_r",
    spot_size=30,
    title=f"",
    legend_loc=None,
    colorbar_loc="bottom",
    show=True)

In [None]:
samples = model.adata.obs[sample_key].unique().tolist()

rna_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=rna_cluster_key)

atac_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=atac_cluster_key)

In [None]:
# Create plot of RNA cluster annotations in physical and latent space
groups = None # set this to a specific cluster for easy visualization, e.g. ["R0"]
save_fig = True
file_path = f"{figure_folder_path}/" \
            "rna_clusters_latent_physical_space.svg"

fig = plt.figure(figsize=(12, 14))
title = fig.suptitle(t=f"Original RNA Clusters " \
                       "in NicheCompass Latent and Physical Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec1 = gridspec.GridSpec(ncols=1,
                          nrows=2,
                          width_ratios=[1],
                          height_ratios=[3, 2])
spec2 = gridspec.GridSpec(ncols=len(samples),
                          nrows=2,
                          width_ratios=[1] * len(samples),
                          height_ratios=[3, 2])
axs = []
axs.append(fig.add_subplot(spec1[0]))
sc.pl.umap(adata=model.adata,
           color=[rna_cluster_key],
           groups=groups,
           palette=rna_cluster_colors,
           title=f"Original RNA Clusters in NicheCompass Latent Space",
           ax=axs[0],
           show=False)
for idx, sample in enumerate(samples):
    axs.append(fig.add_subplot(spec2[len(samples) + idx]))
    sc.pl.spatial(adata=model.adata[model.adata.obs[sample_key] == sample],
                  color=[rna_cluster_key],
                  groups=groups,
                  palette=rna_cluster_colors,
                  spot_size=spot_size,
                  title=f"Original RNA Clusters in Physical Space \n"
                        f"(Sample: {sample})",
                  legend_loc=None,
                  ax=axs[idx+1],
                  show=False)

# Create and position shared legend
handles, labels = axs[0].get_legend_handles_labels()
lgd = fig.legend(handles,
                 labels,
                 loc="center left",
                 bbox_to_anchor=(0.98, 0.5))
axs[0].get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0.2, hspace=0.25)
if save_fig:
    fig.savefig(file_path,
                bbox_extra_artists=(lgd, title),
                bbox_inches="tight")
plt.show()

In [None]:
# Create plot of ATAC cluster annotations in physical and latent space
groups = None # set this to a specific cluster for easy visualization, e.g. ["C0"]
save_fig = True
file_path = f"{figure_folder_path}/" \
            "atac_clusters_latent_physical_space.svg"

fig = plt.figure(figsize=(12, 14))
title = fig.suptitle(t=f"Original ATAC Clusters " \
                       "in NicheCompass Latent and Physical Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec1 = gridspec.GridSpec(ncols=1,
                          nrows=2,
                          width_ratios=[1],
                          height_ratios=[3, 2])
spec2 = gridspec.GridSpec(ncols=len(samples),
                          nrows=2,
                          width_ratios=[1] * len(samples),
                          height_ratios=[3, 2])
axs = []
axs.append(fig.add_subplot(spec1[0]))
sc.pl.umap(adata=model.adata,
           color=[atac_cluster_key],
           groups=groups,
           palette=atac_cluster_colors,
           title=f"Original ATAC Clusters in NicheCompass Latent Space",
           ax=axs[0],
           show=False)
for idx, sample in enumerate(samples):
    axs.append(fig.add_subplot(spec2[len(samples) + idx]))
    sc.pl.spatial(adata=model.adata[model.adata.obs[sample_key] == sample],
                  color=[atac_cluster_key],
                  groups=groups,
                  palette=atac_cluster_colors,
                  spot_size=spot_size,
                  title=f"Original ATAC Clusters in Physical Space \n"
                        f"(Sample: {sample})",
                  legend_loc=None,
                  ax=axs[idx+1],
                  show=False)

# Create and position shared legend
handles, labels = axs[0].get_legend_handles_labels()
lgd = fig.legend(handles,
                 labels,
                 loc="center left",
                 bbox_to_anchor=(0.98, 0.5))
axs[0].get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0.2, hspace=0.25)
if save_fig:
    fig.savefig(file_path,
                bbox_extra_artists=(lgd, title),
                bbox_inches="tight")
plt.show()

### 4.2 Identify Niches

We will compute Leiden clustering based on the NicheCompass embeddings to identify spatial cellular niches.

In [None]:
latent_leiden_resolution = 0.5

In [None]:
latent_cluster_key = f'latent_leiden_{latent_leiden_resolution}'

In [None]:
# Compute latent Leiden clustering
sc.tl.leiden(adata=model.adata,
             resolution=latent_leiden_resolution,
             key_added=latent_cluster_key,
             neighbors_key=latent_key)

In [None]:
    # Compute latent Leiden subclustering
    sc.tl.leiden(adata=model.adata,
                 resolution=0.1,
                 key_added=latent_cluster_key,
                 restrict_to=(latent_cluster_key, ["8"]),
                 neighbors_key=latent_key)
    sc.tl.leiden(adata=model.adata,
             resolution=0.08,
             key_added=latent_cluster_key,
             restrict_to=(latent_cluster_key, ["3"]),
             neighbors_key=latent_key)
    sc.tl.leiden(adata=model.adata,
             resolution=0.05,
             key_added=latent_cluster_key,
             restrict_to=(latent_cluster_key, ["9"]),
             neighbors_key=latent_key)
    sc.tl.leiden(adata=model.adata,
             resolution=0.06,
             key_added=latent_cluster_key,
             restrict_to=(latent_cluster_key, ["5"]),
             neighbors_key=latent_key)

In [None]:
latent_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=latent_cluster_key)

In [None]:
# Create plot of latent cluster / niche annotations in physical and latent space
groups = None # set this to a specific cluster for easy visualization, e.g. ["11"]
save_fig = True
file_path = f"{figure_folder_path}/" \
            f"res_{latent_leiden_resolution}_" \
            "latent_clusters_latent_physical_space.svg"

fig = plt.figure(figsize=(12, 14))
title = fig.suptitle(t=f"NicheCompass Latent Clusters " \
                       "in Latent and Physical Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec1 = gridspec.GridSpec(ncols=1,
                          nrows=2,
                          width_ratios=[1],
                          height_ratios=[3, 2])
spec2 = gridspec.GridSpec(ncols=len(samples),
                          nrows=2,
                          width_ratios=[1] * len(samples),
                          height_ratios=[3, 2])
axs = []
axs.append(fig.add_subplot(spec1[0]))
sc.pl.umap(adata=model.adata,
           color=[latent_cluster_key],
           groups=groups,
           palette=latent_cluster_colors,
           title=f"Latent Clusters in Latent Space",
           ax=axs[0],
           show=False)
for idx, sample in enumerate(samples):
    axs.append(fig.add_subplot(spec2[len(samples) + idx]))
    sc.pl.spatial(adata=model.adata[model.adata.obs[sample_key] == sample],
                  color=[latent_cluster_key],
                  groups=groups,
                  palette=latent_cluster_colors,
                  spot_size=spot_size,
                  title=f"Latent Clusters in Physical Space \n"
                        f"(Sample: {sample})",
                  legend_loc=None,
                  ax=axs[idx+1],
                  show=False)

# Create and position shared legend
handles, labels = axs[0].get_legend_handles_labels()
lgd = fig.legend(handles,
                 labels,
                 loc="center left",
                 bbox_to_anchor=(0.98, 0.5))
axs[0].get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0.2, hspace=0.25)
if save_fig:
    fig.savefig(file_path,
                bbox_extra_artists=(lgd, title),
                bbox_inches="tight")
plt.show()

In [None]:
cat_covariates_embeds = model.get_cat_covariates_embeds()

In [None]:
cat_covariates_embeds[0].shape

In [None]:
"chr17:85617697-85618556" in model.adata_atac.var_names

In [None]:
model.adata_atac.raw.var_names

In [None]:
gp_summary_df[gp_summary_df["gp_name"].str.contains("Sox2_")]

In [None]:
model.adata.uns[gp_names_key].tolist().index("Sox2_TF_target_genes_GP")

In [None]:
model.adata_atac.var_names.tolist().index("chr2:26500622-26501037")

In [None]:
len(model.adata_atac.var_names)

In [None]:
model.model.get_gp_weights()[0].shape

In [None]:
model.model.get_gp_weights()[1][1130, 789]

In [None]:
# Visualize an example gp
model.add_active_gp_scores_to_obs()

gp = "PTPN11_ligand_receptor_GP"

sc.pl.spatial(
    use_raw=False,
    adata=model.adata_atac,
    color="chr17:85617697-85618556",
    color_map="RdBu_r",
    spot_size=30,
    title=f"",
    legend_loc=None,
    colorbar_loc="bottom",
    show=True) 

In [None]:
sc.pl.spatial(
    use_raw=False,
    adata=model.adata_atac,
    color="chr2:26500622-26501037",
    color_map="RdBu_r",
    spot_size=30,
    title=f"",
    legend_loc=None,
    colorbar_loc="bottom",
    show=True) 

In [None]:
# Check number of active GPs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
gp_weights = model.model.get_gp_weights(only_masked_features=False,
                                             gp_type="all")[0]

In [None]:
(gp_weights.norm(p=1, dim=0) > 0.001).sum()

In [None]:
# Display example active GPs
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True].head(20)

In [None]:
model.model.get_active_gp_mask()

### 4.3 Analyze Niches

#### 4.3.1 Niche Composition

We can analyze the niche composition in terms of rna cluster spot labels.

In [None]:
save_fig = True
file_path = f"{figure_folder_path}/" \
            f"res_{latent_leiden_resolution}_" \
            f"niche_composition.svg"

df_counts = (model.adata.obs.groupby([latent_cluster_key, rna_cluster_key])
             .size().unstack())
df_counts.plot(kind="bar", stacked=True)
legend = plt.legend(bbox_to_anchor=(1, 1), loc="upper left", prop={'size': 10})
legend.set_title("RNA Cluster", prop={'size': 10})
plt.title("RNA Cluster Composition of Niches")
plt.xlabel("Niche")
plt.ylabel("RNA Cluster Counts")
if save_fig:
    plt.savefig(file_path,
                bbox_extra_artists=(legend,),
                bbox_inches="tight")

#### 4.3.2 Spot Annotation Neighbor Importances

Now we will investigate neighbor importances in terms of rna cluster spot labels of different niches.

In [None]:
# Retrieve node neighbor importances 
# (aggregation weights of the node label aggregator)
model.adata.obsp[agg_weights_key] = model.get_neighbor_importances()

In [None]:
# Get cell type neighbor importances for each niche / latent cluster
niche_neighbor_importances_df = aggregate_obsp_matrix_per_cell_type(
    adata=model.adata,
    obsp_key=agg_weights_key,
    cell_type_key=rna_cluster_key,
    group_key=latent_cluster_key,
    agg_rows=True)

In [None]:
display(niche_neighbor_importances_df)

In [None]:
# Generate chord plots showing cell type neighbor importances
# for each niche / latent cluster
groups = "all"
save_fig = True
file_path = f"{figure_folder_path}/" \
            f"res_{latent_leiden_resolution}_" \
            f"niche_neighbor_importances.png"

create_cell_type_chord_plot_from_df(
        adata=model.adata,
        df=niche_neighbor_importances_df,
        link_threshold=0.01,
        cell_type_key=rna_cluster_key,
        group_key=latent_cluster_key,
        groups=groups,
        plot_label="Niche",
        save_fig=save_fig,
        file_path=file_path)

### 4.4 Perform Differential Gene Program Testing

Now we can test which communication gene programs are differentially expressed in a niche. To this end, we will perform differential gene program testing for each latent cluster / niche (```selected_cats = None```)  vs all other latent clusters / niches (```comparison_cats = "rest"```).

We could also perform differential gene program testing for a selected niche only, e.g. latent cluster / niche "11" (```selected_cats = ["11"]```) or test only against specified latent clusters / niches, e.g. niches "2" and "3" (```comparison_cats = ["2", "3"]```).

We choose a log bayes factor threshold of 4.6 to determine decisively enriched gene programs. Alternatively, the threshold could be loosened to a value of 2.3 to also allow "only" strongly enriched gene programs to be identified (see https://en.wikipedia.org/wiki/Bayes_factor).

In [None]:
# Check number of active GPs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
# Display example active GPs
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True][0: 20]

In [None]:
model.model.get_active_gp_mask()

In [None]:
gp_weights = model.model.get_gp_weights(gp_type="addon")

In [None]:
gp_weights[0].sum(0)[-100:]

In [None]:
# Set parameters for differential gp testing
selected_cats = None
comparison_cats = "rest"
title = f"NicheCompass Latent Cluster Decisively Enriched Gene Programs"
log_bayes_factor_thresh = 2.3 # 2.3
save_fig = True
file_path = f"{figure_folder_path}/" \
            f"/log_bayes_factor_{log_bayes_factor_thresh}_" \
             "all_niches_vs_rest_enriched_gps_dotplot.svg"

In [None]:
# Run differential gp testing
enriched_gps = model.run_differential_gp_tests(
    cat_key=latent_cluster_key,
    selected_cats=selected_cats,
    comparison_cats=comparison_cats,
    log_bayes_factor_thresh=log_bayes_factor_thresh)

In [None]:
# Results are stored in a df in the adata object
model.adata.uns[differential_gp_test_results_key]

In [None]:
# Create dotplot of results
fig = sc.pl.dotplot(model.adata,
                    enriched_gps,
                    groupby=latent_cluster_key,
                    expression_cutoff=-np.inf,
                    dendrogram=True, 
                    title=title,
                    swap_axes=True,
                    return_fig=True,
                    figsize=(model.adata.obs[latent_cluster_key].nunique() / 2,
                             len(enriched_gps) / 2))
if save_fig:
    fig.savefig(file_path)
else:
    fig.show()

In [None]:
sc.pl.spatial(model.adata, color=["Hoxc9_TF_target_genes_GP"], spot_size=30)

In [None]:
sc.pl.spatial(model.adata, color=["Myc_TF_target_genes_GP"], spot_size=30)

In [None]:
enriched_gps

In [None]:
# Store gene program summary of enriched gene programs
save_file = True
file_path = f"{figure_folder_path}/" \
            f"/log_bayes_factor_{log_bayes_factor_thresh}_" \
            "all_niches_vs_rest_enriched_gps_summary.csv"

gp_summary_cols = ["gp_name",
                   "n_source_genes",
                   "n_non_zero_source_genes",
                   "n_target_genes",
                   "n_non_zero_target_genes",
                   "gp_source_genes",
                   "gp_target_genes",
                   "gp_source_genes_weights_sign_corrected",
                   "gp_target_genes_weights_sign_corrected",
                   "gp_source_genes_importances",
                   "gp_target_genes_importances",
                   "n_source_peaks",
                   "n_target_peaks",
                   "gp_source_peaks",
                   "gp_target_peaks",
                   "gp_source_peaks_weights_sign_corrected",
                   "gp_target_peaks_weights_sign_corrected",
                   "gp_source_peaks_importances",
                   "gp_target_peaks_importances"]

enriched_gp_summary_df = gp_summary_df[gp_summary_df["gp_name"].isin(enriched_gps)]
cat_dtype = pd.CategoricalDtype(categories=enriched_gps, ordered=True)
enriched_gp_summary_df.loc[:, "gp_name"] = enriched_gp_summary_df["gp_name"].astype(cat_dtype)
enriched_gp_summary_df = enriched_gp_summary_df.sort_values(by="gp_name")
enriched_gp_summary_df = enriched_gp_summary_df[gp_summary_cols]

if save_file:
    enriched_gp_summary_df.to_csv(f"{file_path}")
else:
    display(enriched_gp_summary_df)

### 4.5 Analyze Enriched Gene Programs

Now we will have a look at the gene program scores as well as the (log normalized) count distributions of
the most important omics features of the differentially expressed gene programs.

In [None]:
plot_label = f"log_bayes_factor_{log_bayes_factor_thresh}_all_niches_vs_rest"
save_figs = True

generate_enriched_gp_info_plots(
    plot_label=plot_label,
    model=model,
    sample_key=sample_key,
    differential_gp_test_results_key=differential_gp_test_results_key,
    cat_key=latent_cluster_key,
    cat_palette=latent_cluster_colors,
    n_top_enriched_gp_start_idx=0,
    n_top_enriched_gp_end_idx=10,
    feature_spaces=samples, # ["latent"]
    n_top_genes_per_gp=5,
    #n_top_peaks_per_gp=5,
    save_figs=save_figs,
    figure_folder_path=f"{figure_folder_path}/",
    spot_size=spot_size)

In [None]:
"TYMP" in [gene.upper() for gene in model.adata.var_names.tolist()]

In [None]:
"Cda" in model.adata.var_names.

In [None]:
mebocost_gp_dict

In [None]:
def softmax(z):
    # z being a matrix whos rows are the observations, and columns the different input per observation
    e = np.exp(z)
    s = np.sum(e, axis=1, keepdims=True)
    return e/s

In [None]:
softmax(np.array([[-5, -2, 0], [1, 3, 0]]))

In [None]:
0 * -np.inf

In [None]:
def compute_cosine_similarity(tensor1: torch.Tensor,
                              tensor2: torch.Tensor,
                              eps: float=1e-8) -> torch.Tensor:
    """
    Compute the element-wise cosine similarity between two 2D tensors.

    Parameters
    ----------
    tensor1:
        First tensor for element-wise cosine similarity computation (dim: n_obs
        x n_features).
    tensor2:
        Second tensor for element-wise cosine similarity computation (dim: n_obs
        x n_features).
    
    Returns
    ----------
    cosine_sim:
        Result tensor that contains the computed element-wise cosine
        similarities (dim: n_obs).
    """
    tensor1_norm = tensor1.norm(dim=1)[:, None]
    tensor2_norm = tensor2.norm(dim=1)[:, None]
    tensor1_normalized = tensor1 / torch.max(
            tensor1_norm, eps * torch.ones_like(tensor1_norm))
    tensor2_normalized = tensor2 / torch.max(
            tensor2_norm, eps * torch.ones_like(tensor2_norm))
    cosine_sim = torch.mul(tensor1_normalized, tensor2_normalized).sum(1)
    return cosine_sim

In [None]:
compute_cosine_similarity(torch.tensor(([0., 0.5], [0, 0.5])), torch.tensor(([0., 0.5], [0.5, 0])))

In [None]:
                        gp_weights = model.model.get_gp_weights(
                            only_masked_features=False)[0]

In [None]:
gp_weights.sum(1).shape