# Autotalker Benchmarking

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 07.12.2022
- **Date of Last Modification:** 13.12.2022

## 1. Setup

### 1.1 Import Libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")

In [3]:
import itertools
import os
import random
import warnings
from datetime import datetime

import anndata as ad
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import scipy.sparse as sp
import seaborn as sns
import squidpy as sq

from autotalker.models import Autotalker
from autotalker.utils import (add_gps_from_gp_dict_to_adata,
                              extract_gp_dict_from_mebocost_es_interactions,
                              extract_gp_dict_from_nichenet_ligand_target_mx,
                              extract_gp_dict_from_omnipath_lr_interactions,
                              filter_and_combine_gp_dict_gps)

### 1.2 Configure Paths and Create Directories

In [4]:
# Define paths
figure_path = "../figures"
model_artefacts_path = "../model_artefacts"
gp_data_folder_path = "../datasets/gp_data/" # gene program data
srt_data_folder_path = "../datasets/srt_data/" # spatially resolved transcriptomics data
srt_data_gold_folder_path = f"{srt_data_folder_path}/gold"
nichenet_ligand_target_mx_file_path = gp_data_folder_path + "nichenet_ligand_target_matrix.csv"
omnipath_lr_interactions_file_path = gp_data_folder_path + "omnipath_lr_interactions.csv"

# Create required directories
os.makedirs(figure_path, exist_ok=True)
os.makedirs(model_artefacts_path, exist_ok=True)
os.makedirs("mlruns", exist_ok=True)
os.makedirs(gp_data_folder_path, exist_ok=True)

### 1.3 Run Notebook Setup

In [5]:
# Ignore future warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

### 1.4 Define Parameters

In [6]:
# AnnData object
dataset = "squidpy_seqfish_mouse_organogenesis"
counts_key = "counts"
cell_type_key = "celltype_mapped_refined"
adj_key = "spatial_connectivities"
spatial_key = "spatial"
gp_names_key = "autotalker_gp_names"
active_gp_names_key = "autotalker_active_gp_names"
gp_targets_mask_key = "autotalker_gp_targets"
gp_sources_mask_key = "autotalker_gp_sources"
latent_key = "autotalker_latent"

# Others
random_seed = 42

### 1.5 Define Functions

In [10]:
def run_hyperparam_benchmarking(dataset,
                                hyperparam_option_dict,
                                n_iters,
                                experiment_name,
                                gp_mask="combined_priors",
                                save_model=False):
    # Retrieve gene program mask
    print("--- GP MASK ---")
    print(f"Using '{gp_mask}' GP mask.")
    mlflow.log_param("gp_mask", gp_mask)
    if gp_mask == "combined_priors":
        nichenet_keep_target_ratio = 0.01
        omnipath_min_curation_effort = 0
        gp_filter_mode = "subset"
        combine_overlap_gps = True
        overlap_thresh_source_genes=0.9
        overlap_thresh_target_genes=0.9
        overlap_thresh_genes=0.9
        
        nichenet_gp_dict = extract_gp_dict_from_nichenet_ligand_target_mx(
            keep_target_ratio=nichenet_keep_target_ratio,
            load_from_disk=False,
            save_to_disk=True,
            file_path=nichenet_ligand_target_mx_file_path)
        omnipath_gp_dict = extract_gp_dict_from_omnipath_lr_interactions(
            min_curation_effort=omnipath_min_curation_effort,
            load_from_disk=False,
            save_to_disk=True,
            file_path=omnipath_lr_interactions_file_path)
        mebocost_gp_dict = extract_gp_dict_from_mebocost_es_interactions(
            dir_path = "../datasets/gp_data/metabolite_enzyme_sensor_gps/",
            species="mouse",
            genes_uppercase=True)
        combined_gp_dict = dict(nichenet_gp_dict)
        combined_gp_dict.update(omnipath_gp_dict)
        combined_gp_dict.update(mebocost_gp_dict)

        # Filter and combine gene programs
        combined_new_gp_dict = filter_and_combine_gp_dict_gps(
            gp_dict=combined_gp_dict,
            gp_filter_mode=gp_filter_mode,
            combine_overlap_gps=combine_overlap_gps,
            overlap_thresh_source_genes=overlap_thresh_source_genes,
            overlap_thresh_target_genes=overlap_thresh_target_genes,
            overlap_thresh_genes=overlap_thresh_genes,
            verbose=True)

        mlflow.log_param("nichenet_keep_target_ratio", nichenet_keep_target_ratio)
        mlflow.log_param("omnipath_min_curation_effort", omnipath_min_curation_effort)
        mlflow.log_param("gp_filter_mode", gp_filter_mode)
        mlflow.log_param("combine_overlap_gps", combine_overlap_gps)
        mlflow.log_param("overlap_thresh_source_genes", overlap_thresh_source_genes)
        mlflow.log_param("overlap_thresh_target_genes", overlap_thresh_target_genes)
        mlflow.log_param("overlap_thresh_genes", overlap_thresh_genes)
        print(f"Number of gene programs before filtering and combining: {len(combined_gp_dict)}.")
        print(f"Number of gene programs after filtering and combining: {len(combined_new_gp_dict)}.")
        print("")

    # Loop `n_iters` times through combination of hyperparams 
    iters = range(n_iters)
    hyperparams, hyperparam_values = zip(*hyperparam_option_dict.items())
    for hyperparam_comb in itertools.product(*hyperparam_values, iters):
        hyperparam_dict = {}
        for hyperparam, hyperparam_value in zip(hyperparams, hyperparam_comb):
            hyperparam_dict[hyperparam] = hyperparam_value
    
        experiment = mlflow.set_experiment(experiment_name)

        print("--- DATASET ---")
        print(f"Using dataset {dataset}.")
        mlflow.log_param("dataset", dataset)
        adata = ad.read_h5ad(f"{srt_data_gold_folder_path}/{dataset}.h5ad")
        n_nodes = adata.layers["counts"].shape[0]
        n_genes = adata.layers["counts"].shape[1]
        print(f"Number of nodes (cells): {n_nodes}")
        print(f"Number of node features (genes): {n_genes}")
        mlflow.log_param("n_nodes", n_nodes)
        mlflow.log_param("n_genes", n_genes)

        if gp_mask == "fc":
            if hyperparam_dict["node_label_method"] == "self":
                n_output = len(adata.var)
                adata.varm[gp_targets_mask_key] = np.ones((hyperparam_dict["n_latent_fc_gps"], n_output))
            elif hyperparam_dict["node_label_method"] != "self":
                n_output = len(adata.var) * 2
                adata.varm[gp_targets_mask_key] = np.ones((hyperparam_dict["n_latent_fc_gps"], int(n_output / 2)))
                adata.varm[gp_sources_mask_key] = np.ones((hyperparam_dict["n_latent_fc_gps"], int(n_output / 2)))
            adata.uns[gp_names_key] = np.array([f"FC_GP_{i}" for i in range(hyperparam_dict["n_latent_fc_gps"])])
            n_hidden_encoder = int(hyperparam_dict["n_latent_fc_gps"]/2)
        elif gp_mask == "combined_priors":
            min_source_genes_per_gp = 1
            min_target_genes_per_gp = 1
            mlflow.log_param("min_source_genes_per_gp", min_source_genes_per_gp)
            mlflow.log_param("min_target_genes_per_gp", min_target_genes_per_gp)            
            # Add the gene program dictionary as binary masks to the adata for model training
            add_gps_from_gp_dict_to_adata(
                gp_dict=combined_new_gp_dict,
                adata=adata,
                genes_uppercase=True,
                gp_targets_mask_key=gp_targets_mask_key,
                gp_sources_mask_key=gp_sources_mask_key,
                gp_names_key=gp_names_key,
                min_genes_per_gp=1,
                min_source_genes_per_gp=min_source_genes_per_gp,
                min_target_genes_per_gp=min_target_genes_per_gp,
                max_genes_per_gp=None,
                max_source_genes_per_gp=None,
                max_target_genes_per_gp=None)
            n_hidden_encoder = len(adata.uns[gp_names_key])
        # Summarize gene programs
        print(f"Number of gene programs with probed genes: {len(adata.uns['autotalker_gp_names'])}.")
        print(f"Example gene programs: {random.sample(list(adata.uns['autotalker_gp_names']), 5)}.")
        print(f"Number of gene program target genes: {adata.varm['autotalker_gp_targets'].sum()}.")
        print(f"Number of gene program source genes: {adata.varm['autotalker_gp_sources'].sum()}.")

        print("")
        print("--- SPATIAL CONNECTIVITY ---")
        # Compute spatial neighborhood
        sq.gr.spatial_neighbors(adata,
                                coord_type="generic",
                                spatial_key=spatial_key,
                                n_neighs=hyperparam_dict["n_neighs"])
        avg_edges_per_node = round(adata.obsp['spatial_connectivities'].sum(axis=0).mean(), 2)
        print(f"Average number of edges per node: {avg_edges_per_node}")
        n_edges = int(sp.triu(adata.obsp['spatial_connectivities'], k=1).sum())
        print(f"Number of total edges: {n_edges}", sep="")
        mlflow.log_param("n_neighbors", hyperparam_dict["n_neighs"])
        mlflow.log_param("n_edges", n_edges)

        # Initialize model
        print("")
        model = Autotalker(adata,
                           counts_key=counts_key,
                           adj_key=adj_key,
                           gp_names_key=gp_names_key,
                           active_gp_names_key=active_gp_names_key,
                           gp_targets_mask_key=gp_targets_mask_key,
                           gp_sources_mask_key=gp_sources_mask_key,
                           latent_key=latent_key,
                           include_edge_recon_loss=hyperparam_dict["include_edge_recon_loss"],
                           include_gene_expr_recon_loss=hyperparam_dict["include_gene_expr_recon_loss"],
                           gene_expr_recon_dist=hyperparam_dict["gene_expr_recon_dist"],
                           node_label_method=hyperparam_dict["node_label_method"],
                           active_gp_thresh_ratio=hyperparam_dict["active_gp_thresh_ratio"],
                           n_hidden_encoder=n_hidden_encoder,
                           conv_layer_encoder=hyperparam_dict["conv_layer_encoder"],
                           encoder_n_attention_heads=hyperparam_dict["encoder_n_attention_heads"],
                           n_addon_gps=hyperparam_dict["n_addon_gps"])

        # Train model
        print("")
        model.train(n_epochs=hyperparam_dict["n_epochs"],
                    n_epochs_all_gps=hyperparam_dict["n_epochs_all_gps"],
                    lr=hyperparam_dict["lr"],
                    lambda_edge_recon=hyperparam_dict["lambda_edge_recon"],
                    lambda_gene_expr_recon=hyperparam_dict["lambda_gene_expr_recon"],
                    lambda_group_lasso=hyperparam_dict["lambda_group_lasso"],
                    lambda_l1_addon=hyperparam_dict["lambda_l1_addon"],
                    mlflow_experiment_id=experiment.experiment_id,
                    verbose=True)

        print("")
        # Benchmark model
        benchmark_dict = model.run_benchmarks(
            adata=model.adata,
            cell_type_key=cell_type_key,
            spatial_key=spatial_key,
            spatial_knng_key="autotalker_spatial_knng",
            latent_knng_key="autotalker_latent_knng",
            n_neighbors=hyperparam_dict["n_neighs"],
            seed=random_seed,
            mlflow_experiment_id=experiment.experiment_id)
        print("--- BENCHMARKING RESULTS ---")
        print(benchmark_dict)

        if save_model:
            # Get time for timestamping saved artefacts
            now = datetime.now()
            current_timestamp = now.strftime("%d%m%Y_%H%M%S")
            
            model.save(dir_path=f"{model_artefacts_path}/{dataset}/benchmark_conv_layer_encoder/{current_timestamp}",
                       overwrite=True,
                       save_adata=True,
                       adata_file_name=f"{dataset}.h5ad")

        mlflow.end_run()
        print("--------------------")
        print("")
        print("--------------------")

In [11]:
def plot_benchmarking_metrics(fig_title,
                              df,
                              y_col_name,
                              save_fig=False,
                              save_dir="../figures",
                              file_name="benchmarking_metrics.png"):
    fig, axes = plt.subplots(3, 2, sharey=True, figsize=(10, 10))
    fig.suptitle(fig_title)
    
    # Graph Connectivity Distance
    sns.boxplot(data=df, ax=axes[0, 0], x="gcd", y=y_col_name)
    axes[0, 0].set_title("GCD")

    # Maximum Leiden Normalized Mutual Info
    sns.boxplot(data=df, ax=axes[0, 1], x="mlnmi", y=y_col_name)
    axes[0, 1].set_title("MLNMI")

    # Cell-Type Affinity Distance
    sns.boxplot(data=df, ax=axes[1, 0], x="cad", y=y_col_name)
    axes[1, 0].set_title("CAD")

    # Average Absolute Log Relative Cell-Type Local Inverse Simpson's Index
    sns.boxplot(data=df, ax=axes[1, 1], x="arclisi", y=y_col_name)
    axes[1, 1].set_title("ARCLISI")

    # Cell Classification Accuracy
    sns.boxplot(data=df, ax=axes[2, 0], x="cca", y=y_col_name)
    axes[2, 0].set_title("CCA")

    # Gene Expression Regression Mean Squared Error
    sns.boxplot(data=df, ax=axes[2, 1], x="germse", y=y_col_name)
    axes[2, 1].set_title("GERMSE")

    plt.subplots_adjust(left=0.1,
                        bottom=0.1,
                        right=0.9,
                        top=0.9,
                        wspace=0.2,
                        hspace=0.4)
    if save_fig:
        # Get time for timestamping saved artefacts
        now = datetime.now()
        current_timestamp = now.strftime("%d%m%Y_%H%M%S")

        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(f"{save_dir}/{file_name}_{current_timestamp}.png",
                    bbox_inches='tight')

## 2. Benchmarking Experiments

### 3.1. Benchmark Number of Neighbors

In [None]:
hyperparam_option_dict = {}
hyperparam_option_dict["n_neighs"] = [4, 8, 16, 32, 64]
hyperparam_option_dict["include_edge_recon_loss"] = [True]
hyperparam_option_dict["include_gene_expr_recon_loss"] = [True]
hyperparam_option_dict["active_gp_thresh_ratio"] = ["nb"]
hyperparam_option_dict["node_label_method"] = ["one-hop-norm"]
hyperparam_option_dict["active_gp_thresh_ratio"] = [0.]
hyperparam_option_dict["conv_layer_encoder"] = ["gcnconv"]
hyperparam_option_dict["encoder_n_attention_heads"] = [1]
hyperparam_option_dict["n_addon_gps"] = [0]
hyperparam_option_dict["n_epochs"] = [10]
hyperparam_option_dict["n_epochs_all_gps"] = [2]
hyperparam_option_dict["lr"] = [0.01]
hyperparam_option_dict["lambda_edge_recon"] = [1.]
hyperparam_option_dict["lambda_gene_expr_recon"] = [1.]
hyperparam_option_dict["lambda_group_lasso"] = [0]
hyperparam_option_dict["lambda_l1_addon"] = [0]

run_hyperparam_benchmarking(dataset=dataset,
                            hyperparam_option_dict=hyperparam_option_dict,
                            n_iters=5,
                            experiment_name="benchmark_n_neighs",
                            gp_mask="combined_priors",
                            save_model=False)

--- GP MASK ---
Using 'combined_priors' GP mask.
Downloading NicheNet ligand target potential matrix from the web. This might take a while...


In [None]:
experiment_name = "benchmark_loss_inclusions"

runs = mlflow.search_runs(experiment_names=[experiment_name],
                          output_format="list")

data = []
for run in runs:
    data.append({**run.data.metrics, **run.data.params})
df = pd.DataFrame.from_dict(data)

In [None]:
def get_recon_loss_inclusion(row):  
    if row["include_edge_recon_loss_"] == "True" and row["include_gene_expr_recon_loss_"] == "True":
        return "edge_+_gene_expr"
    elif row["include_edge_recon_loss_"] == "True" and row["include_gene_expr_recon_loss_"] == "False":
        return "only_edge"
    elif row["include_edge_recon_loss_"] == "False" and row["include_gene_expr_recon_loss_"] == "True":
        return "only_gene_expr"
    return "none"

df["recon_loss_inclusions"] = df.apply(lambda row: get_recon_loss_inclusion(row), axis=1)

In [None]:
fig, axes = plt.subplots(3, 2, sharey=True, figsize=(10, 10))
fig.suptitle("Reconstruction Loss Inclusion Benchmarking Metrics")

# Graph Connectivity Distance
sns.boxplot(data=df, ax=axes[0, 0], x="gcd", y="recon_loss_inclusions")
axes[0, 0].set_title("GCD")

# Maximum Leiden Normalized Mutual Info
sns.boxplot(data=df, ax=axes[0, 1], x="mlnmi", y="recon_loss_inclusions")
axes[0, 1].set_title("MLNMI")

# Cell-Type Affinity Distance
sns.boxplot(data=df, ax=axes[1, 0], x="cad", y="recon_loss_inclusions")
axes[1, 0].set_title("CAD")

# Average Absolute Log Relative Cell-Type Local Inverse Simpson's Index
sns.boxplot(data=df, ax=axes[1, 1], x="arclisi", y="recon_loss_inclusions")
axes[1, 1].set_title("ARCLISI")

# Cell Classification Accuracy
sns.boxplot(data=df, ax=axes[2, 0], x="cca", y="recon_loss_inclusions")
axes[2, 0].set_title("CCA")

# Gene Expression Regression Mean Squared Error
sns.boxplot(data=df, ax=axes[2, 1], x="germse", y="recon_loss_inclusions")
axes[2, 1].set_title("GERMSE")

plt.subplots_adjust(left=0.1,
                    bottom=0.1,
                    right=0.9,
                    top=0.9,
                    wspace=0.2,
                    hspace=0.4)

# Get time for timestamping saved artefacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

os.makedirs(f"{figure_path}/{dataset}", exist_ok=True)
plt.savefig(f"{figure_path}/{dataset}/benchmark_loss_inclusions_{current_timestamp}.png",
            bbox_inches='tight')

### 3.2. Benchmark Gene Expression Reconstruction Distribution with FC GPs

In [None]:
hyperparams = {}
hyperparams["n_neighs"] = [2] # [2, 4, 8, 16]
hyperparams["n_latent_fc_gps"] = [32, 64, 128, 256, 512] 
hyperparams["include_edge_recon_loss"] = [True]
hyperparams["lambda_edge_recon"] = [None]
hyperparams["include_gene_expr_recon_loss"] = [True]
hyperparams["lambda_gene_expr_recon"] = [1.]
hyperparams["node_label_method"] = ["self"]
hyperparams["gene_expr_recon_dist"] = ["nb", "zinb"] # <--- benchmark
hyperparams["conv_layer_encoder"] = ["gcnconv"]
n_epochs = 10
lr = 0.01

# Loop through combination of hyperparams
iters = range(1)
keys, values = zip(*hyperparams.items())
for hyperparam_config in itertools.product(*values, iters):
    
    n_neighs = hyperparam_dict["n_neighs"]
    n_latent_fc_gps = hyperparam_dict["n_neighs"]
    include_edge_recon_loss = hyperparam_dict["n_neighs"]
    lambda_edge_recon = hyperparam_dict["n_neighs"]
    include_gene_expr_recon_loss = hyperparam_dict["n_neighs"]
    lambda_gene_expr_recon = hyperparam_dict["n_neighs"]
    node_label_method = hyperparam_dict["n_neighs"]
    gene_expr_recon_dist = hyperparam_dict["n_neighs"]
    conv_layer_encoder = hyperparam_dict["n_neighs"]
    
    experiment = mlflow.set_experiment("benchmark_gene_expr_recon_dist")
    
    print("--- DATASET ---")
    adata = ad.read_h5ad(f"{srt_data_gold_folder_path}/{dataset}.h5ad")
    print(f"Using dataset {dataset}.")
    mlflow.log_param("dataset", dataset)
    n_nodes = adata.layers["counts"].shape[0]
    n_genes = adata.layers["counts"].shape[1]
    print(f"Number of nodes (cells): {n_nodes}")
    print(f"Number of node features (genes): {n_genes}")
    mlflow.log_param("n_nodes", n_nodes)
    mlflow.log_param("n_genes", n_genes)

    # Create fully-connected mask that allows all latent dims to reconstruct all genes
    if node_label_method == "self":
        n_output = len(adata.var)
        gp_targets_mask = np.ones((n_latent_fc_gps, n_output))
    elif node_label_method != "self":
        n_output = len(adata.var) * 2
        gp_targets_mask = np.ones((n_latent_fc_gps, int(n_output / 2)))
        gp_sources_mask = np.ones((n_latent_fc_gps, int(n_output / 2)))
    adata.uns[gp_names_key] = np.array([f"FC_GP_{i}" for i in range(n_latent_fc_gps)])

    # Determine dimensionality of hidden encoder
    n_hidden_encoder = int(n_latent_fc_gps/2)
    
    print("")
    print("--- SPATIAL CONNECTIVITIES STATS ---")
    # Compute spatial neighborhood
    sq.gr.spatial_neighbors(adata,
                            coord_type="generic",
                            spatial_key=spatial_key,
                            n_neighs=n_neighs)
    avg_edges_per_node = round(adata.obsp['spatial_connectivities'].sum(axis=0).mean(), 2)
    print(f"Average number of edges per node: {avg_edges_per_node}")
    n_edges = int(sp.triu(adata.obsp['spatial_connectivities'], k=1).sum())
    print(f"Number of total edges: {n_edges}", sep="")
    mlflow.log_param("n_neighbors", n_neighs)
    mlflow.log_param("n_edges", n_edges)

    print("")
    # Initialize model
    model = Autotalker(adata,
                       counts_key=counts_key,
                       adj_key=adj_key,
                       gp_names_key=gp_names_key,
                       active_gp_names_key=active_gp_names_key,
                       latent_key=latent_key,
                       include_edge_recon_loss=include_edge_recon_loss,
                       include_gene_expr_recon_loss=include_gene_expr_recon_loss,
                       gene_expr_recon_dist=gene_expr_recon_dist,
                       node_label_method=node_label_method,
                       n_hidden_encoder=n_hidden_encoder,
                       conv_layer_encoder=conv_layer_encoder,
                       gp_targets_mask=gp_targets_mask,
                       gp_sources_mask=(None if node_label_method == "self" else gp_sources_mask),
                       n_addon_gps=0)
    
    print("")
    # Train model
    model.train(n_epochs=n_epochs,
                lr=lr,
                lambda_edge_recon=lambda_edge_recon,
                lambda_gene_expr_recon=lambda_gene_expr_recon,
                mlflow_experiment_id=experiment.experiment_id,
                verbose=True)
    
    print("")
    # Benchmark model
    benchmark_dict = model.run_benchmarks(
        adata=model.adata,
        cell_type_key=cell_type_key,
        spatial_key=spatial_key,
        spatial_knng_key="autotalker_spatial_knng",
        latent_knng_key="autotalker_latent_knng",
        n_neighbors=n_neighs,
        seed=random_seed,
        mlflow_experiment_id=experiment.experiment_id)
    print("--- BENCHMARKING RESULTS ---")
    print(benchmark_dict)
    
    # Get time for timestamping saved artefacts
    now = datetime.now()
    current_timestamp = now.strftime("%d%m%Y_%H%M%S")
    #model.save(dir_path=f"{model_artefacts_path}/{dataset}/benchmark_conv_layer_encoder/{current_timestamp}",
    #           overwrite=True,
    #           save_adata=True,
    #           adata_file_name=f"{dataset}.h5ad")
    print("--------------------")
    print("")
    print("--------------------")
    
    mlflow.end_run()

In [None]:
experiment_name = "benchmark_gene_expr_recon_dist"

runs = mlflow.search_runs(experiment_names=[experiment_name],
                          output_format="list")

data = []
for run in runs:
    data.append({**run.data.metrics, **run.data.params})
df = pd.DataFrame.from_dict(data)

In [None]:
plot_benchmarking_metrics(
    fig_title="Gene Expression Reconstruction Distribution Benchmarking Metrics",
    df=df,
    y_col_name="gene_expr_recon_dist_",
    save_fig=False,
    save_dir=figure_path,
    file_name="benchmarking_gene_expr_recon_dist.png")

### 3.3. Benchmark Conv Layer of Encoder & Node Label Method with FC GP

In [None]:
hyperparams = {}
hyperparams["n_neighs"] = [16] # [2, 4, 8, 16]
hyperparams["n_latent_fc_gps"] = [32, 64, 128] 
hyperparams["include_edge_recon_loss"] = [True]
hyperparams["lambda_edge_recon"] = [None]
hyperparams["include_gene_expr_recon_loss"] = [True]
hyperparams["lambda_gene_expr_recon"] = [1.]
hyperparams["node_label_method"] = ["self",
                                    "one-hop-sum",
                                    "one-hop-norm",
                                    "one-hop-attention"] # <--- benchmark
hyperparams["gene_expr_recon_dist"] = ["nb"]
hyperparams["conv_layer_encoder"] = ["gcnconv", "gatv2conv"] # <--- benchmark
n_epochs = 10
lr = 0.01

# Loop through combination of hyperparams
iters = range(1)
keys, values = zip(*hyperparams.items())
for hyperparam_config in itertools.product(*values, iters):
    
    n_neighs = hyperparam_config[0]
    n_latent_fc_gps = hyperparam_config[1]
    include_edge_recon_loss = hyperparam_config[2]
    lambda_edge_recon = hyperparam_config[3]
    include_gene_expr_recon_loss = hyperparam_config[4]
    lambda_gene_expr_recon = hyperparam_config[5]
    node_label_method = hyperparam_config[6]
    gene_expr_recon_dist = hyperparam_config[7]
    conv_layer_encoder = hyperparam_config[8]
    
    experiment = mlflow.set_experiment("benchmark_conv_layer_and_node_label_method")
    
    print("--- DATASET ---")
    adata = ad.read_h5ad(f"{srt_data_gold_folder_path}/{dataset}.h5ad")
    print(f"Using dataset {dataset}.")
    mlflow.log_param("dataset", dataset)
    n_nodes = adata.layers["counts"].shape[0]
    n_genes = adata.layers["counts"].shape[1]
    print(f"Number of nodes (cells): {n_nodes}")
    print(f"Number of node features (genes): {n_genes}")
    mlflow.log_param("n_nodes", n_nodes)
    mlflow.log_param("n_genes", n_genes)

    # Create fully-connected mask that allows all latent dims to reconstruct all genes
    if node_label_method == "self":
        n_output = len(adata.var)
        gp_targets_mask = np.ones((n_latent_fc_gps, n_output))
    elif node_label_method != "self":
        n_output = len(adata.var) * 2
        gp_targets_mask = np.ones((n_latent_fc_gps, int(n_output / 2)))
        gp_sources_mask = np.ones((n_latent_fc_gps, int(n_output / 2)))
    adata.uns[gp_names_key] = np.array([f"FC_GP_{i}" for i in range(n_latent_fc_gps)])

    # Determine dimensionality of hidden encoder
    n_hidden_encoder = int(n_latent_fc_gps/2)
    
    print("")
    print("--- SPATIAL CONNECTIVITIES STATS ---")
    # Compute spatial neighborhood
    sq.gr.spatial_neighbors(adata,
                            coord_type="generic",
                            spatial_key=spatial_key,
                            n_neighs=n_neighs)
    avg_edges_per_node = round(adata.obsp['spatial_connectivities'].sum(axis=0).mean(), 2)
    print(f"Average number of edges per node: {avg_edges_per_node}")
    n_edges = int(sp.triu(adata.obsp['spatial_connectivities'], k=1).sum())
    print(f"Number of total edges: {n_edges}", sep="")
    mlflow.log_param("n_neighbors", n_neighs)
    mlflow.log_param("n_edges", n_edges)

    print("")
    # Initialize model
    model = Autotalker(adata,
                       counts_key=counts_key,
                       adj_key=adj_key,
                       gp_names_key=gp_names_key,
                       active_gp_names_key=active_gp_names_key,
                       latent_key=latent_key,
                       include_edge_recon_loss=include_edge_recon_loss,
                       include_gene_expr_recon_loss=include_gene_expr_recon_loss,
                       gene_expr_recon_dist=gene_expr_recon_dist,
                       node_label_method=node_label_method,
                       n_hidden_encoder=n_hidden_encoder,
                       conv_layer_encoder=conv_layer_encoder,
                       gp_targets_mask=gp_targets_mask,
                       gp_sources_mask=(None if node_label_method == "self" else gp_sources_mask),
                       n_addon_gps=0)
    
    print("")
    # Train model
    model.train(n_epochs=n_epochs,
                lr=lr,
                lambda_edge_recon=lambda_edge_recon,
                lambda_gene_expr_recon=lambda_gene_expr_recon,
                mlflow_experiment_id=experiment.experiment_id,
                verbose=True)
    
    print("")
    # Benchmark model
    benchmark_dict = model.run_benchmarks(
        adata=model.adata,
        cell_type_key=cell_type_key,
        spatial_key=spatial_key,
        spatial_knng_key="autotalker_spatial_knng",
        latent_knng_key="autotalker_latent_knng",
        n_neighbors=n_neighs,
        seed=random_seed,
        mlflow_experiment_id=experiment.experiment_id)
    print("--- BENCHMARKING RESULTS ---")
    print(benchmark_dict)
    
    # Get time for timestamping saved artefacts
    now = datetime.now()
    current_timestamp = now.strftime("%d%m%Y_%H%M%S")
    #model.save(dir_path=f"{model_artefacts_path}/{dataset}/benchmark_conv_layer_encoder/{current_timestamp}",
    #           overwrite=True,
    #           save_adata=True,
    #           adata_file_name=f"{dataset}.h5ad")
    print("--------------------")
    print("")
    print("--------------------")
    
    mlflow.end_run()

In [None]:
hyperparams = {}
hyperparams["n_neighs"] = [4]
hyperparams["n_latent_fc_gps"] = [256, 512]
# n_neighs=64 combined with n_latent_fc_gps=512 -> OOM 
hyperparams["include_edge_recon_loss"] = [True]
hyperparams["lambda_edge_recon"] = [None]
hyperparams["include_gene_expr_recon_loss"] = [True]
hyperparams["lambda_gene_expr_recon"] = [1.]
hyperparams["node_label_method"] = ["self",
                                    "one-hop-sum",
                                    "one-hop-norm",
                                    "one-hop-attention"]
hyperparams["gene_expr_recon_dist"] = ["nb",
                                       "zinb"]
n_epochs = 10
lr = 0.01

# Loop through combination of hyperparams
iters = range(1)
keys, values = zip(*hyperparams.items())
for hyperparam_config in itertools.product(*values, iters):
    
    n_neighs = hyperparam_config[0]
    n_latent_fc_gps = hyperparam_config[1]
    include_edge_recon_loss = hyperparam_config[2]
    lambda_edge_recon = hyperparam_config[3]
    include_gene_expr_recon_loss = hyperparam_config[4]
    lambda_gene_expr_recon = hyperparam_config[5]
    node_label_method = hyperparam_config[6]
    gene_expr_recon_dist = hyperparam_config[7]
    
    experiment = mlflow.set_experiment("benchmark_gene_expr_recon")
    
    # Get time for timestamping saved artefacts
    now = datetime.now()
    current_timestamp = now.strftime("%d%m%Y_%H%M%S")
    
    print("--- DATASET ---")
    adata = ad.read_h5ad(f"{srt_data_gold_folder_path}/{dataset}.h5ad")
    print(f"Using dataset {dataset}.")
    mlflow.log_param("dataset", dataset)
    n_nodes = adata.layers["counts"].shape[0]
    n_genes = adata.layers["counts"].shape[1]
    print(f"Number of nodes (cells): {n_nodes}")
    print(f"Number of node features (genes): {n_genes}")
    mlflow.log_param("n_nodes", n_nodes)
    mlflow.log_param("n_genes", n_genes)

    # Create fully-connected mask that allows all latent dims to reconstruct all genes
    if node_label_method == "self":
        n_output = len(adata.var)
        gp_targets_mask = np.ones((n_latent_fc_gps, n_output))
    elif node_label_method != "self":
        n_output = len(adata.var) * 2
        gp_targets_mask = np.ones((n_latent_fc_gps, int(n_output / 2)))
        gp_sources_mask = np.ones((n_latent_fc_gps, int(n_output / 2)))
    adata.uns[gp_names_key] = np.array([f"FC_GP_{i}" for i in range(n_latent_fc_gps)])

    # Determine dimensionality of hidden encoder
    n_hidden_encoder = int(n_latent_fc_gps/2)
    
    print("")
    print("--- SPATIAL CONNECTIVITIES STATS ---")
    # Compute spatial neighborhood
    sq.gr.spatial_neighbors(adata,
                            coord_type="generic",
                            spatial_key=spatial_key,
                            n_neighs=n_neighs)
    avg_edges_per_node = round(adata.obsp['spatial_connectivities'].sum(axis=0).mean(), 2)
    print(f"Average number of edges per node: {avg_edges_per_node}")
    n_edges = int(sp.triu(adata.obsp['spatial_connectivities'], k=1).sum())
    print(f"Number of total edges: {n_edges}", sep="")
    mlflow.log_param("n_neighbors", n_neighs)
    mlflow.log_param("n_edges", n_edges)

    print("")
    # Initialize model
    model = Autotalker(adata,
                       counts_key=counts_key,
                       adj_key=adj_key,
                       gp_names_key=gp_names_key,
                       active_gp_names_key=active_gp_names_key,
                       latent_key=latent_key,
                       include_edge_recon_loss=include_edge_recon_loss,
                       include_gene_expr_recon_loss=include_gene_expr_recon_loss,
                       gene_expr_recon_dist=gene_expr_recon_dist,
                       node_label_method=node_label_method,
                       n_hidden_encoder=n_hidden_encoder,
                       gp_targets_mask=gp_targets_mask,
                       gp_sources_mask=(None if node_label_method == "self" else gp_sources_mask),
                       n_addon_gps=0)
    
    print("")
    # Train model
    model.train(n_epochs=n_epochs,
                lr=lr,
                lambda_edge_recon=lambda_edge_recon,
                lambda_gene_expr_recon=lambda_gene_expr_recon,
                mlflow_experiment_id=experiment.experiment_id,
                verbose=True)
    
    print("")
    # Benchmark model
    benchmark_dict = model.run_benchmarks(
        adata=model.adata,
        cell_type_key=cell_type_key,
        spatial_key=spatial_key,
        spatial_knng_key="autotalker_spatial_knng",
        latent_knng_key="autotalker_latent_knng",
        n_neighbors=n_neighs,
        seed=random_seed,
        mlflow_experiment_id=experiment.experiment_id)
    print("--- BENCHMARKING RESULTS ---")
    print(benchmark_dict)
    
    print("--------------------")
    print("")
    print("--------------------")
    model.save(dir_path=f"{model_artefacts_path}/{dataset}/benchmark_gene_expr_recon/{current_timestamp}",
               overwrite=True,
               save_adata=True,
               adata_file_name=f"{dataset}.h5ad")
    
    mlflow.end_run()

In [None]:
model.adata.obsm['autotalker_latent']

In [None]:
model.get_active_gps()

In [None]:
model.active_gp_thresh_ratio_

In [None]:
arclisi = compute_arclisi(model.adata, cell_type_key="celltype_mapped_refined")