# Autotalker Data Analysis Mouse Organogenesis Imputed

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 22.01.2023
- **Date of Last Modification:** 09.02.2023

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../autotalker")

In [None]:
import argparse
import os
import random
import warnings
from copy import deepcopy
from datetime import datetime

import anndata as ad
import matplotlib
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import seaborn as sns
import squidpy as sq
import torch
from matplotlib import gridspec
from matplotlib.pyplot import rc_context

from autotalker.models import Autotalker
from autotalker.utils import (add_gps_from_gp_dict_to_adata,
                              extract_gp_dict_from_mebocost_es_interactions,
                              extract_gp_dict_from_nichenet_ligand_target_mx,
                              extract_gp_dict_from_omnipath_lr_interactions,
                              filter_and_combine_gp_dict_gps,
                              get_unique_genes_from_gp_dict)

### 1.2 Define Parameters

In [None]:
## Dataset
dataset = "seqfish_mouse_organogenesis"
batch1 = "embryo1_z2"
batch2 = "embryo1_z5"
batch3 = "embryo2_z2"
batch4 = "embryo2_z5"
batch5 = "embryo3_z2"
batch6 = "embryo3_z5"
n_neighbors = 12

## Model
# AnnData Keys
counts_key = "log_normalized_counts" # raw counts not available
cell_type_key = "celltype_mapped_refined"
adj_key = "spatial_connectivities"
spatial_key = "spatial"
gp_names_key = "autotalker_gp_names"
active_gp_names_key = "autotalker_active_gp_names"
gp_targets_mask_key = "autotalker_gp_targets"
gp_sources_mask_key = "autotalker_gp_sources"
latent_key = "autotalker_latent"
condition_key = "batch"
genes_idx_key = "autotalker_genes_idx"
query_enriched_cell_type_key = "autotalker_query_enriched_cell_types"
query_enriched_latent_cluster_key = "autotalker_query_enriched_latent_clusters"
mapping_entity_key = "mapping_entity"

# Architecture
active_gp_thresh_ratio = 0.03
gene_expr_recon_dist = "nb"
n_cond_embed = 3
log_variational = False # log normalized counts as input

# Trainer
n_epochs = 40
n_epochs_all_gps = 20
lr = 0.001
query_cond_embed_lr = 0.01
lambda_edge_recon = 0.01
lambda_gene_expr_recon = 0.0033
lambda_l1_addon = 0.01

# Benchmarking
spatial_knng_key = "autotalker_spatial_knng"
latent_knng_key = "autotalker_latent_knng"

# Analysis
enriched_query_cell_type_prop_thresh = 5
enriched_query_latent_cluster_prop_thresh = 2.5
latent_leiden_resolution = 0.3
latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"
denovo_cell_type_differential_gp_scores_key = "autotalker_denovo_cell_type_differential_gp_scores"
cell_type_differential_gp_scores_key = "autotalker_cell_type_differential_gp_scores"
query_enriched_cell_type_differential_gp_scores_key = "autotalker_query_enriched_latent_cluster_differential_gp_scores"
latent_cluster_differential_gp_scores_key = "autotalker_latent_cluster_differential_gp_scores"
query_enriched_latent_cluster_differential_gp_scores_key = "autotalker_query_enriched_latent_cluster_differential_gp_scores"
n_top_up_gps = 3
n_top_down_gps = 3
n_top_genes_per_gp = 3

## Others
random_seed = 42
load_timestamp = "05022023_160853" # "06022023_064344" # saved model to be loaded

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))
sns.set_style("whitegrid", {'axes.grid' : False})

In [None]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
if load_timestamp is not None:
    current_timestamp = load_timestamp
else:
    current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Create Directories

In [None]:
# Define paths
figure_folder_path = f"../figures/{dataset}/analysis/{current_timestamp}"
model_artifacts_folder_path = f"../artifacts/{dataset}/analysis/{current_timestamp}"
gp_data_folder_path = "../datasets/gp_data" # gene program data
srt_data_folder_path = "../datasets/srt_data" # spatially resolved transcriptomics data
srt_data_gold_folder_path = f"{srt_data_folder_path}/gold"
nichenet_ligand_target_mx_file_path = gp_data_folder_path + "/nichenet_ligand_target_matrix.csv"
omnipath_lr_interactions_file_path = gp_data_folder_path + "/omnipath_lr_interactions.csv"

# Create required directories
os.makedirs(figure_folder_path, exist_ok=True)
os.makedirs(model_artifacts_folder_path, exist_ok=True)

### 1.5 Define Functions

In [None]:
def plot_gp_scores(adata,
                   top_cats,
                   top_gps,
                   top_genes,
                   top_genes_importances,
                   top_genes_signs,
                   top_genes_entities,
                   gene_importances_sums,
                   plot_type,
                   plot_category,
                   feature_space,
                   suptitle,
                   cat_title,
                   fig_name,
                   sample=None):
        
    # Plot selected gene program latent scores
    if plot_type == "gene_categories":
        ncols = 6
        fig_width = 36
        wspace = 0.155
    elif plot_type == "individual_genes":
        ncols = 2 + n_top_genes_per_gp
        fig_width = 12 + (6 * n_top_genes_per_gp)
        wspace = 0.3
    fig, axs = plt.subplots(nrows=n_top_up_gps + n_top_down_gps, ncols=ncols, figsize=(fig_width, 6*(n_top_up_gps + n_top_down_gps)))

    title = fig.suptitle(t=suptitle,
                         x=0.55,
                         y=0.94,
                         fontsize=20)
    for i, gp in enumerate(top_gps):
        if feature_space == "latent":
            sc.pl.umap(adata,
                       color=plot_category,
                       groups=top_cats[i],
                       ax=axs[i, 0],
                       title=cat_title,
                       legend_loc="on data",
                       na_in_legend=False,
                       show=False)
            sc.pl.umap(adata,
                       color=top_gps[i],
                       color_map="PuBuGn",
                       ax=axs[i, 1],
                       title=f"{top_gps[i][:top_gps[i].index('_')]}\n{top_gps[i][top_gps[i].index('_') + 1: top_gps[i].rindex('_')].replace('_', ' ')}\n{top_gps[i][top_gps[i].rindex('_') + 1:]} score",
                       show=False)
        elif feature_space == "physical":
            sc.pl.spatial(adata=adata[adata.obs["sample"] == sample],
                          color=plot_category,
                          groups=top_cats[i],
                          ax=axs[i, 0],
                          spot_size=0.03,
                          title=cat_title,
                          legend_loc="on data",
                          na_in_legend=False,
                          show=False)
            sc.pl.spatial(adata=adata[adata.obs["sample"] == sample],
                          color=top_gps[i],
                          color_map="PuBuGn",
                          spot_size=0.03,
                          title=f"{top_gps[i].split('_', 1)[0]}\n{top_gps[i].split('_', 1)[1]}",
                          legend_loc=None,
                          ax=axs[i, 1],
                          show=False) 
        axs[i, 0].xaxis.label.set_visible(False)
        axs[i, 0].yaxis.label.set_visible(False)
        axs[i, 1].xaxis.label.set_visible(False)
        axs[i, 1].yaxis.label.set_visible(False)
        if plot_type == "gene_categories":
            for j, gene_category in enumerate(["pos_sign_target_genes",
                                               "pos_sign_source_genes",
                                               "neg_sign_target_genes",
                                               "neg_sign_source_genes"]):
                if not adata.obs[f"{gp}_{gene_category}_weighted_mean_gene_expr"].isna().any():
                    if feature_space == "latent":
                        sc.pl.umap(adata,
                                   color=f"{gp}_{gene_category}_weighted_mean_gene_expr",
                                   color_map=("BuGn" if "pos_sign" in gene_category else plt.cm.get_cmap("PuBu").reversed()),
                                   ax=axs[i, j+2],
                                   legend_loc="on data",
                                   na_in_legend=False,
                                   title=f"Weighted mean gene expression \n {gene_category.replace('_', ' ')} ({gene_importances_sums[i][j]:.2f})",
                                   show=False)
                    elif feature_space == "physical":
                        sc.pl.spatial(adata=adata[adata.obs["sample"] == sample],
                                      color=f"{gp}_{gene_category}_weighted_mean_gene_expr",
                                      color_map=("BuGn" if "pos_sign" in gene_category else plt.cm.get_cmap("PuBu").reversed()),
                                      ax=axs[i, 2+j],
                                      legend_loc="on data",
                                      na_in_legend=False,
                                      groups=top_cats[i],
                                      spot_size=0.03,
                                      title=f"Weighted mean gene expression \n {gene_category.replace('_', ' ')} ({gene_importances_sums[i][j]:.2f})",
                                      show=False)                        
                    axs[i, j+2].xaxis.label.set_visible(False)
                    axs[i, j+2].yaxis.label.set_visible(False)
                else:
                    axs[i, j+2].set_visible(False)
        elif plot_type == "individual_genes":
            for j in range(len(top_genes[i])):
                if feature_space == "latent":
                    sc.pl.umap(adata,
                               color=top_genes[i][j],
                               color_map=("BuGn" if top_genes_signs[i][j] == "+" else plt.cm.get_cmap("PuBu").reversed()),
                               ax=axs[i, 2+j],
                               legend_loc="on data",
                               na_in_legend=False,
                               title=f"{top_genes[i][j]}: {top_genes_importances[i][j]:.2f} ({top_genes_entities[i][j][0]}; {top_genes_signs[i][j]})",
                               show=False)
                elif feature_space == "physical":
                      sc.pl.spatial(adata=adata[adata.obs["sample"] == sample],
                                    color=top_genes[i][j],
                                    color_map=("BuGn" if top_genes_signs[i][j] == "+" else plt.cm.get_cmap("PuBu").reversed()),
                                    legend_loc="on data",
                                    na_in_legend=False,
                                    ax=axs[i, 2+j],
                                    groups=top_cats[i],
                                    spot_size=0.03,
                                    title=f"{top_genes[i][j]}: {top_genes_importances[i][j]:.2f} ({top_genes_entities[i][j][0]}; {top_genes_signs[i][j]})",
                                    show=False)
                axs[i, 2+j].xaxis.label.set_visible(False)
                axs[i, 2+j].yaxis.label.set_visible(False)
            for k in range(len(top_genes[i]), ncols - 2):
                axs[i, 2+k].set_visible(False)            

    # Save and display plot
    plt.subplots_adjust(wspace=wspace, hspace=0.275)
    fig.savefig(f"{figure_folder_path}/{fig_name}.png",
                bbox_extra_artists=(title,),
                bbox_inches="tight")
    plt.show()

In [None]:
def run_differential_gp_analysis(analysis_label,
                                 model,
                                 selected_gps,
                                 selected_cats,
                                 differential_gp_scores_key,
                                 n_top_up_gps,
                                 n_top_down_gps,
                                 physical_feature_space_sample):
    adata = model.adata.copy()
    
    if "cell_type" in analysis_label:
        cat_key = cell_type_key
    elif "latent_cluster" in analysis_label:
        cat_key = latent_cluster_key
    
    # Compute gene program enrichments and retrieve top up- and downregulated gene programs
    top_unique_gps = model.compute_differential_gp_scores(cat_key=cat_key,
                                                          adata=adata,
                                                          selected_gps=selected_gps,
                                                          selected_cats=selected_cats,
                                                          gp_scores_weight_normalization=True,
                                                          comparison_cats="rest",
                                                          n_sample=10000,
                                                          key_added=differential_gp_scores_key,
                                                          n_top_up_gps_retrieved=n_top_up_gps,
                                                          n_top_down_gps_retrieved=n_top_down_gps,
                                                          seed=random_seed)
    
    # Display top upregulated gene programs
    top_up_gp_df = adata.uns[differential_gp_scores_key][:n_top_up_gps]
    display(top_up_gp_df)
    
    # Display top downregulated gene programs
    top_down_gp_df = adata.uns[differential_gp_scores_key][-n_top_down_gps:][::-1]
    display(top_down_gp_df)
    
    fig = sc.pl.dotplot(adata,
                        top_unique_gps,
                        groupby=cat_key,
                        dendrogram=True, 
                        title=f"{analysis_label.replace('_', ' ').title()} Differential GP Scores",
                        swap_axes=True,
                        return_fig=True)
    # Save and display plot
    fig.savefig(f"{figure_folder_path}/{analysis_label}_differential_gp_scores.png")
    plt.show()
    
    # Inspect top up- and downregulated gene programs
    display(gp_summary_df[gp_summary_df["gp_name"].isin(top_unique_gps)])
    
    # Plot gp scores
    top_cats = top_up_gp_df["category"].append(top_down_gp_df["category"]).to_list()
    top_gps = top_up_gp_df["gene_program"].append(top_down_gp_df["gene_program"]).to_list()
    top_genes = []
    top_genes_importances = []
    top_genes_signs = []
    top_genes_entities = []
    pos_sign_source_genes_weighted_gene_exprs = []
    pos_sign_target_genes_weighted_gene_exprs = []
    neg_sign_source_genes_weighted_gene_exprs = []
    neg_sign_target_genes_weighted_gene_exprs = []
    gene_importances_sums = []

    for gp in top_gps:
        gp_gene_importances_df = model.compute_gp_gene_importances(selected_gp=gp)

        pos_sign_target_genes = gp_gene_importances_df.loc[
            (gp_gene_importances_df["gene_weight_sign_corrected"] > 0) &
            (gp_gene_importances_df["gene_entity"] == "target"), "gene"].tolist()
        pos_sign_source_genes = gp_gene_importances_df.loc[
            (gp_gene_importances_df["gene_weight_sign_corrected"] > 0) &
            (gp_gene_importances_df["gene_entity"] == "source"), "gene"].tolist()
        neg_sign_target_genes = gp_gene_importances_df.loc[
            (gp_gene_importances_df["gene_weight_sign_corrected"] < 0) &
            (gp_gene_importances_df["gene_entity"] == "target"), "gene"].tolist()
        neg_sign_source_genes = gp_gene_importances_df.loc[
            (gp_gene_importances_df["gene_weight_sign_corrected"] < 0) &
            (gp_gene_importances_df["gene_entity"] == "source"), "gene"].tolist()

        pos_sign_target_gene_importances = gp_gene_importances_df.loc[
            (gp_gene_importances_df["gene_weight_sign_corrected"] > 0) &
            (gp_gene_importances_df["gene_entity"] == "target"), "gene_importance"].values.reshape(1, -1)
        pos_sign_source_gene_importances = gp_gene_importances_df.loc[
            (gp_gene_importances_df["gene_weight_sign_corrected"] > 0) &
            (gp_gene_importances_df["gene_entity"] == "source"), "gene_importance"].values.reshape(1, -1)
        neg_sign_target_gene_importances = gp_gene_importances_df.loc[
            (gp_gene_importances_df["gene_weight_sign_corrected"] < 0) &
            (gp_gene_importances_df["gene_entity"] == "target"), "gene_importance"].values.reshape(1, -1)
        neg_sign_source_gene_importances = gp_gene_importances_df.loc[
            (gp_gene_importances_df["gene_weight_sign_corrected"] < 0) &
            (gp_gene_importances_df["gene_entity"] == "source"), "gene_importance"].values.reshape(1, -1)

        pos_sign_target_gene_importances_sum = pos_sign_target_gene_importances.sum()
        pos_sign_source_gene_importances_sum = pos_sign_source_gene_importances.sum()
        neg_sign_target_gene_importances_sum = neg_sign_target_gene_importances.sum()
        neg_sign_source_gene_importances_sum = neg_sign_source_gene_importances.sum()
        gene_importances_sums.append(np.array([pos_sign_target_gene_importances_sum,
                                               pos_sign_source_gene_importances_sum,
                                               neg_sign_target_gene_importances_sum,
                                               neg_sign_source_gene_importances_sum]))

        pos_sign_target_genes_weighted_gene_expr = adata[:, pos_sign_target_genes].X.toarray()
        pos_sign_source_genes_weighted_gene_expr = adata[:, pos_sign_source_genes].X.toarray()
        neg_sign_target_genes_weighted_gene_expr = adata[:, neg_sign_target_genes].X.toarray()
        neg_sign_source_genes_weighted_gene_expr = adata[:, neg_sign_source_genes].X.toarray()

        adata.obs[f"{gp}_pos_sign_target_genes_weighted_mean_gene_expr"] = (
            np.mean(pos_sign_target_genes_weighted_gene_expr * pos_sign_target_gene_importances, axis=1))
        adata.obs[f"{gp}_pos_sign_source_genes_weighted_mean_gene_expr"] = (
            np.mean(pos_sign_source_genes_weighted_gene_expr * pos_sign_source_gene_importances, axis=1))
        adata.obs[f"{gp}_neg_sign_target_genes_weighted_mean_gene_expr"] = (
            np.mean(neg_sign_target_genes_weighted_gene_expr * neg_sign_target_gene_importances, axis=1))
        adata.obs[f"{gp}_neg_sign_source_genes_weighted_mean_gene_expr"] = (
            np.mean(neg_sign_source_genes_weighted_gene_expr * neg_sign_source_gene_importances, axis=1))

        top_genes.append(gp_gene_importances_df["gene"][:n_top_genes_per_gp])
        top_genes_importances.append(gp_gene_importances_df["gene_importance"][:n_top_genes_per_gp])
        top_genes_signs.append(np.where(gp_gene_importances_df["gene_weight_sign_corrected"] > 0, "+", "-"))
        top_genes_entities.append(gp_gene_importances_df["gene_entity"])
    
    for feature_space in ["latent", "physical"]:
        for plot_type in ["gene_categories", "individual_genes"]:
            plot_gp_scores(adata=adata,
                           top_cats=top_cats,
                           top_gps=top_gps,
                           top_genes=top_genes,
                           top_genes_importances=top_genes_importances,
                           top_genes_signs=top_genes_signs,
                           top_genes_entities=top_genes_entities,
                           gene_importances_sums=gene_importances_sums,
                           plot_type=plot_type,
                           plot_category=cat_key,
                           feature_space=feature_space,
                           suptitle=f"{analysis_label.replace('_', ' ').title()} Differential GPs: "
                                    f"GP Scores and {'Weighted Mean ' if plot_type == 'gene_categories' else ''}"
                                    f"Gene Expression of {plot_type.replace('_', ' ').title()} in {feature_space.capitalize()} Space",
                           cat_title=f"GP-enriched \n {analysis_label.replace('_', ' ')}",
                           fig_name=f"{analysis_label}_gp_scores_{'weighted_mean_' if plot_type == 'gene_categories' else ''}gene_expr_"
                                    f"{plot_type}_{feature_space}_space",
                           sample=physical_feature_space_sample)

## 2. Analysis

### 2.1 Load Model

In [None]:
if load_timestamp is not None:
    model_artifacts_load_folder_path = f"../artifacts/{dataset}/batch_integration/{load_timestamp}"
else:
    model_artifacts_load_folder_path = model_artifacts_folder_path
    
# Load trained model    
model = Autotalker.load(dir_path=f"{model_artifacts_load_folder_path}/reference_query",
                        adata=None,
                        adata_file_name=f"{dataset}.h5ad",
                        gp_names_key="autotalker_gp_names")

### 2.2 Retrieve GP Summary

In [None]:
# Check number of active gene programs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"]].head()

### 2.3 Cell type-based Analysis

#### 2.3.1 Visualize Cell Types in Physical and Latent Space

In [None]:
# Create plot of cell type annotations in physical and latent space
fig = plt.figure(figsize=(12, 14))
title = fig.suptitle(t="Autotalker: Cell Types in Physical and Latent Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec = gridspec.GridSpec(ncols=1,
                         nrows=2,
                         width_ratios=[1],
                         height_ratios=[1, 5])
ax1 = fig.add_subplot(2, 3, 1)
ax2 = fig.add_subplot(2, 3, 2)
ax3 = fig.add_subplot(2, 3, 3)
ax4 = fig.add_subplot(spec[1])
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo1"],
              color=[cell_type_key],
              spot_size=0.03,
              title="Physical Space Embryo 1",
              legend_loc=None,
              ax=ax1,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo2"],
              color=[cell_type_key],
              spot_size=0.03,
              title="Physical Space Embryo 2",
              legend_loc=None,
              ax=ax2,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo3"],
              color=[cell_type_key],
              spot_size=0.03,
              title="Physical Space Embryo 3",
              legend_loc=None,
              ax=ax3,
              show=False)
sc.pl.umap(adata=model.adata,
           color=[cell_type_key],
           size=240000/len(model.adata),
           title="Autotalker Latent Space",
           ax=ax4,
           show=False)

# Create and position shared legend
handles, labels = ax4.get_legend_handles_labels()
lgd = fig.legend(handles, labels, bbox_to_anchor=(1.25, 0.8))
ax4.get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0., hspace=0.85)
fig.savefig(f"{figure_folder_path}/cell_types_physical_latent_space.png",
            bbox_extra_artists=(lgd, title),
            bbox_inches="tight")
plt.show()

In [None]:
cell_types = ["Cardiomyocytes", "Endothelium"]

In [None]:
# Create plot of cell type annotations in physical and latent space
fig = plt.figure(figsize=(12, 14))
title = fig.suptitle(t="Autotalker: Cell Types in Physical and Latent Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec = gridspec.GridSpec(ncols=1,
                         nrows=2,
                         width_ratios=[1],
                         height_ratios=[1, 5])
ax1 = fig.add_subplot(2, 3, 1)
ax2 = fig.add_subplot(2, 3, 2)
ax3 = fig.add_subplot(2, 3, 3)
ax4 = fig.add_subplot(spec[1])
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo1"],
              color=[cell_type_key],
              palette="Accent",
              groups=cell_types,
              size=160000/len(model.adata),
              spot_size=0.03,
              title="Physical Space Embryo 1",
              legend_loc=None,
              ax=ax1,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo2"],
              color=[cell_type_key],
              palette="Accent",
              groups=cell_types,
              size=160000/len(model.adata),
              spot_size=0.03,
              title="Physical Space Embryo 2",
              legend_loc=None,
              ax=ax2,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo3"],
              color=[cell_type_key],
              palette="Accent",
              groups=cell_types,
              size=160000/len(model.adata),
              spot_size=0.03,
              title="Physical Space Embryo 3",
              legend_loc=None,
              ax=ax3,
              show=False)
sc.pl.umap(adata=model.adata,
           color=[cell_type_key],
           palette="Accent",
           groups=cell_types,
           size=1280000/len(model.adata),
           title="Autotalker Latent Space",
           ax=ax4,
           show=False)

# Create and position shared legend
handles, labels = ax4.get_legend_handles_labels()
lgd = fig.legend(handles, labels, bbox_to_anchor=(1.12, 0.6475))
ax4.get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0., hspace=0.85)
fig.savefig(f"{figure_folder_path}/cell_types_cardiomyocytes_endothelium_physical_latent_space.png",
            bbox_extra_artists=(lgd, title),
            bbox_inches="tight")
plt.show()

#### 2.3.2 Visualize Mapping Entities in Physical and Latent Space

In [None]:
model.adata.uns["mapping_entity_colors"] = ['#1f77b4', '#ff7f0e']

# Create plot of mapping entity annotations in physical and latent space
fig = plt.figure(figsize=(12, 14))
title = fig.suptitle(t="Autotalker: Mapping Entities in Physical and Latent Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec = gridspec.GridSpec(ncols=1,
                         nrows=2,
                         width_ratios=[1],
                         height_ratios=[1, 5])
ax1 = fig.add_subplot(2, 3, 1)
ax2 = fig.add_subplot(2, 3, 2)
ax3 = fig.add_subplot(2, 3, 3)
ax4 = fig.add_subplot(spec[1])
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo1"],
              color=[mapping_entity_key],
              spot_size=0.03,
              title="Physical Space Embryo 1",
              legend_loc=None,
              ax=ax1,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo2"],
              color=[mapping_entity_key],
              spot_size=0.03,
              title="Physical Space Embryo 2",
              legend_loc=None,
              ax=ax2,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo3"],
              color=[mapping_entity_key],
              spot_size=0.03,
              title="Physical Space Embryo 3",
              legend_loc=None,
              ax=ax3,
              show=False)
sc.pl.umap(adata=model.adata,
           color=[mapping_entity_key],
           size=240000/len(model.adata),
           title="Autotalker Latent Space",
           ax=ax4,
           show=False)

# Create and position shared legend
handles, labels = ax4.get_legend_handles_labels()
lgd = fig.legend(handles, labels, bbox_to_anchor=(1.08, 0.645))
ax4.get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0., hspace=0.85)
fig.savefig(f"{figure_folder_path}/mapping_entities_physical_latent_space.png",
            bbox_extra_artists=(lgd, title),
            bbox_inches="tight")
plt.show()

#### 2.3.3 Visualize Query-enriched Cell Types in Physical and Latent Space

In [None]:
# Get query-enriched cell types
adata_reference = model.adata[model.adata.obs[mapping_entity_key] == "reference"]
adata_query = model.adata[model.adata.obs[mapping_entity_key] == "query"]
cell_type_reference_proportions = adata_reference.obs[cell_type_key].value_counts().sort_index() / len(adata_reference)
cell_type_query_proportions = adata_query.obs[cell_type_key].value_counts().sort_index() / len(adata_query)
relative_cell_type_query_proportions = cell_type_query_proportions / cell_type_reference_proportions
relative_cell_type_query_proportions.sort_values(ascending=False, inplace=True)
display(relative_cell_type_query_proportions)

query_enriched_cell_types = relative_cell_type_query_proportions[relative_cell_type_query_proportions > enriched_query_cell_type_prop_thresh].index.to_list()
model.adata.obs[query_enriched_cell_type_key] = "Cell types not enriched in query"
for cell_type in query_enriched_cell_types:
    model.adata.obs.loc[model.adata.obs[cell_type_key] == cell_type, query_enriched_cell_type_key] = cell_type

In [None]:
# Create plot of query-enriched cell-type annotations in physical and latent space
fig = plt.figure(figsize=(12, 14))
title = fig.suptitle(t="Autotalker: Query-enriched Cell Types in Physical and Latent Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec = gridspec.GridSpec(ncols=1,
                         nrows=2,
                         width_ratios=[1],
                         height_ratios=[1, 5])
ax1 = fig.add_subplot(2, 3, 1)
ax2 = fig.add_subplot(2, 3, 2)
ax3 = fig.add_subplot(2, 3, 3)
ax4 = fig.add_subplot(spec[1])
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo1"],
              color=[query_enriched_cell_type_key],
              palette="coolwarm",
              spot_size=0.03,
              title="Physical Space Embryo 1",
              legend_loc=None,
              ax=ax1,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo2"],
              color=[query_enriched_cell_type_key],
              palette="coolwarm",
              spot_size=0.03,
              title="Physical Space Embryo 2",
              legend_loc=None,
              ax=ax2,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo3"],
              color=[query_enriched_cell_type_key],
              palette="coolwarm",
              spot_size=0.03,
              title="Physical Space Embryo 3",
              legend_loc=None,
              ax=ax3,
              show=False)
sc.pl.umap(adata=model.adata,
           color=[query_enriched_cell_type_key],
           palette="coolwarm",
           size=240000/len(model.adata),
           title="Autotalker Latent Space",
           ax=ax4,
           show=False)

# Create and position shared legend
handles, labels = ax4.get_legend_handles_labels()
lgd = fig.legend(handles, labels, bbox_to_anchor=(1.25, 0.6525))
ax4.get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0., hspace=0.85)
fig.savefig(f"{figure_folder_path}/query_enriched_cell_types_physical_latent_space.png",
            bbox_extra_artists=(lgd, title),
            bbox_inches="tight")
plt.show()

#### 2.3.4 Analyze GP Enrichments

##### 2.3.4.1 Explore All GPs Across All Cell Types

In [None]:
run_differential_gp_analysis(analysis_label="cell_type",
                             model=model,
                             selected_gps=None,
                             selected_cats=None,
                             differential_gp_scores_key=cell_type_differential_gp_scores_key,
                             n_top_up_gps=n_top_up_gps,
                             n_top_down_gps=n_top_down_gps,
                             physical_feature_space_sample="embryo3")

##### 2.3.4.2 Explore All GPs Across Query-enriched Cell Types

In [None]:
run_differential_gp_analysis(analysis_label="query_enriched_cell_type",
                             model=model,
                             selected_gps=None,
                             selected_cats=query_enriched_cell_types,
                             differential_gp_scores_key=cell_type_differential_gp_scores_key,
                             n_top_up_gps=n_top_up_gps,
                             n_top_down_gps=n_top_down_gps,
                             physical_feature_space_sample="embryo3")

##### 2.3.4.3 Explore Specific GPs Across All Cell Types

In [None]:
run_differential_gp_analysis(analysis_label="cell_type_epor_gp",
                             model=model,
                             selected_gps=["EPOR_ligand_receptor_GP"],
                             selected_cats=None,
                             differential_gp_scores_key=cell_type_differential_gp_scores_key,
                             n_top_up_gps=n_top_up_gps,
                             n_top_down_gps=n_top_down_gps,
                             physical_feature_space_sample="embryo3")

### 2.4 Latent Clustering-based Analysis

In [None]:
# Compute latent Leiden clustering
sc.tl.leiden(adata=model.adata,
             resolution=latent_leiden_resolution,
             random_state=random_seed,
             key_added=f"latent_leiden_{str(latent_leiden_resolution)}",
             adjacency=model.adata.obsp[f"{latent_key}_knng_connectivities"])

#### 2.4.1 Visualize Latent Leiden Clusters in Latent and Physical Space

In [None]:
# Create plot of latent leiden cluster annotations in physical and latent space
fig = plt.figure(figsize=(12, 14))
title = fig.suptitle(t="Autotalker: Latent Leiden Clusters in Physical and Latent Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec = gridspec.GridSpec(ncols=1,
                         nrows=2,
                         width_ratios=[1],
                         height_ratios=[5, 1])
ax1 = fig.add_subplot(spec[0])
ax2 = fig.add_subplot(2, 3, 4)
ax3 = fig.add_subplot(2, 3, 5)
ax4 = fig.add_subplot(2, 3, 6)
sc.pl.umap(adata=model.adata,
           color=[f"latent_leiden_{str(latent_leiden_resolution)}"],
           size=240000/len(model.adata),
           title="Autotalker Latent Space",
           ax=ax1,
           show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo1"],
              color=[f"latent_leiden_{str(latent_leiden_resolution)}"],
              spot_size=0.03,
              title="Physical Space Embryo 1",
              legend_loc=None,
              ax=ax2,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo2"],
              color=[f"latent_leiden_{str(latent_leiden_resolution)}"],
              spot_size=0.03,
              title="Physical Space Embryo 2",
              legend_loc=None,
              ax=ax3,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo3"],
              color=[f"latent_leiden_{str(latent_leiden_resolution)}"],
              spot_size=0.03,
              title="Physical Space Embryo 3",
              legend_loc=None,
              ax=ax4,
              show=False)

# Create and position shared legend
handles, labels = ax1.get_legend_handles_labels()
lgd = fig.legend(handles, labels, bbox_to_anchor=(1.025, 0.8))
ax1.get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0., hspace=0.85)
fig.savefig(f"{figure_folder_path}/latent_leiden_clusters_latent_physical_space.png",
            bbox_extra_artists=(lgd, title),
            bbox_inches="tight")
plt.show()

#### 2.4.2 Visualize Query-enriched Latent Clusters in Physical and Latent Space

In [None]:
# Get query-enriched latent leiden clusters
adata_reference = model.adata[model.adata.obs[mapping_entity_key] == "reference"]
adata_query = model.adata[model.adata.obs[mapping_entity_key] == "query"]
latent_cluster_reference_proportions = adata_reference.obs[f"latent_leiden_{str(latent_leiden_resolution)}"].value_counts().sort_index() / len(adata_reference)
latent_cluster_query_proportions = adata_query.obs[f"latent_leiden_{str(latent_leiden_resolution)}"].value_counts().sort_index() / len(adata_query)
relative_latent_cluster_query_proportions = latent_cluster_query_proportions / latent_cluster_reference_proportions
relative_latent_cluster_query_proportions.sort_values(ascending=False, inplace=True)
display(relative_latent_cluster_query_proportions)

query_enriched_latent_clusters = relative_latent_cluster_query_proportions[relative_latent_cluster_query_proportions > enriched_query_latent_cluster_prop_thresh].index.to_list()
model.adata.obs[query_enriched_latent_cluster_key] = "Latent clusters not enriched in query"
for latent_cluster in query_enriched_latent_clusters:
    model.adata.obs.loc[model.adata.obs[f"latent_leiden_{str(latent_leiden_resolution)}"] == latent_cluster, query_enriched_latent_cluster_key] = f"Cluster {latent_cluster}"

In [None]:
# Create subplot of cell-type annotations in physical and latent space
fig = plt.figure(figsize=(12, 14))
title = fig.suptitle(t="Autotalker: Query-enriched Latent Clusters in Physical and Latent Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec = gridspec.GridSpec(ncols=1,
                         nrows=2,
                         width_ratios=[1],
                         height_ratios=[1, 5])
ax1 = fig.add_subplot(2, 3, 1)
ax2 = fig.add_subplot(2, 3, 2)
ax3 = fig.add_subplot(2, 3, 3)
ax4 = fig.add_subplot(spec[1])
sc.pl.umap(adata=model.adata,
           color=[query_enriched_latent_cluster_key],
           size=240000/len(model.adata),
           title="Autotalker Latent Space",
           ax=ax4,
           show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo1"],
              color=[query_enriched_latent_cluster_key],
              spot_size=0.03,
              title="Physical Space Embryo 1",
              legend_loc=None,
              ax=ax1,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo2"],
              color=[query_enriched_latent_cluster_key],
              spot_size=0.03,
              title="Physical Space Embryo 2",
              legend_loc=None,
              ax=ax2,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo3"],
              color=[query_enriched_latent_cluster_key],
              spot_size=0.03,
              title="Physical Space Embryo 3",
              legend_loc=None,
              ax=ax3,
              show=False)

# Create and position shared legend
handles, labels = ax4.get_legend_handles_labels()
lgd = fig.legend(handles, labels, bbox_to_anchor=(1.285, 0.6575))
ax4.get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0., hspace=0.85)
fig.savefig(f"{figure_folder_path}/query_enriched_latent_clusters_physical_latent_space.png",
            bbox_extra_artists=(lgd, title),
            bbox_inches="tight")
plt.show()

#### 2.4.3 Analyze GP Enrichments

##### 4.3.1.2 Explore All GPs Across All Latent Clusters

In [None]:
run_differential_gp_analysis(analysis_label="latent_cluster",
                             model=model,
                             selected_gps=None,
                             selected_cats=None,
                             differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
                             n_top_up_gps=n_top_up_gps,
                             n_top_down_gps=n_top_down_gps,
                             physical_feature_space_sample="embryo3")

##### 4.3.1.2 Explore All GPs Across Query-enriched Latent Clusters

In [None]:
run_differential_gp_analysis(analysis_label="query_enriched_latent_cluster",
                             model=model,
                             selected_gps=None,
                             selected_cats=query_enriched_latent_clusters,
                             differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
                             n_top_up_gps=n_top_up_gps,
                             n_top_down_gps=n_top_down_gps,
                             physical_feature_space_sample="embryo3")

##### 4.3.1.3 Explore Specific GPs Across All Latent Clusters

In [None]:
run_differential_gp_analysis(analysis_label="latent_cluster_epor_gp",
                             model=model,
                             selected_gps=["EPOR_ligand_receptor_GP"],
                             selected_cats=None,
                             differential_gp_scores_key=latent_cluster_differential_gp_scores_key,
                             n_top_up_gps=n_top_up_gps,
                             n_top_down_gps=n_top_down_gps,
                             physical_feature_space_sample="embryo3")

## 3. Learn De-Novo CCI GPs

### 3.1 Initialize, Train & Save Model

In [None]:
# Load trained model
if load_timestamp is not None:
    model_artifacts_load_folder_path = f"../artifacts/{dataset}/batch_integration/{load_timestamp}"
else:
    model_artifacts_load_folder_path = model_artifacts_folder_path

model = Autotalker.load(dir_path=f"{model_artifacts_load_folder_path}/reference",
                        adata=None,
                        adata_file_name=f"{dataset}.h5ad",
                        n_addon_gps=10,
                        gp_names_key=gp_names_key,
                        genes_idx_key=genes_idx_key,
                        unfreeze_all_weights=False,
                        unfreeze_addon_gp_weights=True)

In [None]:
for param_name, param in model.model.named_parameters():
    print(param_name, param.requires_grad)

In [None]:
model.train(n_epochs=3,
            n_epochs_all_gps=3,
            lr=lr,
            lambda_edge_recon=1.,
            lambda_gene_expr_recon=0.001,
            lambda_l1_addon=0.003,
            verbose=True)

In [None]:
# Save trained model
model.save(dir_path=model_artifacts_folder_path + "/reference_query_denovo_gps",
           overwrite=True,
           save_adata=True,
           adata_file_name=f"{dataset}.h5ad")

In [None]:
#Load trained model
model = Autotalker.load(dir_path=model_artifacts_folder_path + "/reference_query_denovo_gps",
                        adata=None,
                        adata_file_name=f"{dataset}.h5ad",
                        gp_names_key=gp_names_key)

In [None]:
# Check number of active gene programs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
print(active_gps[-10:])

In [None]:
model.adata.var_names

In [None]:
addon_gp_summary_df = model.get_gp_summary()
addon_gp_summary_df[(addon_gp_summary_df["gp_name"].str.contains("addon")) & 
                    (addon_gp_summary_df["gp_active"])]

In [None]:
# Get active addon gps
active_addon_gp_names = [active_gp for active_gp in active_gps if "addon" in active_gp]

In [None]:
# Compute gene program enrichments and retrieve top up- and downregulated gene programs
# for the query-enriched latent clusters
top_unique_gps = model.compute_differential_gp_scores(cat_key=cell_type_key,
                                                      adata=model.adata,
                                                      selected_gps=active_addon_gp_names,
                                                      selected_cats=None,
                                                      gp_scores_weight_normalization=False,
                                                      comparison_cats="rest",
                                                      n_sample=1000,
                                                      key_added=denovo_cell_type_differential_gp_scores_key,
                                                      n_top_up_gps_retrieved=n_top_up_gps,
                                                      n_top_down_gps_retrieved=n_top_down_gps,
                                                      seed=random_seed)

In [None]:
# Display top upregulated gene programs
top_up_gp_df = model.adata.uns[denovo_cell_type_differential_gp_scores_key][:n_top_up_gps]
display(top_up_gp_df)

In [None]:
# Display top downregulated gene programs
top_down_gp_df = model.adata.uns[denovo_cell_type_differential_gp_scores_key][-n_top_down_gps:][::-1]
display(top_down_gp_df)

In [None]:
fig = sc.pl.dotplot(model.adata,
                    top_unique_gps,
                    groupby=cell_type_key,
                    dendrogram=True, 
                    title="Differential GP Scores Add-on GPs",
                    swap_axes=True,
                    return_fig=True)

# Save and display plot
fig.savefig(f"{figure_folder_path}/differential_gp_scores_addon_gps.png")
plt.show()

In [None]:
model.compute_gp_gene_importances(selected_gp=active_addon_gp_names[0])[:50]

In [None]:
model.compute_gp_gene_importances(selected_gp=active_addon_gp_names[1])[:50]

In [None]:
# Create subplot of cell-type annotations in physical and latent space
fig = plt.figure(figsize=(12, 20))
title = fig.suptitle(t="Autotalker: Query-enriched Cell Types in Physical and Latent Space",
                     y=0.96,
                     x=0.55,
                     fontsize=20)
spec = gridspec.GridSpec(ncols=1,
                         nrows=3,
                         width_ratios=[1],
                         height_ratios=[1, 1, 0.7])
spec2 = gridspec.GridSpec(ncols=3,
                          nrows=3,
                          width_ratios=[1, 1, 1],
                          height_ratios=[1, 1, 0.7])
ax1 = fig.add_subplot(spec[0])
ax2 = fig.add_subplot(spec[1])
ax3 = fig.add_subplot(spec2[6])
ax4 = fig.add_subplot(spec2[7])
ax5 = fig.add_subplot(spec2[8])
sc.pl.umap(adata=model.adata,
           color=["mapping_entity"],
           palette="Set1",
           size=240000/len(model.adata),
           title="Autotalker Latent Space",
           ax=ax1,
           show=False)
sc.pl.umap(adata=model.adata,
           color=[query_enriched_cell_type_key],
           palette="coolwarm",
           size=240000/len(model.adata),
           title="Autotalker Latent Space",
           ax=ax2,
           show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo1"],
              color=[query_enriched_cell_type_key],
              palette="coolwarm",
              spot_size=0.03,
              title="Physical Space Embryo 1",
              legend_loc=None,
              ax=ax3,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo2"],
              color=[query_enriched_cell_type_key],
              palette="coolwarm",
              spot_size=0.03,
              title="Physical Space Embryo 2",
              legend_loc=None,
              ax=ax4,
              show=False)
sc.pl.spatial(adata=model.adata[model.adata.obs["sample"] == "embryo3"],
              color=[query_enriched_cell_type_key],
              palette="coolwarm",
              spot_size=0.03,
              title="Physical Space Embryo 3",
              legend_loc=None,
              ax=ax5,
              show=False)

# Create and position shared legend
handles, labels = ax2.get_legend_handles_labels()
lgd = fig.legend(handles, labels, bbox_to_anchor=(1.25, 0.465))
ax2.get_legend().remove()

# Adjust, save and display plot
plt.subplots_adjust(wspace=0., hspace=0.2)
fig.savefig(f"{figure_folder_path}/query_enriched_cell_types_physical_latent_space.png",
            bbox_extra_artists=(lgd, title),
            bbox_inches="tight")
plt.show()