# Ablation

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 06.01.2023
- **Date of Last Modification:** 30.06.2023

## 1. Setup

### 1.1 Import Libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../../utils")

In [3]:
import os
import warnings
from datetime import datetime

import mlflow
import numpy as np
import pandas as pd
import pickle
import plottable
import scanpy as sc
import scib


from nichecompass.benchmarking import compute_benchmarking_metrics
from nichecompass.models import NicheCompass
from nichecompass.utils import (add_gps_from_gp_dict_to_adata,
                                create_new_color_dict,
                                extract_gp_dict_from_mebocost_es_interactions,
                                extract_gp_dict_from_nichenet_lrt_interactions,
                                extract_gp_dict_from_omnipath_lr_interactions,
                                filter_and_combine_gp_dict_gps)

from ablation_utils import *

Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  __import__("pkg_resources").declare_namespace(__name__)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  __import__("pkg_resources").declare_namespace(__name__)
  bokeh_bool_types += (np.bool8,)
  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()
  warn(


### 1.2 Define Parameters

In [4]:
latent_key = "nichecompass_latent"
spatial_key = "spatial"
latent_knng_key = "nichecompass_latent_knng"
spatial_knng_key = "spatial_knng"
gp_names_key = "nichecompass_gp_names"

### 1.4 Run Notebook Setup

In [5]:
sc.set_figure_params(figsize=(6, 6))

  IPython.display.set_matplotlib_formats(*ipython_format)


In [6]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=DeprecationWarning)

In [7]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

In [8]:
# Set mlflow tracking server (run it on the defined port)
mlflow.set_tracking_uri("http://localhost:8889")

### 1.5 Configure Paths and Directories

In [9]:
artifact_folder_path = f"../../artifacts"
miscellaneous_folder_path = f"{artifact_folder_path}/miscellaneous"

## 2. Ablation

### 2.1 Loss Weights (Edge Reconstruction & Gene Expression Reconstruction)

- Different combinations of the edge reconstruction loss and gene expression reconstruction loss weighting hyperparameters are tested.
- Number of neighbors of the spatial neighborhood graph is varied between ```4```, ```8```, ```12```, and ```16``` and for each dataset and 'n_neighbors', one run is performed with a fully connected gene program mask and one with a NicheCompass default GP mask, resulting in a total of 8 runs per loss weight combination.

In [10]:
ablation_task = "loss_weights"
datasets = ["xenium_human_breast_cancer", "starmap_plus_mouse_cns"]
cell_type_keys = ["cell_states", "Main_molecular_cell_type"]
condition_keys = [None, None]
experiment_ids = [3, 4]

In [40]:
# Retrieve metrics and params of ablation runs from mlflow and store in summary df
summary_df = pd.DataFrame()
for dataset, experiment_id in zip(datasets, experiment_ids):
    runs_info = mlflow.list_run_infos(experiment_id)
    for run_info in runs_info:
        run = mlflow.get_run(run_info.run_uuid)
        if run.info.status == "RUNNING":
            continue
        run_dict = {"dataset": dataset}
        run_dict["timestamp"] = run.data.params["timestamp"]
        run_dict["val_auroc_score"] = run.data.metrics.get("val_auroc_score", np.nan)
        run_dict["val_gene_expr_mse_score"] = run.data.metrics.get("val_gene_expr_mse_score", np.nan)
        run_dict["lambda_edge_recon_"] = run.data.params["lambda_edge_recon_"]
        run_dict["lambda_gene_expr_recon_"] = run.data.params["lambda_gene_expr_recon_"]
        run_dict["n_neighbors"] = run.data.params["n_neighbors"]
        run_dict["nichenet_keep_target_genes_ratio"] = run.data.params["nichenet_keep_target_genes_ratio"]
        run_dict["add_fc_gps_instead_of_gp_dict_gps"] = run.data.params["add_fc_gps_instead_of_gp_dict_gps"]
        run_df = pd.DataFrame(run_dict, index=[0])
        summary_df = pd.concat([summary_df, run_df], ignore_index=True)
summary_df["loss_weights"] = summary_df.apply(lambda row: get_loss_weights(row), axis=1)

for dataset in datasets:
    summary_df[summary_df["dataset"] == dataset].to_csv(f"{miscellaneous_folder_path}/mlflow_summary_{ablation_task}_ablation_{dataset}.csv")



In [43]:
summary_df.groupby("nichenet_keep_target_genes_ratio").mean()

Unnamed: 0_level_0,val_auroc_score,val_gene_expr_mse_score
nichenet_keep_target_genes_ratio,Unnamed: 1_level_1,Unnamed: 2_level_1
0.01,0.93491,1.767176
0.1,0.941705,1.760437
1.0,0.936999,1.816169


In [53]:
summary_df[(summary_df["dataset"] == "starmap_plus_mouse_cns") & (summary_df["nichenet_keep_target_genes_ratio"] == "0.1")].sort_values(["lambda_edge_recon_", "lambda_gene_expr_recon_"])

Unnamed: 0,dataset,timestamp,val_auroc_score,val_gene_expr_mse_score,lambda_edge_recon_,lambda_gene_expr_recon_,n_neighbors,nichenet_keep_target_genes_ratio,add_fc_gps_instead_of_gp_dict_gps,loss_weights
212,starmap_plus_mouse_cns,01072023_143216_3,0.524029,0.900167,0.0,0.0,12,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
213,starmap_plus_mouse_cns,01072023_142146_2,0.528775,0.962689,0.0,0.0,8,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
214,starmap_plus_mouse_cns,01072023_141752_1,0.550148,0.974185,0.0,0.0,4,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
148,starmap_plus_mouse_cns,03072023_135013_1003,0.994157,0.673327,0.0,30.0,16,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
149,starmap_plus_mouse_cns,03072023_132805_1002,0.993887,0.641392,0.0,30.0,12,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
150,starmap_plus_mouse_cns,03072023_131455_1001,0.991201,0.643315,0.0,30.0,8,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
151,starmap_plus_mouse_cns,03072023_131347_1000,0.969641,0.693365,0.0,30.0,4,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
144,starmap_plus_mouse_cns,03072023_171838_1007,0.996494,0.593698,0.0,300.0,16,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
145,starmap_plus_mouse_cns,03072023_171838_1006,0.996093,0.624603,0.0,300.0,12,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
146,starmap_plus_mouse_cns,03072023_141549_1005,0.991448,0.633484,0.0,300.0,8,0.1,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...


In [23]:
for dataset in datasets:
    summary_df[summary_df["dataset"] == dataset].to_csv(f"{miscellaneous_folder_path}/mlflow_summary_{ablation_task}_ablation_{dataset}.csv")

In [37]:
summary_df = pd.read_csv(f"{miscellaneous_folder_path}/mlflow_summary_{ablation_task}_ablation.csv", index_col=0)

In [38]:
summary_df

Unnamed: 0,dataset,timestamp,val_auroc_score,val_gene_expr_mse_score,lambda_edge_recon_,lambda_gene_expr_recon_,n_neighbors,add_fc_gps_instead_of_gp_dict_gps,loss_weights
0,xenium_human_breast_cancer,02072023_073546_2011,0.996472,1.660857,500000.0,300.0,16,False,lambda_edge_recon_500000.0_+_lambda_gene_expr_...
1,xenium_human_breast_cancer,02072023_072424_2010,0.995574,1.605819,500000.0,300.0,12,False,lambda_edge_recon_500000.0_+_lambda_gene_expr_...
2,xenium_human_breast_cancer,02072023_064949_2009,0.992328,1.533650,500000.0,300.0,8,False,lambda_edge_recon_500000.0_+_lambda_gene_expr_...
3,xenium_human_breast_cancer,02072023_064251_2008,0.979258,1.426696,500000.0,300.0,4,False,lambda_edge_recon_500000.0_+_lambda_gene_expr_...
4,xenium_human_breast_cancer,02072023_053747_2007,0.996632,1.799987,500000.0,30.0,16,False,lambda_edge_recon_500000.0_+_lambda_gene_expr_...
...,...,...,...,...,...,...,...,...,...
286,starmap_plus_mouse_cns,29062023_161622_7,0.997866,0.470623,0.0,30.0,8,True,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
287,starmap_plus_mouse_cns,29062023_161622_11,0.979904,0.499574,0.0,300.0,4,True,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
288,starmap_plus_mouse_cns,29062023_161622_12,0.997421,0.481453,0.0,300.0,8,True,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
289,starmap_plus_mouse_cns,29062023_161622_1,0.560831,0.707935,0.0,0.0,4,True,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...


In [33]:
summary_df = summary_df[summary_df["dataset"] == "xenium_human_breast_cancer"]

In [34]:
summary_df[46:].to_csv(f"{miscellaneous_folder_path}/mlflow_summary_loss_weights_ablation_xenium_human_breast_cancer.csv")

In [35]:
summary_df = pd.read_csv(f"{miscellaneous_folder_path}/mlflow_summary_loss_weights_ablation_xenium_human_breast_cancer.csv", index_col=0)

In [36]:
summary_df

Unnamed: 0,dataset,timestamp,val_auroc_score,val_gene_expr_mse_score,lambda_edge_recon_,lambda_gene_expr_recon_,n_neighbors,add_fc_gps_instead_of_gp_dict_gps,loss_weights
46,xenium_human_breast_cancer,01072023_174536_6,0.969510,1.547869,0.0,30.0,8,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
47,xenium_human_breast_cancer,01072023_174214_5,0.932853,1.402551,0.0,30.0,4,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
48,xenium_human_breast_cancer,01072023_173609_4,0.500000,5.073387,0.0,0.0,16,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
49,xenium_human_breast_cancer,01072023_173609_3,0.500000,5.118559,0.0,0.0,12,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
50,xenium_human_breast_cancer,01072023_162730_2,0.500000,5.130722,0.0,0.0,8,False,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
...,...,...,...,...,...,...,...,...,...
139,xenium_human_breast_cancer,28062023_164600_11,0.965506,2.606598,0.0,300.0,4,True,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
140,xenium_human_breast_cancer,28062023_164600_12,0.993967,2.656636,0.0,300.0,8,True,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
141,xenium_human_breast_cancer,28062023_164600_13,0.997351,2.624020,0.0,300.0,12,True,lambda_edge_recon_0.0_+_lambda_gene_expr_recon...
142,xenium_human_breast_cancer,28062023_164600_16,0.973018,4.571494,50000.0,0.0,4,True,lambda_edge_recon_50000.0_+_lambda_gene_expr_r...


In [14]:
# Compute metrics and add to summary df
metrics_df = pd.DataFrame()
for i, dataset in enumerate(datasets):
    # Get timestamps of ablation runs for specific dataset
    timestamps = summary_df[(summary_df["dataset"] == dataset) & (summary_df["val_auroc_score"].notnull())]["timestamp"].tolist()
    
    # Compute metrics for ablation runs models
    current_iteration_metrics_df = compute_metrics(
        artifact_folder_path=artifact_folder_path,
        dataset=dataset,
        task=ablation_task + "_ablation",
        timestamps=timestamps,
        cell_type_key=cell_type_keys[i],
        condition_key=condition_keys[i],
        spatial_knng_key=spatial_knng_key,
        latent_knng_key=latent_knng_key,
        spatial_key=spatial_key,
        latent_key=latent_key)
    metrics_df = pd.concat([metrics_df, current_iteration_metrics_df], axis=0)
summary_df = pd.merge(summary_df, metrics_df, on=["dataset", "timestamp"], how="left")

Loading xenium_human_breast_cancer model with timestamp 29062023_014049_44.
Computing metric CLISIS
Computing spatial nearest neighbor graph for entire dataset...




Computing spatial cell CLISI scores for entire dataset...
Computing latent nearest neighbor graph...
Computing latent cell CLISI scores...



KeyboardInterrupt



In [None]:
test_df = summary_df.dropna()

In [None]:
test_df

In [None]:
ablation_col = "loss_weights"
dataset_col = "dataset"
group_cols = [dataset_col] + [ablation_col]
metric_cols = ["cas", "mlami", "gcs", "cca"]
metric_cols_weights = [0.6, 0.6, 0.6, 0.4]
metric_score_cols = [metric_col + "_points" for metric_col in metric_cols]

# Compute ablation points
points_df = compute_ablation_points(
    df=test_df,
    group_col=ablation_col,
    metric_cols=metric_cols,
    metric_cols_weights=metric_cols_weights,
    sort_metric_col="total_score")

In [None]:
# Plot metrics
for dataset in datasets:
    dataset_df = points_df[points_df["dataset"] == dataset]
    plot_metrics(
        fig_title=f"Loss Weights Ablation: {dataset}",
        df=dataset_df,
        group_col="loss_weights",
        metric_cols=metric_cols,
        plot_ratio_active_gps=False,
        save_fig=False,
        file_name="ablation_metrics.png")

In [None]:
ablation_col = "loss_weights"
dataset_col = "dataset"
group_cols = [dataset_col] + [ablation_col]
metric_cols = ["cas", "mlami", "gcs", "cca"]
metric_cols_weights = [0.6, 0.6, 0.6, 0.4]
metric_score_cols = [metric_col + "_points" for metric_col in metric_cols]

In [None]:
# Unroll points df and compute means over group columns and score type
unrolled_points_df = pd.melt(points_df, 
   id_vars=group_cols,
   value_vars=metric_score_cols,
   var_name="score_type", 
   value_name="score")

# Compute metric means over all runs
mean_points_df = unrolled_points_df.groupby(group_cols + ["score_type"]).mean()
mean_points_df.reset_index(inplace=True)

In [None]:
mean_points_df

In [None]:
group_cols + ["score_type"]

In [None]:
group_cols

In [None]:
mean_points_df

In [None]:
plot_ablation_points(df=mean_points_df,
                     ablation_col=ablation_col,
                     ablation_col_width=7,
                     group_col=dataset_col,
                     metric_cols=metric_score_cols,
                     show=True,
                     save_dir=None,
                     save_name="ablation_results.svg")

In [None]:
visualize_latent_embeddings(artifact_folder_path=artifact_folder_path,
                            plot_label="test",
                            task=ablation_task + "_ablation",
                            timestamps=timestamps,
                            dataset="xenium_human_breast_cancer",
                            cat_key="cell_states",
                            sample_key="batch",
                            groups=None,
                            spot_size=30.,
                            save_fig=False)

### 2.2 Encoder Architecture (GCN vs GATv2)

#### 2.1.2 Categorical Covariates Contrastive Loss

### 2.5 Spatial Neighborhood Graph

In [None]:
visualize_niches(artifact_folder_path=artifact_folder_path,
                 dataset=dataset,
                 task=ablation_task + "_ablation",
                 timestamps=timestamps,
                 sample_key="batch",
                 latent_key=latent_key,
                 latent_leiden_resolution=0.2,
                 latent_cluster_key="nichecompass_latent_clusters",
                 spot_size=30.)

### 2.5 Gene Program Mask

In [31]:
"""
# Log additional mlflow param
run_ids = []
runs_info = mlflow.list_run_infos(3)
for run_info in runs_info:
    run = mlflow.get_run(run_info.run_uuid)
    if "add_fc_gps_instead_of_gp_dict_gps" not in run.data.params:
        run_ids.append(run_info.run_uuid)
        
for run_id in run_ids:
    with mlflow.start_run(run_id=run_id) as run:
        mlflow.log_param("add_fc_gps_instead_of_gp_dict_gps", True)
"""