# Autotalker

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 13.01.2023
- **Date of Last Modification:** 22.02.2023

- The Autotalker source code is available at https://github.com/Talavera-Lopez-Lab/autotalker.
- The workflow of this notebook follows the tutorial from https://github.com/sebastianbirk/autotalker/blob/main/notebooks/autotalker_tutorial.ipynb.
- It is recommended to use raw counts as input to Autotalker. Therefore, we use raw counts (stored in adata.X).

## 1. Setup

### 1.1 Import Libraries

In [None]:
import sys
sys.path.append("../../../autotalker")

In [None]:
import argparse
import os
import random
import time
import warnings
from copy import deepcopy
from datetime import datetime

import anndata as ad
import matplotlib
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import seaborn as sns
import squidpy as sq
import torch
from matplotlib.pyplot import rc_context

from autotalker.models import Autotalker
from autotalker.utils import (add_gps_from_gp_dict_to_adata,
                              extract_gp_dict_from_mebocost_es_interactions,
                              extract_gp_dict_from_nichenet_ligand_target_mx,
                              extract_gp_dict_from_omnipath_lr_interactions,
                              filter_and_combine_gp_dict_gps)

### 1.2 Define Parameters

In [None]:
model_name = "autotalker"
latent_key = f"{model_name}_latent"
leiden_resolution = 0.3 # used for Leiden clustering of latent space
random_seed = 0 # used for Leiden clustering

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Directories

In [None]:
data_folder_path = "../../datasets/srt_data/gold/"
figure_folder_path = f"../../figures"
gp_data_folder_path = "../../datasets/gp_data" # gene program data
nichenet_ligand_target_mx_file_path = gp_data_folder_path + "/nichenet_ligand_target_matrix.csv"
omnipath_lr_interactions_file_path = gp_data_folder_path + "/omnipath_lr_interactions.csv"

# Create required directories
os.makedirs(gp_data_folder_path, exist_ok=True)

## 2. Autotalker Model

### 2.1 Prepare Gene Program Mask

In [None]:
nichenet_gp_dict = extract_gp_dict_from_nichenet_ligand_target_mx(
    keep_target_ratio=0.01,
    load_from_disk=False,
    save_to_disk=False,
    file_path=nichenet_ligand_target_mx_file_path)

In [None]:
omnipath_gp_dict = extract_gp_dict_from_omnipath_lr_interactions(
    min_curation_effort=0,
    load_from_disk=False,
    save_to_disk=False,
    file_path=omnipath_lr_interactions_file_path)

In [None]:
mebocost_gp_dict = extract_gp_dict_from_mebocost_es_interactions(
    dir_path=f"{gp_data_folder_path}/metabolite_enzyme_sensor_gps/",
    species="mouse",
    genes_uppercase=True)

In [None]:
# Combine gene programs into one dictionary
combined_gp_dict = dict(nichenet_gp_dict)
combined_gp_dict.update(omnipath_gp_dict)
combined_gp_dict.update(mebocost_gp_dict)

In [None]:
# Filter and combine gene programs
combined_new_gp_dict = filter_and_combine_gp_dict_gps(
    gp_dict=combined_gp_dict,
    gp_filter_mode="subset", #None,
    combine_overlap_gps=True, #True,
    overlap_thresh_source_genes=0.9,
    overlap_thresh_target_genes=0.9,
    overlap_thresh_genes=0.9,
    verbose=True)

print(f"Number of gene programs before filtering and combining: {len(combined_gp_dict)}.")
print(f"Number of gene programs after filtering and combining: {len(combined_new_gp_dict)}.")

### 2.2 Define Training Function

In [None]:
def train_autotalker_models(dataset,
                            cell_type_key,
                            n_runs=10,
                            n_neighbor_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20]):
    # Configure figure folder path
    dataset_figure_folder_path = f"{figure_folder_path}/{dataset}/method_benchmarking/" \
                                 f"{model_name}/{current_timestamp}"
    os.makedirs(dataset_figure_folder_path, exist_ok=True)
    
    # Create original adata to store results from training runs
    adata_original = sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")
    
    model_seeds = list(range(10))
    for run_number, n_neighbors in zip(np.arange(1, n_runs+1), n_neighbor_list):
        # Load data
        adata = sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")
        
        adata.obs["batch"] == "batch1"
        
        # Add the gene program dictionary as binary masks to the adata for model training
        add_gps_from_gp_dict_to_adata(
            gp_dict=combined_new_gp_dict,
            adata=adata,
            genes_uppercase=True,
            gp_targets_mask_key="autotalker_gp_targets",
            gp_sources_mask_key="autotalker_gp_sources",
            gp_names_key="autotalker_gp_names",
            min_genes_per_gp=1,
            min_source_genes_per_gp=0,
            min_target_genes_per_gp=0,
            max_genes_per_gp=None,
            max_source_genes_per_gp=None,
            max_target_genes_per_gp=None)

        # Determine dimensionality of hidden encoder
        n_hidden_encoder = len(adata.uns[f"{model_name}_gp_names"])

        # Compute spatial neighborhood graph
        sq.gr.spatial_neighbors(adata,
                                coord_type="generic",
                                spatial_key="spatial",
                                n_neighs=n_neighbors)

        # Make adjacency matrix symmetric
        adata.obsp["spatial_connectivities"] = (
            adata.obsp["spatial_connectivities"].maximum(
            adata.obsp["spatial_connectivities"].T))

        start_time = time.time()
        
        # Initialize model
        model = Autotalker(adata,
                           counts_key="counts",
                           adj_key="spatial_connectivities",
                           condition_key="batch",
                           cond_embed_injection=["encoder",
                                                 "gene_expr_decoder",
                                                 "graph_decoder"],
                           n_cond_embed=3,
                           gp_names_key="autotalker_gp_names",
                           active_gp_names_key="autotalker_active_gp_names",
                           gp_targets_mask_key="autotalker_gp_targets",
                           gp_sources_mask_key="autotalker_gp_sources",
                           latent_key=latent_key,
                           active_gp_thresh_ratio=0.03,
                           gene_expr_recon_dist="nb",
                           n_layers_encoder=2,
                           n_hidden_encoder=n_hidden_encoder,
                           log_variational=True)
        print("")

        # Train model
        model.train(n_epochs=40,
                    n_epochs_all_gps=20,
                    lr=0.001,
                    lambda_edge_recon=0.01,
                    lambda_gene_expr_recon=0.0033,
                    lambda_l1_masked=0,
                    edge_batch_size=64,
                    node_batch_size=8,
                    seed=model_seeds[run_number-1],
                    verbose=True)
        
        # Measure time for model training
        end_time = time.time()
        elapsed_time = end_time - start_time
        hours, rem = divmod(elapsed_time, 3600)
        minutes, seconds = divmod(rem, 60)
        print(f"Duration of model training in run {run_number}: "
              f"{int(hours)} hours, {int(minutes)} minutes and {int(seconds)} seconds.")
        adata_original.uns[f"{model_name}_model_training_duration_run{run_number}"] = (
            elapsed_time)

        # Use Autotalker latent space for UMAP generation
        sc.pp.neighbors(adata,
                        use_rep=latent_key,
                        n_neighbors=n_neighbors)
        sc.tl.umap(adata)
        fig = sc.pl.umap(adata,
                         color=[cell_type_key],
                         title=f"Latent Space with Cell Types: {model_name.capitalize()}",
                         return_fig=True)
        fig.savefig(f"{dataset_figure_folder_path}/latent_{model_name}"
                    f"_cell_types_run{run_number}.png",
                    bbox_inches="tight")

        # Compute latent Leiden clustering
        sc.tl.leiden(adata=adata,
                     resolution=leiden_resolution,
                     random_state=random_seed,
                     key_added=f"latent_{model_name}_leiden_{str(leiden_resolution)}")

        # Create subplot of latent Leiden cluster annotations in physical and latent space
        fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(6, 12))
        title = fig.suptitle(t="Latent and Physical Space with Leiden Clusters: "
                               f"{model_name.capitalize()}")
        sc.pl.umap(adata=adata,
                   color=[f"latent_{model_name}_leiden_{str(leiden_resolution)}"],
                   title=f"Latent Space with Leiden Clusters",
                   ax=axs[0],
                   show=False)
        sq.pl.spatial_scatter(adata=adata,
                              color=[f"latent_{model_name}_leiden_{str(leiden_resolution)}"],
                              title=f"Physical Space with Leiden Clusters",
                              shape=None,
                              ax=axs[1])

        # Create and position shared legend
        handles, labels = axs[0].get_legend_handles_labels()
        lgd = fig.legend(handles, labels, bbox_to_anchor=(1.25, 0.9185))
        axs[0].get_legend().remove()
        axs[1].get_legend().remove()

        # Adjust, save and display plot
        plt.subplots_adjust(wspace=0, hspace=0.2)
        fig.savefig(f"{dataset_figure_folder_path}/latent_physical_comparison_"
                    f"{model_name}_run{run_number}.png",
                    bbox_extra_artists=(lgd, title),
                    bbox_inches="tight")
        plt.show()

        # Store latent representation
        adata_original.obsm[latent_key + f"_run{run_number}"] = adata.obsm[latent_key]

        # Store active gene programs
        adata_original.uns[f"{model_name}_active_gp_names_run{run_number}"] = (
            adata.uns[f"{model_name}_active_gp_names"])

        # Store intermediate adata to disk
        adata_original.write(f"{data_folder_path}/{dataset}_{model_name}.h5ad")

    # Store final adata to disk
    adata_original.write(f"{data_folder_path}/{dataset}_{model_name}.h5ad")    

### 2.3 Train Models on Benchmarking Datasets

In [None]:
train_autotalker_models(dataset="seqfish_mouse_organogenesis_embryo2",
                        cell_type_key="celltype_mapped_refined")

In [None]:
train_autotalker_models(dataset="vizgen_merfish_mouse_liver",
                        cell_type_key="Cell_Type",
                        n_runs=1,
                        n_neighbor_list=[20])

In [None]:
train_autotalker_models(dataset="starmap_plus_mouse_cns",
                        cell_type_key="Main_molecular_cell_type")

In [None]:
train_autotalker_models(dataset="nanostring_cosmx_human_nsclc",
                        cell_type_key="cell_type")