# Autotalker Data Analysis Tutorial

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 22.01.2023
- **Date of Last Modification:** 17.05.2023

- In order to run this notebook, a trained model needs to be stored under f"../artifacts/{dataset}/models/{load_timestamp}/{model_label}".

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../autotalker")
sys.path.append("../utils")

In [None]:
import argparse
import os
import random
import warnings
from datetime import datetime

import anndata as ad
import matplotlib
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import seaborn as sns
import squidpy as sq
import torch
from matplotlib import gridspec
from matplotlib.pyplot import rc_context

from autotalker.models import Autotalker
from autotalker.utils import (add_gps_from_gp_dict_to_adata,
                              aggregate_obsp_matrix_per_cell_type,
                              create_cell_type_chord_plot_from_df,
                              extract_gp_dict_from_mebocost_es_interactions,
                              extract_gp_dict_from_nichenet_ligand_target_mx,
                              extract_gp_dict_from_omnipath_lr_interactions,
                              filter_and_combine_gp_dict_gps,
                              get_unique_genes_from_gp_dict)

from color_utils import (batch_colors,
                         latent_cluster_colors,
                         mapping_entity_colors,
                         seqfish_mouse_organogenesis_cell_type_colors,
                         spatial_atac_rna_seq_mouse_embryo_and_brain_rna_colors,
                         spatial_atac_rna_seq_mouse_embryo_and_brain_atac_colors,
                         starmap_plus_mouse_cns_cell_type_colors,
                         visium_human_heart_colors)
from analysis_utils import (add_cell_type_latent_cluster_emphasis,
                            add_sub_cell_type,
                            compute_cell_type_latent_clusters,
                            generate_gp_info_plots,
                            plot_physical_latent_for_cell_types,
                            plot_cell_type_latent_clusters,
                            plot_latent,
                            plot_latent_clusters_in_latent_and_physical_space,
                            store_top_gps_summary)

### 1.2 Define Parameters

In [None]:
## Model
# Model to be loaded
# dataset = "spatial_atac_rna_seq_mouse_brain_batch2"
dataset = "seqfish_mouse_organogenesis_imputed"
model_label = "reference"
run_number = 1 # only required if model_label == sample_integration_method_benchmarking

# AnnData keys
sub_cell_type_key = "sub_cell_type"
gp_names_key = "autotalker_gp_names"
latent_key = "autotalker_latent"
latent_knng_key = "autotalker_latent"
mapping_entity_key = "mapping_entity"

## Others
random_seed = 0

In [None]:
if dataset == "seqfish_mouse_organogenesis_imputed":
    load_timestamp = "12052023_101740"
    cell_type_key = "celltype_mapped_refined"
    cell_type_colors = seqfish_mouse_organogenesis_cell_type_colors
    dataset_str = "seqFISH Mouse Organogenesis Imputed"
    multimodal = False
    sample_key = "sample"
    samples = ["embryo1", "embryo2", "embryo3"]
    condition_key = "batch"
    latent_leiden_resolution = 0.4
    latent_cluster_spot_size = 0.03
elif dataset == "starmap_plus_mouse_cns":
    cell_type_colors = starmap_plus_mouse_cns_cell_type_colors
    dataset_str = "STARmap PLUS Mouse Central Nervous System"
    multimodal = False
elif dataset == "visium_human_heart":
    load_timestamp = "15052023_102158"
    cell_type_key = "majority_cell_type"
    cell_type_colors = visium_human_heart_colors
    dataset_str = "Visium Human Heart"
    multimodal = False
    sample_key = "batch"
    samples = ["batch1", "batch2", "batch3"]
    condition_key = "batch"
    latent_leiden_resolution = 0.5
    latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"
    latent_cluster_spot_size = 200
elif dataset == "spatial_atac_rna_seq_mouse_brain_batch2":
    load_timestamp = "15052023_114414"
    multimodal = True
    cell_type_key = "RNA_clusters"
    cell_type_colors = spatial_atac_rna_seq_mouse_embryo_and_brain_rna_colors
    rna_cluster_colors = spatial_atac_rna_seq_mouse_embryo_and_brain_rna_colors
    atac_cluster_colors = spatial_atac_rna_seq_mouse_embryo_and_brain_atac_colors
    dataset_str = "Spatial ATAC-RNA-Seq Mouse Brain"
    sample_key = "batch"
    samples = ["p22"]
    condition_key = "batch"
    latent_leiden_resolution = 0.5
    latent_cluster_spot_size = 30
    
latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))
sns.set_style("whitegrid", {'axes.grid' : False})

In [None]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
if load_timestamp is not None:
    current_timestamp = load_timestamp
else:
    current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Create Directories

In [None]:
# Define paths
figure_folder_path = f"../figures/{dataset}/analysis/{current_timestamp}"
model_artifacts_folder_path = f"../artifacts/{dataset}/models/{load_timestamp}"
gp_data_folder_path = "../datasets/gp_data" # gene program data
srt_data_folder_path = "../datasets/srt_data" # spatially resolved transcriptomics data
srt_data_gold_folder_path = f"{srt_data_folder_path}/gold"

# Create required directories
os.makedirs(figure_folder_path, exist_ok=True)
os.makedirs(model_artifacts_folder_path, exist_ok=True)

## 2. Model

### 2.1 Load Model

In [None]:
if model_label == "sample_integration_method_benchmarking":
    model_dir_path = f"{model_artifacts_folder_path}/{model_label}/run{run_number}"
else:
    model_dir_path = f"{model_artifacts_folder_path}/{model_label}"

In [None]:
if multimodal:
    model = Autotalker.load(dir_path=model_dir_path,
                            adata=None,
                            adata_file_name=f"{dataset}_{model_label}.h5ad",
                            adata_atac=None,
                            adata_atac_file_name=f"{dataset}_{model_label}_atac.h5ad",
                            gp_names_key=gp_names_key)
else:
    # Load trained model
    model = Autotalker.load(dir_path=model_dir_path,
                            adata=None,
                            adata_file_name=f"{dataset}_{model_label}.h5ad",
                            gp_names_key=gp_names_key)

### 2.2 Retrieve GP Summary

In [None]:
# Check number of active gene programs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True][0:20]

## 3. Analysis

### 3.1 NicheCompass Latent Manifold Overview

In [None]:
# Plot UMAP with batch annotations
save_fig = True
file_path = f"{figure_folder_path}/batch_annotations_latent_space.svg"

plot_latent(adata=model.adata,
            dataset_label=dataset_str,
            color_by=condition_key,
            color_palette=batch_colors,
            groups=None,
            save_fig=save_fig,
            file_path=file_path)

In [None]:
# Plot UMAP with cell type annotations
save_fig = True
file_path = f"{figure_folder_path}/cell_type_annotations_latent_space.svg"

plot_latent(adata=model.adata,
            dataset_label=dataset_str,
            color_by=cell_type_key,
            color_palette=cell_type_colors,
            groups=None,
            save_fig=save_fig,
            file_path=file_path)

In [None]:
# Plot UMAP with mapping entity annotations
save_fig = True
file_path = f"{figure_folder_path}/mapping_entity_annotations_latent_space.svg"

plot_latent(adata=model.adata,
            dataset_label=dataset_str,
            color_by=mapping_entity_key,
            color_palette=mapping_entity_colors,
            groups=None,
            save_fig=save_fig,
            file_path=file_path)

In [None]:
sc.tl.leiden(adata=model.adata,
             resolution=latent_leiden_resolution,
             key_added=latent_cluster_key,
             neighbors_key=latent_knng_key)

In [None]:
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_cluster_annotations_latent_physical_space.svg"

plot_latent_clusters_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Clusters",
    latent_cluster_key=latent_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    latent_cluster_colors=latent_cluster_colors,
    size=640000/len(model.adata),
    spot_size=latent_cluster_spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
if multimodal:
    save_fig = True
    file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
                "rna_cluster_annotations_latent_physical_space.svg"

    plot_latent_clusters_in_latent_and_physical_space(
        adata=model.adata,
        plot_label="RNA Clusters",
        latent_cluster_key="RNA_clusters",
        groups=None,
        condition_key="batch",
        conditions=conditions,
        latent_cluster_colors=rna_cluster_colors,
        size=640000/len(model.adata),
        spot_size=latent_cluster_spot_size,
        save_fig=save_fig,
        file_path=file_path)

    save_fig = True
    file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
                "atac_cluster_annotations_latent_physical_space.svg"

    plot_latent_clusters_in_latent_and_physical_space(
        adata=model.adata,
        plot_label="ATAC Clusters",
        latent_cluster_key="ATAC_clusters",
        groups=None,
        condition_key="batch",
        conditions=conditions,
        latent_cluster_colors=atac_cluster_colors,
        size=640000/len(model.adata),
        spot_size=latent_cluster_spot_size,
        save_fig=save_fig,
        file_path=file_path)

### 3.2 NicheCompass Latent Cluster Differential GP Testing

In [None]:
# Run differential gp testing and generate dotplot of gps enriched in the cell type
title = f"NicheCompass Latent Cluster Enriched Gene Programs"
log_bayes_factor_thresh = 4.6 # 2.3, 4.6
save_fig = True
file_path = f"{figure_folder_path}/log_bayes_factor_{log_bayes_factor_thresh}" \
             "_enriched_gps_dotplot.svg"

# Cell type vs other cell types
enriched_gps = model.run_differential_gp_tests(
    cat_key=latent_cluster_key,
    selected_cats=None,
    comparison_cats="rest",
    log_bayes_factor_thresh=log_bayes_factor_thresh)

fig = sc.pl.dotplot(model.adata,
                    enriched_gps,
                    groupby=latent_cluster_key,
                    dendrogram=True, 
                    title=title,
                    swap_axes=True,
                    return_fig=True,
                    figsize=(model.adata.obs[latent_cluster_key].nunique() / 2,
                             len(enriched_gps) / 2))
if save_fig:
    fig.savefig(file_path)
else:
    fig.show()

In [None]:
save_file = True
file_path = file_path = f"{figure_folder_path}/log_bayes_factor_" \
                        f"{log_bayes_factor_thresh}_enriched_gps_summary.csv"

gp_summary_cols = ["gp_name",
                   "n_source_genes",
                   "n_non_zero_source_genes",
                   "n_target_genes",
                   "n_non_zero_target_genes",
                   "gp_source_genes",
                   "gp_target_genes",
                   "gp_source_genes_weights_sign_corrected",
                   "gp_target_genes_weights_sign_corrected",
                   "gp_source_genes_importances",
                   "gp_target_genes_importances"]
if multimodal:
    gp_summary_cols = gp_summary_cols + [
        "n_source_peaks",
        "n_target_peaks",
        "gp_source_peaks",
        "gp_target_peaks",
        "gp_source_genes_weights_sign_corrected",
        "gp_target_genes_weights_sign_corrected",
        "gp_source_genes_importances",
        "gp_target_genes_importances"]

# Get summary of enriched gene programs
enriched_gp_summary_df = gp_summary_df[gp_summary_df["gp_name"].isin(enriched_gps)]
cat_dtype = pd.CategoricalDtype(categories=enriched_gps, ordered=True)
enriched_gp_summary_df.loc[:, "gp_name"] = enriched_gp_summary_df["gp_name"].astype(cat_dtype)
enriched_gp_summary_df = enriched_gp_summary_df.sort_values(by="gp_name")
enriched_gp_summary_df = enriched_gp_summary_df[gp_summary_cols]

if save_file:
    enriched_gp_summary_df.to_csv(f"{file_path}")
else:
    display(enriched_gp_summary_df)

In [None]:
generate_gp_info_plots(
    analysis_label=f"latent_cluster",
    model=model,
    cell_type_key=cell_type_key,
    cell_type_colors=cell_type_colors,
    latent_cluster_colors=latent_cluster_colors,
    differential_gp_test_results_key="autotalker_differential_gp_test_results",
    plot_category=latent_cluster_key,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    plot_types=["top_genes"],
    figure_folder_path=figure_folder_path,
    log_bayes_factor_thresh=log_bayes_factor_thresh,
    save_figs=True)

### 3.2 NicheCompass Cell Type Analysis

#### 3.2.1 Forebrain/Midbrain/Hindbrain

In [None]:
cell_type = "Forebrain/Midbrain/Hindbrain"
cell_type_fmt = cell_type.replace('/', '_').lower()
cell_type_differential_gp_scores_key = f"{cell_type}_differential_gp_scores"

##### 3.2.1.1 Overview

- Regionally specific developing brain subtypes: separation into Rhombencephalon, Tegmentum, Mesencephalon, Prosencephalon

In [None]:
# Plot cell type in physical and latent space
plot_physical_latent_for_cell_types(adata=model.adata,
                                    cell_types=[cell_type],
                                    sample_key=sample_key,
                                    cell_type_key=cell_type_key,
                                    cell_type_colors=seqfish_mouse_organogenesis_cell_type_colors,
                                    figure_folder_path=figure_folder_path,
                                    save_fig=True)

In [None]:
# Run differential gp testing and generate dotplot of gps enriched in the cell type
title = f"Enriched Gene Programs {cell_type}"
log_bayes_factor_thresh = 2.3
save_fig = True
file_path = f"{figure_folder_path}/{cell_type_fmt}_log_bayes_factor_{log_bayes_factor_thresh}" \
             "_enriched_gps_dotplot.svg"

# Cell type vs other cell types
enriched_gps = model.run_differential_gp_tests(
    cat_key=cell_type_key,
    selected_cats=[cell_type],
    comparison_cats="rest",
    log_bayes_factor_thresh=log_bayes_factor_thresh)

fig = sc.pl.dotplot(model.adata,
                    enriched_gps,
                    groupby=cell_type_key,
                    dendrogram=True, 
                    title=title,
                    swap_axes=True,
                    return_fig=True,
                    figsize=(model.adata.obs[cell_type_key].nunique()/2.5, len(enriched_gps)/2.5))
if save_fig:
    fig.savefig(file_path)
else:
    fig.show()

In [None]:
save_file = True
file_path = file_path = f"{figure_folder_path}/{cell_type_fmt}_log_bayes_factor_" \
                        f"{log_bayes_factor_thresh}_enriched_gps_summary.csv"

# Get summary of enriched gene programs
enriched_gp_summary_df = gp_summary_df[gp_summary_df["gp_name"].isin(enriched_gps)]
cat_dtype = pd.CategoricalDtype(categories=enriched_gps, ordered=True)
enriched_gp_summary_df.loc[:, "gp_name"] = enriched_gp_summary_df["gp_name"].astype(cat_dtype)
enriched_gp_summary_df = enriched_gp_summary_df.sort_values(by="gp_name")
enriched_gp_summary_df = enriched_gp_summary_df[[
        "gp_name",
        "n_source_genes",
        "n_non_zero_source_genes",
        "n_target_genes",
        "n_non_zero_target_genes",
        "gp_source_genes",
        "gp_target_genes",
        "gp_source_genes_weights_sign_corrected",
        "gp_target_genes_weights_sign_corrected",
        "gp_source_genes_importances",
        "gp_target_genes_importances"]]

if save_file:
    enriched_gp_summary_df.to_csv(f"{file_path}")
else:
    display(enriched_gp_summary_df)

In [None]:
generate_gp_info_plots(
    analysis_label=f"{cell_type_fmt}",
    model=model,
    cell_type_key=cell_type_key,
    cell_type_colors=cell_type_colors,
    latent_cluster_colors=latent_cluster_colors,
    differential_gp_test_results_key="autotalker_differential_gp_test_results",
    plot_category=cell_type_key,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    plot_types=["top_genes"],
    figure_folder_path=figure_folder_path,
    log_bayes_factor_thresh=log_bayes_factor_thresh,
    save_figs=True)

##### 3.2.1.2 Latent Cluster Analysis

In [None]:
cell_type_latent_resolution = 0.2 # 0.2 0.4 0.6
cell_type_latent_cluster_key = f"{cell_type_fmt}_latent_leiden_{cell_type_latent_resolution}"
cell_type_latent_clusters_of_interest_key = f"{cell_type_fmt}_latent_leiden_{cell_type_latent_resolution}_clusters_of_interest"
cell_type_latent_cluster_emphasis_key = f"{cell_type_latent_cluster_key}_emphasis"

In [None]:
compute_cell_type_latent_clusters(
    adata=model.adata,
    latent_key=latent_key,
    cell_type_latent_resolution=cell_type_latent_resolution,
    cell_type_latent_cluster_key=cell_type_latent_cluster_key,
    latent_knng_key=latent_knng_key,
    cell_type_key=cell_type_key,
    cell_type=cell_type)

In [None]:
groups = None
save_fig = True

plot_cell_type_latent_clusters(
    adata=model.adata,
    cell_type_latent_cluster_key=cell_type_latent_cluster_key,
    cell_type=cell_type,
    groups=groups,
    condition_key=sample_key,
    latent_cluster_colors=latent_cluster_colors,
    save_fig=save_fig,
    file_path=f"{figure_folder_path}/{cell_type.replace('/', '_').replace(' ', '_').lower()}"
              f"_res_{cell_type_latent_resolution}_{'latent_clusters' if groups is None else 'latent_cluster_' + groups}"
              "_physical_latent_space.svg")

In [None]:
# Retrieve cell type latent clusters with at least 'min_obs_per_cluster' observations
min_obs_per_cluster = 100
cell_type_latent_cluster_counts = model.adata.obs[cell_type_latent_cluster_key].value_counts()
cell_type_latent_clusters_of_interest = cell_type_latent_cluster_counts[
    cell_type_latent_cluster_counts > min_obs_per_cluster].index.tolist()

In [None]:
for latent_cluster in cell_type_latent_clusters_of_interest:
    groups = latent_cluster
    save_fig = True

    plot_cell_type_latent_clusters(
        adata=model.adata,
        cell_type_latent_cluster_key=cell_type_latent_cluster_key,
        cell_type=cell_type,
        groups=groups,
        condition_key=sample_key,
        latent_cluster_colors=latent_cluster_colors,
        save_fig=save_fig,
        file_path=f"{figure_folder_path}/{cell_type.replace('/', '_').replace(' ', '_').lower()}"
                  f"_res_{cell_type_latent_resolution}_{'latent_clusters' if groups is None else 'latent_cluster_' + groups}"
                  "_physical_latent_space.svg")

In [None]:
# Run differential gp testing and generate dotplot of gps enriched in the cell type
title = f"Enriched Gene Programs {cell_type} Latent Clusters (Resolution: {cell_type_latent_resolution})"
log_bayes_factor_thresh = 2.3 # 2.3, 4.6
save_fig = True
file_path = f"{figure_folder_path}/{cell_type_fmt}_res_{cell_type_latent_resolution}_" \
            f"latent_clusters_log_bayes_factor_" \
            f"{log_bayes_factor_thresh}_enriched_gps_dotplot.svg"

# Create new column with only clusters of interest for dotplot
model.adata.obs.loc[:, cell_type_latent_clusters_of_interest_key] = np.nan
model.adata.obs.loc[model.adata.obs[cell_type_latent_cluster_key].isin(
    cell_type_latent_clusters_of_interest), cell_type_latent_clusters_of_interest_key] = model.adata.obs[cell_type_latent_cluster_key]

enriched_gps = []
for cell_type_latent_cluster in cell_type_latent_clusters_of_interest:
    
    # Determine comparison categories
    selected_cats = [cell_type_latent_cluster]
    comparison_cats = list(set(cell_type_latent_clusters_of_interest) - set(selected_cats))

    # Latent cluster vs other latent clusters of interest
    latent_cluster_enriched_gps = model.run_differential_gp_tests(
        cat_key=cell_type_latent_cluster_key,
        selected_cats=selected_cats,
        comparison_cats=comparison_cats,
        log_bayes_factor_thresh=log_bayes_factor_thresh)
    
    enriched_gps.extend(latent_cluster_enriched_gps)
    
enriched_gps = list(set(enriched_gps))

fig = sc.pl.dotplot(model.adata,
                    enriched_gps,
                    groupby=cell_type_latent_clusters_of_interest_key,
                    dendrogram=True, 
                    title=title,
                    swap_axes=True,
                    return_fig=True,
                    figsize=(model.adata.obs[cell_type_latent_clusters_of_interest_key].nunique()/1.5, len(enriched_gps)/2.5))
if save_fig:
    fig.savefig(file_path)
else:
    fig.show()

In [None]:
save_file = True
file_path = file_path = f"{figure_folder_path}/{cell_type_fmt}_res_{cell_type_latent_resolution}" \
                        f"_latent_clusters_" \
                        f"log_bayes_factor_{log_bayes_factor_thresh}_enriched_gps_summary.csv"

# Get summary of enriched gene programs
enriched_gp_summary_df = gp_summary_df[gp_summary_df["gp_name"].isin(enriched_gps)]
cat_dtype = pd.CategoricalDtype(categories=enriched_gps, ordered=True)
enriched_gp_summary_df.loc[:, "gp_name"] = enriched_gp_summary_df["gp_name"].astype(cat_dtype)
enriched_gp_summary_df = enriched_gp_summary_df.sort_values(by="gp_name")
enriched_gp_summary_df = enriched_gp_summary_df[[
        "gp_name",
        "n_source_genes",
        "n_non_zero_source_genes",
        "n_target_genes",
        "n_non_zero_target_genes",
        "gp_source_genes",
        "gp_target_genes",
        "gp_source_genes_weights_sign_corrected",
        "gp_target_genes_weights_sign_corrected",
        "gp_source_genes_importances",
        "gp_target_genes_importances"]]

if save_file:
    enriched_gp_summary_df.to_csv(f"{file_path}")
else:
    display(enriched_gp_summary_df)

#### 3.2.2 Cardiomyocytes

In [None]:
cell_type = "Cardiomyocytes"
cell_type_fmt = cell_type.replace('/', '_').lower()
cell_type_differential_gp_scores_key = f"{cell_type}_differential_gp_scores"

##### 3.2.2.1 Overview

In [None]:
# Plot cell type in physical and latent space
plot_physical_latent_for_cell_types(adata=model.adata,
                                    cell_types=[cell_type],
                                    sample_key=sample_key,
                                    cell_type_key=cell_type_key,
                                    cell_type_colors=seqfish_mouse_organogenesis_cell_type_colors,
                                    figure_folder_path=figure_folder_path,
                                    save_fig=True)

In [None]:
# Run differential gp testing and generate dotplot of gps enriched in the cell type
title = f"Enriched Gene Programs {cell_type}"
log_bayes_factor_thresh = 2.3
save_fig = True
file_path = f"{figure_folder_path}/{cell_type_fmt}_log_bayes_factor_{log_bayes_factor_thresh}" \
             "_enriched_gps_dotplot.svg"

# Cell type vs other cell types
enriched_gps = model.run_differential_gp_tests(
    cat_key=cell_type_key,
    selected_cats=[cell_type],
    comparison_cats="rest",
    log_bayes_factor_thresh=log_bayes_factor_thresh)

fig = sc.pl.dotplot(model.adata,
                    enriched_gps,
                    groupby=cell_type_key,
                    dendrogram=True, 
                    title=title,
                    swap_axes=True,
                    return_fig=True,
                    figsize=(model.adata.obs[cell_type_key].nunique()/2.5, len(enriched_gps)/2.5))
if save_fig:
    fig.savefig(file_path)
else:
    fig.show()

In [None]:
save_file = True
file_path = file_path = f"{figure_folder_path}/{cell_type_fmt}_log_bayes_factor_" \
                        f"{log_bayes_factor_thresh}_enriched_gps_summary.csv"

# Get summary of enriched gene programs
enriched_gp_summary_df = gp_summary_df[gp_summary_df["gp_name"].isin(enriched_gps)]
cat_dtype = pd.CategoricalDtype(categories=enriched_gps, ordered=True)
enriched_gp_summary_df.loc[:, "gp_name"] = enriched_gp_summary_df["gp_name"].astype(cat_dtype)
enriched_gp_summary_df = enriched_gp_summary_df.sort_values(by="gp_name")
enriched_gp_summary_df = enriched_gp_summary_df[[
        "gp_name",
        "n_source_genes",
        "n_non_zero_source_genes",
        "n_target_genes",
        "n_non_zero_target_genes",
        "gp_source_genes",
        "gp_target_genes",
        "gp_source_genes_weights_sign_corrected",
        "gp_target_genes_weights_sign_corrected",
        "gp_source_genes_importances",
        "gp_target_genes_importances"]]

if save_file:
    enriched_gp_summary_df.to_csv(f"{file_path}")
else:
    display(enriched_gp_summary_df)

In [None]:
generate_gp_info_plots(
    analysis_label=f"{cell_type_fmt}",
    model=model,
    cell_type_key=cell_type_key,
    cell_type_colors=cell_type_colors,
    latent_cluster_colors=latent_cluster_colors,
    differential_gp_test_results_key="autotalker_differential_gp_test_results",
    plot_category=cell_type_key,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    plot_types=["top_genes"],
    figure_folder_path=figure_folder_path,
    log_bayes_factor_thresh=log_bayes_factor_thresh,
    save_figs=True)

##### 3.2.2.2 Latent Cluster Analysis

In [None]:
cell_type_latent_resolution = 0.6 # 0.8 1.0
cell_type_latent_cluster_key = f"{cell_type_fmt}_latent_leiden_{cell_type_latent_resolution}"
cell_type_latent_clusters_of_interest_key = f"{cell_type_fmt}_latent_leiden_{cell_type_latent_resolution}_clusters_of_interest"
cell_type_latent_cluster_emphasis_key = f"{cell_type_latent_cluster_key}_emphasis"

In [None]:
compute_cell_type_latent_clusters(
    adata=model.adata,
    latent_key=latent_key,
    cell_type_latent_resolution=cell_type_latent_resolution,
    cell_type_latent_cluster_key=cell_type_latent_cluster_key,
    latent_knng_key=latent_knng_key,
    cell_type_key=cell_type_key,
    cell_type=cell_type)

In [None]:
groups = None
save_fig = True

plot_cell_type_latent_clusters(
    adata=model.adata,
    cell_type_latent_cluster_key=cell_type_latent_cluster_key,
    cell_type=cell_type,
    groups=groups,
    condition_key=sample_key,
    latent_cluster_colors=latent_cluster_colors,
    save_fig=save_fig,
    file_path=f"{figure_folder_path}/{cell_type.replace('/', '_').replace(' ', '_').lower()}"
              f"_res_{cell_type_latent_resolution}_{'latent_clusters' if groups is None else 'latent_cluster_' + groups}"
              "_physical_latent_space.svg")

In [None]:
# Retrieve cell type latent clusters with at least 'min_obs_per_cluster' observations
min_obs_per_cluster = 30
cell_type_latent_cluster_counts = model.adata.obs[cell_type_latent_cluster_key].value_counts()
cell_type_latent_clusters_of_interest = cell_type_latent_cluster_counts[
    cell_type_latent_cluster_counts > min_obs_per_cluster].index.tolist()

In [None]:
for latent_cluster in cell_type_latent_clusters_of_interest:
    groups = latent_cluster
    save_fig = True

    plot_cell_type_latent_clusters(
        adata=model.adata,
        cell_type_latent_cluster_key=cell_type_latent_cluster_key,
        cell_type=cell_type,
        groups=groups,
        condition_key=sample_key,
        latent_cluster_colors=latent_cluster_colors,
        save_fig=save_fig,
        file_path=f"{figure_folder_path}/{cell_type.replace('/', '_').replace(' ', '_').lower()}"
                  f"_res_{cell_type_latent_resolution}_{'latent_clusters' if groups is None else 'latent_cluster_' + groups}"
                  "_physical_latent_space.svg")

In [None]:
# Run differential gp testing and generate dotplot of gps enriched in the cell type
title = f"Enriched Gene Programs {cell_type} Latent Clusters (Resolution: {cell_type_latent_resolution})"
log_bayes_factor_thresh = 2.3 # 2.3, 4.6
save_fig = True
file_path = f"{figure_folder_path}/{cell_type_fmt}_res_{cell_type_latent_resolution}_" \
            f"latent_clusters_log_bayes_factor_" \
            f"{log_bayes_factor_thresh}_enriched_gps_dotplot.svg"

# Create new column with only clusters of interest for dotplot
model.adata.obs.loc[:, cell_type_latent_clusters_of_interest_key] = np.nan
model.adata.obs.loc[model.adata.obs[cell_type_latent_cluster_key].isin(
    cell_type_latent_clusters_of_interest), cell_type_latent_clusters_of_interest_key] = model.adata.obs[cell_type_latent_cluster_key]

enriched_gps = []
for cell_type_latent_cluster in cell_type_latent_clusters_of_interest:
    
    # Determine comparison categories
    selected_cats = [cell_type_latent_cluster]
    comparison_cats = list(set(cell_type_latent_clusters_of_interest) - set(selected_cats))

    # Latent cluster vs other latent clusters of interest
    latent_cluster_enriched_gps = model.run_differential_gp_tests(
        cat_key=cell_type_latent_cluster_key,
        selected_cats=selected_cats,
        comparison_cats=comparison_cats,
        log_bayes_factor_thresh=log_bayes_factor_thresh)
    
    enriched_gps.extend(latent_cluster_enriched_gps)
    
enriched_gps = list(set(enriched_gps))

fig = sc.pl.dotplot(model.adata,
                    enriched_gps,
                    groupby=cell_type_latent_clusters_of_interest_key,
                    dendrogram=True, 
                    title=title,
                    swap_axes=True,
                    return_fig=True,
                    figsize=(model.adata.obs[cell_type_latent_clusters_of_interest_key].nunique()/1.5, len(enriched_gps)/2.5))
if save_fig:
    fig.savefig(file_path)
else:
    fig.show()

In [None]:
save_file = True
file_path = file_path = f"{figure_folder_path}/{cell_type_fmt}_res_{cell_type_latent_resolution}" \
                        f"_latent_clusters_" \
                        f"log_bayes_factor_{log_bayes_factor_thresh}_enriched_gps_summary.csv"

# Get summary of enriched gene programs
enriched_gp_summary_df = gp_summary_df[gp_summary_df["gp_name"].isin(enriched_gps)]
cat_dtype = pd.CategoricalDtype(categories=enriched_gps, ordered=True)
enriched_gp_summary_df.loc[:, "gp_name"] = enriched_gp_summary_df["gp_name"].astype(cat_dtype)
enriched_gp_summary_df = enriched_gp_summary_df.sort_values(by="gp_name")
enriched_gp_summary_df = enriched_gp_summary_df[[
        "gp_name",
        "n_source_genes",
        "n_non_zero_source_genes",
        "n_target_genes",
        "n_non_zero_target_genes",
        "gp_source_genes",
        "gp_target_genes",
        "gp_source_genes_weights_sign_corrected",
        "gp_target_genes_weights_sign_corrected",
        "gp_source_genes_importances",
        "gp_target_genes_importances"]]

if save_file:
    enriched_gp_summary_df.to_csv(f"{file_path}")
else:
    display(enriched_gp_summary_df)

#### 3.2.3 Save Results

In [None]:
model.adata.write(f"{figure_folder_path}/adata.h5ad")

In [None]:
model.adata

In [None]:
# Add cell type latent cluster emphasis for plotting
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row,
                                                      cell_type_latent_cluster_key,
                                                      selected_cats,
                                                      comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

In [None]:
# TO DO HERE

In [None]:
generate_gp_info_plots(
    analysis_label=f"{cell_type_fmt}_latent_cluster",
    model=model,
    cell_type_key=cell_type_key,
    cell_type_colors=cell_type_colors,
    latent_cluster_colors=latent_cluster_colors,
    differential_gp_test_results_df=model.adata.uns["autotalker_differential_gp_test_results"],
    plot_category=cell_type_latent_cluster_emphasis_key,
    plot_group=[x for x in model.adata.obs[cell_type_latent_cluster_emphasis_key].unique().tolist() if str(x) != "nan"],
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    figure_folder_path=figure_folder_path,
    save_figs=False)

In [None]:
# Cell type vs other cell types
enriched_gps = model.run_differential_gp_tests(
    cat_key=cell_type_key,
    selected_cats=[cell_type],
    comparison_cats="rest")

In [None]:
# Specific latent clusters
cluster_and_plot_cell_type_latent_clusters(adata=model.adata,
                                           cell_type_latent_resolution=cell_type_latent_resolution,
                                           cell_type_latent_cluster_key=cell_type_latent_cluster_key,
                                           latent_knng_key=latent_knng_key,
                                           cell_type_key=cell_type_key,
                                           cell_type=cell_type,
                                           groups=["2"],
                                           latent_cluster_colors=latent_cluster_colors,
                                           condition_key=sample_key,
                                           save_fig=False,
                                           figure_folder_path=figure_folder_path)

In [None]:
# Retrieve all latent clusters with cells of the given cell type
cell_type_latent_clusters = [cell_type_latent_cluster for cell_type_latent_cluster 
                             in model.adata.obs[cell_type_latent_cluster_key].unique().tolist()
                             if str(cell_type_latent_cluster) != "nan"]

# Define latent clusters of interest according to visualization
cell_type_latent_clusters_of_interest = ["5"]

# Add column with sub cell types
model.adata.obs[sub_cell_type_key] = model.adata.obs.apply(lambda row: add_sub_cell_type(row,
                                                                                         cell_type_key=cell_type_key,
                                                                                         cell_type=cell_type,
                                                                                         cell_type_latent_cluster_key=cell_type_latent_cluster_key), axis=1)
model.adata.obs[sub_cell_type_key] = model.adata.obs[sub_cell_type_key].astype("category")

# Squidpy nhood enrichment is sorted alphabetically
sub_cell_types = model.adata.obs[sub_cell_type_key].unique().tolist()
sub_cell_types.sort()

sq.gr.nhood_enrichment(model.adata, cluster_key=sub_cell_type_key)

cell_type_sub_cell_types_idx = []
for i, sub_cell_type in enumerate(sub_cell_types):
    if cell_type in sub_cell_type:
        cell_type_sub_cell_types_idx.append(i)

        # Retrieve cell type latent cluster neighborhood enrichments
enrichment_dict = {f"{sub_cell_types[cell_type_sub_cell_types_idx[i]]}":{
    sub_cell_type: zscore for zscore, sub_cell_type in sorted(
        zip(model.adata.uns[f"{sub_cell_type_key}_nhood_enrichment"]["zscore"][cell_type_sub_cell_types_idx[i]], sub_cell_types), reverse=True)} for i in range(len(cell_type_sub_cell_types_idx))}

enrichment_df = pd.DataFrame(enrichment_dict)

In [None]:
compare_clusters_of_interest_only = False
selected_cats = ["3"]
if compare_clusters_of_interest_only:
    comparison_cats = list(set(cell_type_latent_clusters_of_interest) - set(selected_cats))
else:
    comparison_cats="rest"
latent_cluster_differential_gp_scores_key = f"autotalker_latent_cluster_{selected_cats[0]}_differential_gp_scores"

model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row, cell_type_latent_cluster_key, selected_cats[0], comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

top_unique_gps = get_differential_analysis_results(
    analysis_label=f"{cell_type_fmt}_latent_cluster_{selected_cats[0]}",
    model=model,
    adata=model.adata,
    cell_type_key=cell_type_key,
    cell_type_colors=cell_type_colors,
    latent_cluster_colors=latent_cluster_colors,
    random_seed=random_seed,
    cat_key=cell_type_latent_cluster_key,
    selected_cats=selected_cats,
    differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
    comparison_cats=comparison_cats,
    plot_category=cell_type_latent_cluster_emphasis_key,
    plot_group=[x for x in model.adata.obs[cell_type_latent_cluster_emphasis_key].unique().tolist() if str(x) != "nan"],
    selected_gps=None,
    n_top_up_gps=10,
    n_top_down_gps=10,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    figure_folder_path=figure_folder_path,
    save_figs=True)

top_gps_summary_df = store_top_gps_summary(model=model,
                                           top_gps=top_unique_gps,
                                           file_path=f"{figure_folder_path}/{cell_type_fmt}_cluster_{selected_cats[0]}_gp_summary.csv")
display(top_gps_summary_df)

In [None]:
model.adata

In [None]:
model.adata.uns["autotalker_latent_cluster_3_differential_gp_scores"]

In [None]:
cell_type_latent_cluster_key
save_fig = False

In [None]:
def create_dotplot_of_differential_gps(adata,
                                       groupby_key,
                                       title,
                                       save_fig,
                                       file_path):
    
    differential_gps = [col for col in model.adata.obs.columns if col.endswith("_GP")]
    
    fig = sc.pl.dotplot(adata,
                        differential_gps,
                        groupby=groupby_key,
                        dendrogram=True, 
                        title=f"Differential GP Scores",
                        swap_axes=True,
                        return_fig=True)
    # Save and display plot
    if save_fig:
        fig.savefig(f"{figure_folder_path}/{analysis_label}_differential_gp_scores.svg")
    fig.show()

In [None]:
# Retrieve all latent clusters with cells of the given cell type
cell_type_latent_clusters = [cell_type_latent_cluster for cell_type_latent_cluster 
                             in model.adata.obs[cell_type_latent_cluster_key].unique().tolist()
                             if str(cell_type_latent_cluster) != "nan"]

# Define latent clusters of interest according to visualization
cell_type_latent_clusters_of_interest = ["2", "25", "43"]

# Add column with sub cell types
model.adata.obs[sub_cell_type_key] = model.adata.obs.apply(lambda row: add_sub_cell_type(row, cell_type=cell_type), axis=1)
model.adata.obs[sub_cell_type_key] = model.adata.obs[sub_cell_type_key].astype("category")

# Squidpy nhood enrichment is sorted alphabetically
sub_cell_types = model.adata.obs[sub_cell_type_key].unique().tolist()
sub_cell_types.sort()

sq.gr.nhood_enrichment(model.adata, cluster_key=sub_cell_type_key)

cell_type_sub_cell_types_idx = []
for i, sub_cell_type in enumerate(sub_cell_types):
    if cell_type in sub_cell_type:
        cell_type_sub_cell_types_idx.append(i)

        # Retrieve cell type latent cluster neighborhood enrichments
enrichment_dict = {f"{sub_cell_types[cell_type_sub_cell_types_idx[i]]}":{
    sub_cell_type: zscore for zscore, sub_cell_type in sorted(
        zip(model.adata.uns[f"{sub_cell_type_key}_nhood_enrichment"]["zscore"][cell_type_sub_cell_types_idx[i]], sub_cell_types), reverse=True)} for i in range(len(cell_type_sub_cell_types_idx))}

enrichment_df = pd.DataFrame(enrichment_dict)

In [None]:
selected_cats = [""]
comparison_cats = list(
    set(cell_type_adata.obs[cell_type_latent_cluster_key].unique()) - set(selected_cats))

model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row, comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

In [None]:
get_differential_analysis_results(analysis_label="forebrain_midbrain_hindbrain_latent_cluster",
                                  model=model,
                                  adata=model.adata,
                                  cat_key=cell_type_latent_cluster_key,
                                  selected_cats=selected_cats,
                                  differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
                                  comparison_cats=comparison_cats,
                                  plot_category=cell_type_latent_cluster_emphasis_key,
                                  selected_gps=None,
                                  n_top_up_gps=3,
                                  n_top_down_gps=3,
                                  feature_spaces=["physical_embryo2"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
                                  save_figs=False)

In [None]:
model.compute_gp_gene_importances(selected_gp="GDF3_ligand_targetgenes_GP")

In [None]:
selected_cats = ["7"]
comparison_cats = ["8"]

model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row, comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

In [None]:
get_differential_analysis_results(analysis_label="forebrain_midbrain_hindbrain_latent_cluster",
                                  model=model,
                                  adata=model.adata,
                                  cat_key=cell_type_latent_cluster_key,
                                  selected_cats=selected_cats,
                                  differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
                                  comparison_cats=comparison_cats,
                                  plot_category=cell_type_latent_cluster_emphasis_key,
                                  selected_gps=None,
                                  n_top_up_gps=3,
                                  n_top_down_gps=3,
                                  feature_spaces=["physical_embryo2"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
                                  save_figs=False)

### 3.4 Gut Tube

In [None]:
cell_type = "Gut tube"
cell_type_latent_resolution = 0.04

In [None]:
plot_physical_latent_for_cell_types(adata=model.adata,
                                    cell_types=[cell_type],
                                    save_fig=True)

In [None]:
cell_type_adata = model.adata[model.adata.obs[cell_type_key] == cell_type.replace("_", " ").capitalize()]

In [None]:
# Compute latent nearest neighbor graph for cell type only
sc.pp.neighbors(cell_type_adata,
                use_rep=latent_key,
                key_added=f"{cell_type}_latent_knng")

# Compute latent Leiden clustering for cell type
sc.tl.leiden(adata=cell_type_adata,
             resolution=cell_type_latent_resolution,
             random_state=random_seed,
             key_added=cell_type_latent_cluster_key,
             adjacency=cell_type_adata.obsp[f"{cell_type}_latent_knng_connectivities"])

# Use cell type latent space for UMAP generation
sc.tl.umap(cell_type_adata,
           neighbors_key=f"{cell_type}_latent_knng")

In [None]:
model.adata.obs[cell_type_latent_cluster_key] = np.nan
model.adata.obs.loc[model.adata.obs[cell_type_key] == cell_type.replace("_", " ").capitalize(),
                    cell_type_latent_cluster_key] = cell_type_adata.obs[cell_type_latent_cluster_key]

In [None]:
plot_latent_physical_for_cell_type_latent_clusters(adata=model.adata,
                                                   cell_type="mixed_mesenchymal_mesoderm",
                                                   save_fig=False)

Dorsal-ventral separation of esophageal and tracheal progenitor populations in the gut tube (not visible in scVI latent space).

Dorsal-ventral and rostral-caudal spatially resolved patterns of the midbrain and hindbrain region.

In [None]:
model.adata.obsp["autotalker_recon_adj"] = model.get_recon_adj()

In [None]:
att_weight_df = aggregate_obsp_matrix_per_cell_type(adata=model.adata,
                                                    obsp_key="autotalker_agg_alpha",
                                                    cell_type_key=cell_type_key,
                                                    agg_rows=True)

recon_adj_df = aggregate_obsp_matrix_per_cell_type(adata=model.adata,
                                                   obsp_key="autotalker_recon_adj",
                                                   cell_type_key=cell_type_key,
                                                   agg_rows=True)

adj_df = aggregate_obsp_matrix_per_cell_type(adata=model.adata,
                                             obsp_key="spatial_connectivities",
                                             cell_type_key=cell_type_key,
                                             agg_rows=True)

In [None]:
create_cell_type_chord_plot_from_df(adata=model.adata,
                                    df=att_weight_df,
                                    title="Aggregation Module Cell Type Attention",
                                    link_threshold=.5,
                                    cell_type_key=cell_type_key,
                                    save_fig=True,
                                    save_path=f"{model_artifacts_folder_path}/circos.png")

#### 3.2.1 Cardiomyocytes

In [None]:
cell_type = "Cardiomyocytes"
cell_type_fmt = cell_type.replace('/', '_').lower()
cell_type_latent_resolution = 1.0
cell_type_latent_cluster_key = f"{cell_type_fmt}_latent_leiden_{cell_type_latent_resolution}"
cell_type_latent_cluster_emphasis_key = f"{cell_type_latent_cluster_key}_emphasis"
cell_type_differential_gp_scores_key = f"{cell_type}_differential_gp_scores"

In [None]:
plot_physical_latent_for_cell_types(adata=model.adata,
                                    cell_types=[cell_type],
                                    sample_key=sample_key,
                                    cell_type_key=cell_type_key,
                                    cell_type_colors=seqfish_mouse_organogenesis_cell_type_colors,
                                    figure_folder_path=figure_folder_path,
                                    save_fig=False)

In [None]:
top_unique_gps = get_differential_analysis_results(
    analysis_label=cell_type_fmt,
    model=model,
    adata=model.adata,
    cell_type_key=cell_type_key,
    cell_type_colors=cell_type_colors,
    cat_key=cell_type_key,
    selected_cats=[cell_type],
    differential_gp_scores_key=cell_type_differential_gp_scores_key,
    comparison_cats="rest",
    plot_category=cell_type_key,
    plot_group=cell_type,
    selected_gps=None,
    n_top_up_gps=10,
    n_top_down_gps=10,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    save_figs=False,
    random_seed=random_seed)

top_gps_summary_df = store_top_gps_summary(model=model,
                                           top_gps=top_unique_gps,
                                           file_path=f"{model_artifacts_folder_path}/{cell_type_fmt}_gp_summary.csv")
display(top_gps_summary_df)

In [None]:
# Compute latent Leiden clustering with cell-type-specific resolution
sc.tl.leiden(adata=model.adata,
             resolution=cell_type_latent_resolution,
             random_state=random_seed,
             key_added=cell_type_latent_cluster_key,
             neighbors_key=latent_knng_key)

# Filter for cell type
cell_type_adata = model.adata[model.adata.obs[cell_type_key] == cell_type]

# Only keep latent clusters for cell type and set rest to NaN
model.adata.obs[cell_type_latent_cluster_key] = np.nan
model.adata.obs.loc[model.adata.obs[cell_type_key] == cell_type,
                    cell_type_latent_cluster_key] = cell_type_adata.obs[cell_type_latent_cluster_key]

In [None]:
plot_latent_physical_for_cell_type_latent_clusters(adata=model.adata,
                                                   cell_type=cell_type,
                                                   save_fig=False)

In [None]:
# Retrieve all latent clusters with cells of the given cell type
cell_type_latent_clusters = [cell_type_latent_cluster for cell_type_latent_cluster 
                             in model.adata.obs[cell_type_latent_cluster_key].unique().tolist()
                             if str(cell_type_latent_cluster) != "nan"]

# Define latent clusters of interest according to visualization
cell_type_latent_clusters_of_interest = ["2", "25", "43"]

# Add column with sub cell types
model.adata.obs[sub_cell_type_key] = model.adata.obs.apply(lambda row: add_sub_cell_type(row, cell_type=cell_type), axis=1)
model.adata.obs[sub_cell_type_key] = model.adata.obs[sub_cell_type_key].astype("category")

# Squidpy nhood enrichment is sorted alphabetically
sub_cell_types = model.adata.obs[sub_cell_type_key].unique().tolist()
sub_cell_types.sort()

sq.gr.nhood_enrichment(model.adata, cluster_key=sub_cell_type_key)

cell_type_sub_cell_types_idx = []
for i, sub_cell_type in enumerate(sub_cell_types):
    if cell_type in sub_cell_type:
        cell_type_sub_cell_types_idx.append(i)

        # Retrieve cell type latent cluster neighborhood enrichments
enrichment_dict = {f"{sub_cell_types[cell_type_sub_cell_types_idx[i]]}":{
    sub_cell_type: zscore for zscore, sub_cell_type in sorted(
        zip(model.adata.uns[f"{sub_cell_type_key}_nhood_enrichment"]["zscore"][cell_type_sub_cell_types_idx[i]], sub_cell_types), reverse=True)} for i in range(len(cell_type_sub_cell_types_idx))}

enrichment_df = pd.DataFrame(enrichment_dict)

#### 3.2.1 Latent Cluster 2

In [None]:
selected_cats = ["2"]
comparison_cats = list(set(cell_type_latent_clusters_of_interest) - set(selected_cats))
latent_cluster_differential_gp_scores_key = f"autotalker_latent_cluster_{selected_cats[0]}_differential_gp_scores"

model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row, comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

top_unique_gps = get_differential_analysis_results(
    analysis_label=f"{cell_type}_latent_cluster",
    model=model,
    adata=model.adata,
    cat_key=cell_type_latent_cluster_key,
    selected_cats=selected_cats,
    differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
    comparison_cats=comparison_cats,
    plot_category=cell_type_latent_cluster_emphasis_key,
    plot_group=[x for x in model.adata.obs[cell_type_latent_cluster_emphasis_key].unique().tolist() if str(x) != "nan"],
    selected_gps=None,
    n_top_up_gps=10,
    n_top_down_gps=10,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    save_figs=False)

top_gps_summary_df = store_top_gps_summary(model=model,
                                           top_gps=top_unique_gps,
                                           file_name=f"{cell_type_fmt}_cluster{selected_cats[0]}_gp_summary.csv")
display(top_gps_summary_df)

#### 3.2.2 Latent Cluster 25

In [None]:
selected_cats = ["25"]
comparison_cats = list(set(cell_type_latent_clusters_of_interest) - set(selected_cats))
latent_cluster_differential_gp_scores_key = f"autotalker_latent_cluster_{selected_cats[0]}_differential_gp_scores"

model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row, comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row, comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

top_unique_gps = get_differential_analysis_results(
    analysis_label=f"{cell_type}_latent_cluster",
    model=model,
    adata=model.adata,
    cat_key=cell_type_latent_cluster_key,
    selected_cats=selected_cats,
    differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
    comparison_cats=comparison_cats,
    plot_category=cell_type_latent_cluster_emphasis_key,
    plot_group=[x for x in model.adata.obs[cell_type_latent_cluster_emphasis_key].unique().tolist() if str(x) != "nan"],
    selected_gps=None,
    n_top_up_gps=10,
    n_top_down_gps=10,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    save_figs=False)

top_gps_summary_df = store_top_gps_summary(model=model,
                                           top_gps=top_unique_gps,
                                           file_name=f"{cell_type_fmt}_cluster{selected_cats[0]}_gp_summary.csv")
display(top_gps_summary_df)

#### 3.2.3 Latent Cluster 43

In [None]:
selected_cats = ["43"]
comparison_cats = list(set(cell_type_latent_clusters_of_interest) - set(selected_cats))
latent_cluster_differential_gp_scores_key = f"autotalker_latent_cluster_{selected_cats[0]}_differential_gp_scores"

model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row, comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row, comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

top_unique_gps = get_differential_analysis_results(
    analysis_label=f"{cell_type}_latent_cluster",
    model=model,
    adata=model.adata,
    cat_key=cell_type_latent_cluster_key,
    selected_cats=selected_cats,
    differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
    comparison_cats=comparison_cats,
    plot_category=cell_type_latent_cluster_emphasis_key,
    plot_group=[x for x in model.adata.obs[cell_type_latent_cluster_emphasis_key].unique().tolist() if str(x) != "nan"],
    selected_gps=None,
    n_top_up_gps=10,
    n_top_down_gps=10,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    save_figs=False)

top_gps_summary_df = store_top_gps_summary(model=model,
                                           top_gps=top_unique_gps,
                                           file_name=f"{cell_type_fmt}_cluster{selected_cats[0]}_gp_summary.csv")
display(top_gps_summary_df)

In [None]:
top_gps = [model.adata.uns[f"autotalker_latent_cluster_{latent_cluster}_differential_gp_scores"]["gene_program"][:n_top_up_gps].tolist() + 
           model.adata.uns[f"autotalker_latent_cluster_{latent_cluster}_differential_gp_scores"]["gene_program"][-n_top_down_gps:].tolist()
           for latent_cluster in cell_type_latent_clusters_of_interest]

top_gps = [gp for gp_list in top_gps for gp in gp_list]

for gp in top_gps:
    cell_type_adata.obs[gp] = model.adata.obs.loc[model.adata.obs[cell_type_latent_cluster_key].notnull(), gp]

In [None]:
fig = sc.pl.dotplot(cell_type_adata,
            top_gps,
            groupby=cell_type_latent_cluster_key,
            dendrogram=True, 
            title="Mixed Mesenchymal Mesoderm Latent Clusters Differential GP Scores",
            swap_axes=True,
            ax=ax1,
            return_fig=True)
title = fig.suptitle(t=f"asd",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec = gridspec.GridSpec(ncols=2,
                         nrows=1,
                         width_ratios=[1, 1],
                         height_ratios=[1])
ax1 = fig.add_subplot(spec[0])
ax2 = fig.add_subplot(spec[0])

#ax2 = sns.heatmap(enrichment_df, annot=True, fmt=".2f", cmap="viridis", ax=ax2)
plt.show()

In [None]:
axs = sc.pl.heatmap(cell_type_adata,
                   top_gps,
                   groupby=cell_type_latent_cluster_key,
                   cmap="viridis",
                   dendrogram=True,
                   swap_axes=True,
                   figsize=(12, 12),
                   show=False)
fig = axs["heatmap_ax"].get_figure()
fig.subplots_adjust(left=0.4)
spec = matplotlib.gridspec.GridSpec(ncols=2, nrows=1, right=2.2)
ax2 = fig.add_subplot(spec[1])
sns.heatmap(enrichment_df, annot=True, fmt=".2f", cmap="viridis", ax=ax2)
ax2.set_title("Neighborhood enrichments")

In [None]:
fig = sc.pl.dotplot(cell_type_adata,
                    top_gps,
                    groupby=cell_type_latent_cluster_key,
                    dendrogram=True, 
                    title=f"{cell_type_fmt} Latent Clusters Differential GP Scores",
                    cmap="magma",
                    swap_axes=True,
                    return_fig=True)
# Save and display plot
#if save_figs:
#    fig.savefig(f"{figure_folder_path}/{analysis_label}_differential_gp_scores.png")
fig.show()

In [None]:
# Retrieve summary information for top gene programs
cell_type_gp_df = gp_summary_df[gp_summary_df["gp_name"].isin(top_gps)][[
    "gp_name",
    "n_source_genes",
    "n_non_zero_source_genes",
    "n_target_genes",
    "n_non_zero_target_genes",
    "gp_source_genes",
    "gp_target_genes",
    "gp_source_genes_weights_sign_corrected",
    "gp_target_genes_weights_sign_corrected",
    "gp_source_genes_importances",
    "gp_target_genes_importances"]]

# Write to disk
cell_type_gp_df.to_csv(f"{model_artifacts_folder_path}/cell_type_gp_df.csv")
display(cell_type_gp_df)

#### 3.2.4 Cardiomyocytes vs Endothelium

Cardiomyocytes are spatially and morphologically distinct from other cell types, endothelium is interspersed and spread across the entire embryo space.

In [None]:
plot_physical_latent_for_cell_types(adata=model.adata,
                                    cell_types=["Cardiomyocytes", "Endothelium"],
                                    save_fig=True)

### 3.3 Mixed Mesenchymal Mesoderm

"mixed mesenchymal mesoderm, represent 
a cell state rather than a defined cell type. Mesenchyme represents  a  state  in  which  cells  express  markers  characteristic  of  migratory   cells loosely dispersed within an extracellular matrix56. This strong  overriding transcriptional signature of mesenchyme, irrespective of  location,  makes  it  challenging  to  distinguish  which  cell  types  this   mixed  mesenchymal  mesoderm  population  represents  using  clas- sical scRNA-seq data"

We can identify distinct subpopulations that are spatially defined.

In [None]:
cell_type = "Mixed mesenchymal mesoderm"
cell_type_fmt = cell_type.replace(' ', '_').lower()
cell_type_latent_resolution = 0.5 # 0.5
cell_type_latent_cluster_key = f"{cell_type_fmt}_latent_leiden_{cell_type_latent_resolution}"
cell_type_latent_cluster_emphasis_key = f"{cell_type_latent_cluster_key}_emphasis"
cell_type_differential_gp_scores_key = f"{cell_type}_differential_gp_scores"

In [None]:
plot_physical_latent_for_cell_types(adata=model.adata,
                                    cell_types=["Mixed mesenchymal mesoderm"],
                                    save_fig=False)

In [None]:
top_unique_gps = get_differential_analysis_results(
    analysis_label=cell_type_fmt,
    model=model,
    adata=model.adata,
    cat_key=cell_type_key,
    selected_cats=[cell_type],
    differential_gp_scores_key=cell_type_differential_gp_scores_key,
    comparison_cats="rest",
    plot_category=cell_type_key,
    plot_group=cell_type,
    selected_gps=None,
    n_top_up_gps=10,
    n_top_down_gps=10,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    save_figs=False)

top_gps_summary_df = store_top_gps_summary(model=model,
                                           top_gps=top_unique_gps,
                                           file_name=f"{cell_type_fmt}_gp_summary.csv")
display(top_gps_summary_df)

In [None]:
# Compute latent Leiden clustering with cell-type-specific resolution
sc.tl.leiden(adata=model.adata,
             resolution=cell_type_latent_resolution,
             random_state=random_seed,
             key_added=cell_type_latent_cluster_key,
             neighbors_key=latent_knng_key)

# Filter for cell type
cell_type_adata = model.adata[model.adata.obs[cell_type_key] == cell_type]

# Only keep latent clusters for cell type and set rest to NaN
model.adata.obs[cell_type_latent_cluster_key] = np.nan
model.adata.obs.loc[model.adata.obs[cell_type_key] == cell_type,
                    cell_type_latent_cluster_key] = cell_type_adata.obs[cell_type_latent_cluster_key]

In [None]:
plot_latent_physical_for_cell_type_latent_clusters(adata=model.adata,
                                                   cell_type=cell_type,
                                                   save_fig=False)

In [None]:
fig = sc.pl.umap(
    model.adata,
    color=f"{cell_type_fmt}_latent_leiden_{cell_type_latent_resolution}",
    palette=latent_cluster_colors,
    groups="11",
    size=2560000/len(model.adata),
    return_fig=True)
fig.set_size_inches(15, 10)

In [None]:
# Retrieve all latent clusters with cells of the given cell type
cell_type_latent_clusters = [cell_type_latent_cluster for cell_type_latent_cluster 
                             in model.adata.obs[cell_type_latent_cluster_key].unique().tolist()
                             if str(cell_type_latent_cluster) != "nan"]

# Define latent clusters of interest according to visualization
cell_type_latent_clusters_of_interest = ["1", "22", "16", "13", "20", "24"]

# Add column with sub cell types
model.adata.obs[sub_cell_type_key] = model.adata.obs.apply(lambda row: add_sub_cell_type(row, cell_type=cell_type), axis=1)
model.adata.obs[sub_cell_type_key] = model.adata.obs[sub_cell_type_key].astype("category")

# Squidpy nhood enrichment is sorted alphabetically
sub_cell_types = model.adata.obs[sub_cell_type_key].unique().tolist()
sub_cell_types.sort()

sq.gr.nhood_enrichment(model.adata, cluster_key=sub_cell_type_key)

cell_type_sub_cell_types_idx = []
for i, sub_cell_type in enumerate(sub_cell_types):
    if cell_type in sub_cell_type:
        cell_type_sub_cell_types_idx.append(i)

        # Retrieve cell type latent cluster neighborhood enrichments
enrichment_dict = {f"{sub_cell_types[cell_type_sub_cell_types_idx[i]]}":{
    sub_cell_type: zscore for zscore, sub_cell_type in sorted(
        zip(model.adata.uns[f"{sub_cell_type_key}_nhood_enrichment"]["zscore"][cell_type_sub_cell_types_idx[i]], sub_cell_types), reverse=True)} for i in range(len(cell_type_sub_cell_types_idx))}

enrichment_df = pd.DataFrame(enrichment_dict)

#### 3.3.1 Latent Cluster 1

In [None]:
selected_cats = ["1"]
comparison_cats = list(set(cell_type_latent_clusters_of_interest) - set(selected_cats))
latent_cluster_differential_gp_scores_key = f"autotalker_latent_cluster_{selected_cats[0]}_differential_gp_scores"

model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs.apply(
    lambda row: add_cell_type_latent_cluster_emphasis(row, comparison_cats), axis=1)
model.adata.obs[cell_type_latent_cluster_emphasis_key] = model.adata.obs[cell_type_latent_cluster_emphasis_key].astype("category")

top_unique_gps = get_differential_analysis_results(
    analysis_label=f"{cell_type}_latent_cluster",
    model=model,
    adata=model.adata,
    cat_key=cell_type_latent_cluster_key,
    selected_cats=selected_cats,
    differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
    comparison_cats=comparison_cats,
    plot_category=cell_type_latent_cluster_emphasis_key,
    plot_group=[x for x in model.adata.obs[cell_type_latent_cluster_emphasis_key].unique().tolist() if str(x) != "nan"],
    selected_gps=None,
    n_top_up_gps=10,
    n_top_down_gps=10,
    feature_spaces=["latent"], # "physical_embryo1", "physical_embryo2", "physical_embryo3"
    save_figs=False)

top_gps_summary_df = store_top_gps_summary(model=model,
                                           top_gps=top_unique_gps,
                                           file_name=f"{cell_type_fmt}_cluster{selected_cats[0]}_gp_summary.csv")
display(top_gps_summary_df)

#### Overview

In [None]:
model.adata.obs.loc[model.adata.obs[cell_type_latent_cluster_key] == "0", "MPZ_ligand_targetgenes_GP"].mean()

In [None]:
model.adata.obs.loc[model.adata.obs[cell_type_latent_cluster_key] == "1", "MPZ_ligand_targetgenes_GP"].mean()

In [None]:
model.adata.obs.loc[model.adata.obs[cell_type_latent_cluster_key] == "2", "MPZ_ligand_targetgenes_GP"].mean()

In [None]:
model.adata.obs.loc[model.adata.obs[cell_type_latent_cluster_key] == "3", "MPZ_ligand_targetgenes_GP"].mean()

In [None]:
sq.pl.nhood_enrichment(model.adata, cluster_key=sub_cell_type_key)

In [None]:
enrichment_dict = {f"{i}":{
    cell_type: zscore for zscore, cell_type in sorted(
        zip(model.adata.uns[f"{sub_cell_type_key}_nhood_enrichment"]["zscore"][15 + i], cell_types), reverse=True)} for i in range(4)}

enrichment_df = pd.DataFrame(enrichment_dict)

In [None]:
top_gps = [model.adata.uns[f"autotalker_latent_cluster_{latent_cluster}_differential_gp_scores"]["gene_program"][:n_top_up_gps].tolist() + 
           model.adata.uns[f"autotalker_latent_cluster_{latent_cluster}_differential_gp_scores"]["gene_program"][-n_top_down_gps:].tolist()
           for latent_cluster in [cell_type_latent_cluster for cell_type_latent_cluster in model.adata.obs[cell_type_latent_cluster_key].unique().tolist() if str(cell_type_latent_cluster) != "nan"]]

top_gps = [gp for gp_list in top_gps for gp in gp_list]

for gp in top_gps:
    cell_type_adata.obs[gp] = model.adata.obs.loc[model.adata.obs[cell_type_latent_cluster_key].notnull(), gp]

In [None]:
axs = sc.pl.heatmap(cell_type_adata,
                   top_gps,
                   groupby=cell_type_latent_cluster_key,
                   cmap="viridis",
                   dendrogram=True,
                   swap_axes=True,
                   figsize=(12, 12),
                   show=False)
fig = axs["heatmap_ax"].get_figure()
fig.subplots_adjust(left=0.4)
spec = matplotlib.gridspec.GridSpec(ncols=2, nrows=1, right=2.2)
ax2 = fig.add_subplot(spec[1])
sns.heatmap(enrichment_df, annot=True, fmt=".2f", cmap="viridis", ax=ax2)
ax2.set_title("Neighborhood enrichments")

In [None]:
fig = sc.pl.dotplot(cell_type_adata,
                    top_gps,
                    groupby=cell_type_latent_cluster_key,
                    dendrogram=True, 
                    title="Mixed Mesenchymal Mesoderm Latent Clusters Differential GP Scores",
                    swap_axes=True,
                    return_fig=True)
# Save and display plot
#if save_figs:
#    fig.savefig(f"{figure_folder_path}/{analysis_label}_differential_gp_scores.png")
plt.show()

In [None]:
fig = sc.pl.dotplot(cell_type_adata,
                    top_gps,
                    groupby=cell_type_latent_cluster_key,
                    dendrogram=True, 
                    title="Mixed Mesenchymal Mesoderm Latent Clusters Differential GP Scores",
                    swap_axes=True,
                    return_fig=True)
fig.show()

In [None]:
# fig, axs = plt.subplots(1, 2, figsize=(10, 10))
sns.heatmap(enrichment_df, annot=True, ax=ax2)
ax2.set_title("Neighborhood enrichments")

In [None]:
fig

### 3.5 CellRank