# Real Data Model Ablation

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>)
- **Date of Creation:** 08.10.2024
- **Date of Last Modification:** 09.10.2024

## 1. Setup

Run this notebook in the nichecompass-reproducibility environment, installable from ```('../../../envs/environment.yaml')```.

Before running this notebook:
- Clone SDMBench from https://github.com/zhaofangyuan98/SDMBench.git into ```('../benchmarking')``` (some slight modifications to the SDMBench source code were necessary to remove technical bugs).
    - Move _compute_CHAOS function into compute_CHAOS
    - Move _compute_PAS function into compute_PAS
    - Move fx_kNN function into compute_PAS
    - Move fx_1NN function into compute_CHAOS

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../benchmarking/SDMBench/SDMBench")
sys.path.append("../../utils")

In [None]:
import gc
import math
import os
import warnings

import anndata as ad
import scanpy as sc
import scib_metrics
import squidpy as sq

from nichecompass.benchmarking import compute_benchmarking_metrics

from SDMBench import sdmbench
from benchmarking_utils import *

### 1.2 Run Notebook Setup

In [None]:
warnings.filterwarnings("ignore")
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

### 1.3 Configure Paths

In [None]:
artifacts_folder_path = f"../../artifacts"

### 1.4 Define Functions

In [None]:
def compute_metrics(run_dict,
                    hyperparam_dict,
                    job_ids=[]):
    for subdir, dirs, files in os.walk(model_folder_path):
        for file in files:
            if file == f"{dataset_name}_{ablation_task}.h5ad":
                file_path = os.path.join(subdir, file)
                print(f"Loading file: {file_path}")
                adata = ad.read_h5ad(file_path)
                job_id = int(subdir[-2:].strip("_"))
                
                if job_id not in job_ids:
                    run_dict["dataset"].append(dataset_name)
                    run_dict["job_id"].append(job_id)
                    run_dict["seed"].append(((job_id - 1) % 8))
                    for key, values in hyperparam_dict.items():
                        run_dict[key].append(values[(math.floor((job_id - 1) / 8))])

                    metrics = scib_metrics.nmi_ari_cluster_labels_kmeans(
                        adata.obsm["nichecompass_latent"],
                        adata.obs["Main_molecular_tissue_region"])

                    run_dict["nnmi"].append(metrics["nmi"])
                    run_dict["nari"].append(metrics["ari"])

                    #benchmark_dict = compute_benchmarking_metrics(
                    #        adata=adata,
                    #        metrics=benchmarking_metrics,
                    #        cell_type_key=cell_type_key,
                    #        batch_key=batch_key,
                    #        spatial_key=spatial_key,
                    #        latent_key=latent_key,
                    #        n_jobs=1,
                    #        seed=0,
                    #        mlflow_experiment_id=None)
                    #
                    #for key, value in benchmark_dict.items():
                    #    run_dict[key].append(value)

                    """
                    counter = 0
                    while True:
                        sc.tl.leiden(adata=adata,
                                     resolution=latent_leiden_resolution,
                                     key_added="pred_niche_types",
                                     neighbors_key=latent_key)

                        niche_counts = adata.obs["pred_niche_types"].value_counts()
                        valid_niches = niche_counts[niche_counts >= valid_niche_thresh].index
                        n_niches = adata.obs[adata.obs["pred_niche_types"].isin(valid_niches)]["pred_niche_types"].nunique()
                        print(f"Current number of niches: {n_niches}")
                        print(f"Cluster counter: {counter}")
                        if n_niches == 13:
                            break
                        elif n_niches < 12 and counter < 30:
                            print("Big increase of clustering resolution...")
                            latent_leiden_resolution += leiden_resolution_increments
                        elif n_niches < 13 and counter < 60:
                            print("Slight increase of clustering resolution...")
                            latent_leiden_resolution += leiden_resolution_increments/10
                        elif n_niches > 14 and counter < 30:
                            print("Big decrease of clustering resolution...")
                            latent_leiden_resolution -= leiden_resolution_increments
                        elif n_niches > 13 and counter < 60:
                            print("Slight decrease of clustering resolution...")
                            latent_leiden_resolution -= leiden_resolution_increments/10
                        elif counter > 60:
                            break
                        counter += 1

                    latent_cluster_colors = create_new_color_dict(
                        adata=adata,
                        cat_key="pred_niche_types")

                    # Create plot of latent cluster / niche annotations in physical and latent space
                    groups = None # set this to a specific cluster for easy visualization
                    save_fig = False

                    adata.obs[sample_key] = "batch1"
                    samples = adata.obs[sample_key].unique().tolist()

                    fig = plt.figure(figsize=(12, 14))
                    title = fig.suptitle(t=f"NicheCompass Niches " \
                                           "in Latent and Physical Space",
                                         y=0.96,
                                         x=0.55,
                                         fontsize=20)
                    spec1 = gridspec.GridSpec(ncols=1,
                                              nrows=2,
                                              width_ratios=[1],
                                              height_ratios=[3, 2])
                    spec2 = gridspec.GridSpec(ncols=len(samples),
                                              nrows=2,
                                              width_ratios=[1] * len(samples),
                                              height_ratios=[3, 2])
                    axs = []
                    axs.append(fig.add_subplot(spec1[0]))
                    sc.pl.umap(adata=adata,
                               color=["pred_niche_types"],
                               groups=groups,
                               palette=latent_cluster_colors,
                               title=f"Niches in Latent Space",
                               ax=axs[0],
                               show=False)
                    for idx, sample in enumerate(samples):
                        axs.append(fig.add_subplot(spec2[len(samples) + idx]))
                        sc.pl.spatial(adata=adata[adata.obs[sample_key] == sample],
                                      color=["pred_niche_types"],
                                      groups=groups,
                                      palette=latent_cluster_colors,
                                      spot_size=spot_size,
                                      title=f"Niches in Physical Space \n"
                                            f"(Sample: {sample})",
                                      legend_loc=None,
                                      ax=axs[idx+1],
                                      show=False)

                    # Create and position shared legend
                    handles, labels = axs[0].get_legend_handles_labels()
                    lgd = fig.legend(handles,
                                     labels,
                                     loc="center left",
                                     bbox_to_anchor=(0.98, 0.5))
                    axs[0].get_legend().remove()

                    # Adjust, save and display plot
                    plt.subplots_adjust(wspace=0.2, hspace=0.25)
                    if save_fig:
                        fig.savefig(file_path,
                                    bbox_extra_artists=(lgd, title),
                                    bbox_inches="tight")
                    plt.show()

                    run_dict["ari"].append(sdmbench.compute_ARI(
                        adata,
                        cell_type_key,
                        "pred_niche_types"))
                    run_dict["nmi"].append(sdmbench.compute_NMI(
                        adata,
                        cell_type_key,
                        "pred_niche_types"))
                    run_dict["chaos"].append(sdmbench.compute_CHAOS(
                        adata,
                        "pred_niche_types"))
                    run_dict["pas"].append(sdmbench.compute_PAS(
                        adata,
                        "pred_niche_types",
                        spatial_key="spatial"))
                    run_dict["asw"].append(sdmbench.compute_ASW(
                        adata,
                        "pred_niche_types",
                        spatial_key="spatial"))
                    run_dict["hom"].append(sdmbench.compute_HOM(
                        adata,
                        cell_type_key,
                        "pred_niche_types"))
                    run_dict["com"].append(sdmbench.compute_COM(
                        adata,
                        cell_type_key,
                        "pred_niche_types"))
                    """

                    del(adata)
                    gc.collect()

                    run_df = pd.DataFrame(run_dict)
                    run_df.to_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv", index=False)
                else:
                    continue

## 2. Compute Metrics

### 2.1 Loss Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "loss_ablation"
model_folder_path = f"{artifacts_folder_path}/{dataset_name}/models/{ablation_task}"

benchmarking_metrics = ["cas", "mlami", "clisis", "gcs", "cnmi", "nasw"]
#sdm_benchmarking_metrics = ["ari", "nmi", "chaos", "pas", "asw", "hom", "com"]
cell_type_key = "Main_molecular_tissue_region"
batch_key = "batch"
spatial_key = "spatial"
latent_key = "nichecompass_latent"
sample_key = "batch"
spot_size = 0.1

#latent_leiden_resolution = 0.05
#leiden_resolution_increments = 0.01
#valid_niche_thresh = 100
#latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

hyperparam_dict = {"lambda_edge_recon": [500000., 0., 500000., 50000., 500000.],
                   "lambda_gene_expr_recon": [300., 300., 0., 300., 30]}

run_dict = {"dataset": [],
            "job_id": [],
            "seed": [],
            "nnmi": [],
            "nari": []}
for hyperparam in hyperparam_dict.keys():
    run_dict[hyperparam] = []
for metric in benchmarking_metrics:
    run_dict[metric] = []
#for metric in sdm_benchmarking_metrics:
#    run_dict[metric] = []

In [None]:
compute_metrics(run_dict,
                hyperparam_dict)

### 2.2 Loss Ablation Extended

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "loss_extended_ablation"
model_folder_path = f"{artifacts_folder_path}/{dataset_name}/models/{ablation_task}"

benchmarking_metrics = ["cas", "mlami", "clisis", "gcs", "cnmi", "nasw"]
#sdm_benchmarking_metrics = ["ari", "nmi", "chaos", "pas", "asw", "hom", "com"]
cell_type_key = "Main_molecular_tissue_region"
batch_key = "batch"
spatial_key = "spatial"
latent_key = "nichecompass_latent"
sample_key = "batch"
spot_size = 0.1

#latent_leiden_resolution = 0.05
#leiden_resolution_increments = 0.01
#valid_niche_thresh = 100
#latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

hyperparam_dict = {"lambda_l1_masked": [0., 0., 0., 0., 3., 30., 300.],
                   "lambda_l1_addon": [0., 3., 30., 300., 3., 30., 300.]}

run_dict = {"dataset": [],
            "job_id": [],
            "seed": [],
            "nnmi": [],
            "nari": []}
for hyperparam in hyperparam_dict.keys():
    run_dict[hyperparam] = []
for metric in benchmarking_metrics:
    run_dict[metric] = []
#for metric in sdm_benchmarking_metrics:
#    run_dict[metric] = []

In [None]:
compute_metrics(run_dict,
                hyperparam_dict)

In [None]:
job_ids = run_df["job_id"].tolist()

In [None]:
job_ids

### 2.3 Encoder Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "encoder_ablation"
model_folder_path = f"{artifacts_folder_path}/{dataset_name}/models/{ablation_task}"

benchmarking_metrics = ["cas", "mlami", "clisis", "gcs", "cnmi", "nasw"]
#sdm_benchmarking_metrics = ["ari", "nmi", "chaos", "pas", "asw", "hom", "com"]
cell_type_key = "Main_molecular_tissue_region"
batch_key = "batch"
spatial_key = "spatial"
latent_key = "nichecompass_latent"
sample_key = "batch"
spot_size = 0.1

#latent_leiden_resolution = 0.05
#leiden_resolution_increments = 0.01
#valid_niche_thresh = 100
#latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

hyperparam_dict = {"conv_layer_encoder": ["gcnconv", "gatv2conv"]}

run_dict = {"dataset": [],
            "job_id": [],
            "seed": [],
            "nnmi": [],
            "nari": []}
for hyperparam in hyperparam_dict.keys():
    run_dict[hyperparam] = []
for metric in benchmarking_metrics:
    run_dict[metric] = []
#for metric in sdm_benchmarking_metrics:
#    run_dict[metric] = []

In [None]:
compute_metrics(run_dict,
                hyperparam_dict)

### 2.4 Neighbor Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "neighbor_ablation"
model_folder_path = f"{artifacts_folder_path}/{dataset_name}/models/{ablation_task}"

benchmarking_metrics = ["cas", "mlami", "clisis", "gcs", "cnmi", "nasw"]
#sdm_benchmarking_metrics = ["ari", "nmi", "chaos", "pas", "asw", "hom", "com"]
cell_type_key = "Main_molecular_tissue_region"
batch_key = "batch"
spatial_key = "spatial"
latent_key = "nichecompass_latent"
sample_key = "batch"
spot_size = 0.1

#latent_leiden_resolution = 0.05
#leiden_resolution_increments = 0.01
#valid_niche_thresh = 100
#latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

hyperparam_dict = {"n_neighbors": [4, 8, 12, 16, 20]}

run_dict = {"dataset": [],
            "job_id": [],
            "seed": [],
            "nnmi": [],
            "nari": []}
for hyperparam in hyperparam_dict.keys():
    run_dict[hyperparam] = []
for metric in benchmarking_metrics:
    run_dict[metric] = []
#for metric in sdm_benchmarking_metrics:
#    run_dict[metric] = []

In [None]:
compute_metrics(run_dict,
                hyperparam_dict)

### 2.5 De novo GP Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "denovogp_ablation"
model_folder_path = f"{artifacts_folder_path}/{dataset_name}/models/{ablation_task}"

benchmarking_metrics = ["cas", "mlami", "clisis", "gcs", "cnmi", "nasw"]
#sdm_benchmarking_metrics = ["ari", "nmi", "chaos", "pas", "asw", "hom", "com"]
cell_type_key = "Main_molecular_tissue_region"
batch_key = "batch"
spatial_key = "spatial"
latent_key = "nichecompass_latent"
sample_key = "batch"
spot_size = 0.1

#latent_leiden_resolution = 0.05
#leiden_resolution_increments = 0.01
#valid_niche_thresh = 100
#latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

hyperparam_dict = {"n_addon_gp": [0, 10, 30, 100, 500]}

run_dict = {"dataset": [],
            "job_id": [],
            "seed": [],
            "nnmi": [],
            "nari": []}
for hyperparam in hyperparam_dict.keys():
    run_dict[hyperparam] = []
for metric in benchmarking_metrics:
    run_dict[metric] = []
#for metric in sdm_benchmarking_metrics:
#    run_dict[metric] = []

In [None]:
compute_metrics(run_dict,
                hyperparam_dict)

### 2.6 GP Selection Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "gpselection_ablation"
model_folder_path = f"{artifacts_folder_path}/{dataset_name}/models/{ablation_task}"

benchmarking_metrics = ["cas", "mlami", "clisis", "gcs", "cnmi", "nasw"]
#sdm_benchmarking_metrics = ["ari", "nmi", "chaos", "pas", "asw", "hom", "com"]
cell_type_key = "Main_molecular_tissue_region"
batch_key = "batch"
spatial_key = "spatial"
latent_key = "nichecompass_latent"
sample_key = "batch"
spot_size = 0.1

#latent_leiden_resolution = 0.05
#leiden_resolution_increments = 0.01
#valid_niche_thresh = 100
#latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

hyperparam_dict = {"active_gp_thresh_ratio": [0., 0.01, 0.03, 0.1, 0.3, 0.5, 1]}

run_dict = {"dataset": [],
            "job_id": [],
            "seed": [],
            "nnmi": [],
            "nari": []}
for hyperparam in hyperparam_dict.keys():
    run_dict[hyperparam] = []
#for metric in benchmarking_metrics:
#    run_dict[metric] = []
#for metric in sdm_benchmarking_metrics:
#    run_dict[metric] = []

In [None]:
compute_metrics(run_dict,
                hyperparam_dict)

### 2.7 No Prior GP Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "nopriorgp_ablation"
model_folder_path = f"{artifacts_folder_path}/{dataset_name}/models/{ablation_task}"

benchmarking_metrics = ["cas", "mlami", "clisis", "gcs", "cnmi", "nasw"]
#sdm_benchmarking_metrics = ["ari", "nmi", "chaos", "pas", "asw", "hom", "com"]
cell_type_key = "Main_molecular_tissue_region"
batch_key = "batch"
spatial_key = "spatial"
latent_key = "nichecompass_latent"
sample_key = "batch"
spot_size = 0.1

#latent_leiden_resolution = 0.05
#leiden_resolution_increments = 0.01
#valid_niche_thresh = 100
#latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

hyperparam_dict = {"priorgp": [1, 0]}

run_dict = {"dataset": [],
            "job_id": [],
            "seed": [],
            "nnmi": [],
            "nari": []}
for hyperparam in hyperparam_dict.keys():
    run_dict[hyperparam] = []
#for metric in benchmarking_metrics:
#    run_dict[metric] = []
#for metric in sdm_benchmarking_metrics:
#    run_dict[metric] = []

In [None]:
compute_metrics(run_dict,
                hyperparam_dict)

## 3. Visualize Results

### 3.1 Loss Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "loss_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
spatial_consistency_metrics = ["cas", "mlami", "clisis", "gcs"]
niche_coherence_metrics = ["cnmi", "nasw"]

# Apply min-max scaling to metric columns
for i in range(len(spatial_consistency_metrics)):
    min_val = run_df[spatial_consistency_metrics[i]].min()
    max_val = run_df[spatial_consistency_metrics[i]].max()
    run_df[spatial_consistency_metrics[i] + "_scaled"] = ((
        run_df[spatial_consistency_metrics[i]] - min_val) / (max_val - min_val))
    run_df[spatial_consistency_metrics[i] + "_scaled"] = run_df[spatial_consistency_metrics[i] + "_scaled"].fillna(0)

spatial_consistency_metrics_scaled = [metric_col + "_scaled" for metric_col in spatial_consistency_metrics]

run_df["Spatial Consistency Score"] = np.average(run_df[spatial_consistency_metrics_scaled],
                                                        weights=[0.25, 0.25, 0.25, 0.25],
                                                        axis=1)

# Apply min-max scaling to metric columns
for i in range(len(niche_coherence_metrics)):
    min_val = run_df[niche_coherence_metrics[i]].min()
    max_val = run_df[niche_coherence_metrics[i]].max()
    run_df[niche_coherence_metrics[i] + "_scaled"] = ((
        run_df[niche_coherence_metrics[i]] - min_val) / (max_val - min_val))
    run_df[niche_coherence_metrics[i] + "_scaled"] = run_df[niche_coherence_metrics[i] + "_scaled"].fillna(0)

niche_coherence_metrics_scaled = [metric_col + "_scaled" for metric_col in niche_coherence_metrics]

run_df["Niche Coherence Score"] = np.average(run_df[niche_coherence_metrics_scaled],
                                             weights=[0.5, 0.5],
                                             axis=1)

In [None]:
cat = "Loss"
title = "Edge & Gene Expression Reconstruction Ablation"

mapping_dict = {"500000.0_300.0": "Balanced Edge & Gene Expr Recon",
                "0.0_300.0": "Only Gene Expr Recon",
                "500000.0_0.0": "Only Edge Recon",
                "50000.0_300.0": "Weak Edge Recon",
                "500000.0_30.0": "Weak Gene Expr Recon"
               }

col1 = "lambda_edge_recon"
col2 = "lambda_gene_expr_recon"

def map_values(row):
    return mapping_dict.get((str(row[col1]) + "_" + str(row[col2])), "NA")

run_df["Loss"] = run_df.apply(lambda row: map_values(row), axis=1)

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Spatial Consistency Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Niche Coherence Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.2 Loss Ablation Extended

In [None]:
ablation_task = "loss_extended_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
spatial_consistency_metrics = ["cas", "mlami", "clisis", "gcs"]
niche_coherence_metrics = ["cnmi", "nasw"]

# Apply min-max scaling to metric columns
for i in range(len(spatial_consistency_metrics)):
    min_val = run_df[spatial_consistency_metrics[i]].min()
    max_val = run_df[spatial_consistency_metrics[i]].max()
    run_df[spatial_consistency_metrics[i] + "_scaled"] = ((
        run_df[spatial_consistency_metrics[i]] - min_val) / (max_val - min_val))
    run_df[spatial_consistency_metrics[i] + "_scaled"] = run_df[spatial_consistency_metrics[i] + "_scaled"].fillna(0)

spatial_consistency_metrics_scaled = [metric_col + "_scaled" for metric_col in spatial_consistency_metrics]

run_df["Spatial Consistency Score"] = np.average(run_df[spatial_consistency_metrics_scaled],
                                                        weights=[0.25, 0.25, 0.25, 0.25],
                                                        axis=1)

# Apply min-max scaling to metric columns
for i in range(len(niche_coherence_metrics)):
    min_val = run_df[niche_coherence_metrics[i]].min()
    max_val = run_df[niche_coherence_metrics[i]].max()
    run_df[niche_coherence_metrics[i] + "_scaled"] = ((
        run_df[niche_coherence_metrics[i]] - min_val) / (max_val - min_val))
    run_df[niche_coherence_metrics[i] + "_scaled"] = run_df[niche_coherence_metrics[i] + "_scaled"].fillna(0)

niche_coherence_metrics_scaled = [metric_col + "_scaled" for metric_col in niche_coherence_metrics]

run_df["Niche Coherence Score"] = np.average(run_df[niche_coherence_metrics_scaled],
                                             weights=[0.5, 0.5],
                                             axis=1)

In [None]:
cat = "Loss"
title = "Gene Expr Regularization Ablation"

mapping_dict = {"0.0_0.0": "No Reg",
                "0.0_3.0": "Only Weak De-novo Reg",
                "0.0_30.0": "Only Medium De-novo Reg",
                "0.0_300.0": "Only Strong De-novo Reg",
                "3.0_3.0": "Weak Prior & De-novo Reg",
                "30.0_30.0": "Medium Prior & De-novo Reg",
                "300.0_300.0": "Strong Prior & De-novo Reg",
               }

col1 = "lambda_l1_masked"
col2 = "lambda_l1_addon"

def map_values(row):
    return mapping_dict.get((str(row[col1]) + "_" + str(row[col2])), "NA")

run_df["Loss"] = run_df.apply(lambda row: map_values(row), axis=1)

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Spatial Consistency Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Niche Coherence Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.3 Encoder Ablation

In [None]:
ablation_task = "encoder_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
spatial_consistency_metrics = ["cas", "mlami", "clisis", "gcs"]
niche_coherence_metrics = ["cnmi", "nasw"]

# Apply min-max scaling to metric columns
for i in range(len(spatial_consistency_metrics)):
    min_val = run_df[spatial_consistency_metrics[i]].min()
    max_val = run_df[spatial_consistency_metrics[i]].max()
    run_df[spatial_consistency_metrics[i] + "_scaled"] = ((
        run_df[spatial_consistency_metrics[i]] - min_val) / (max_val - min_val))
    run_df[spatial_consistency_metrics[i] + "_scaled"] = run_df[spatial_consistency_metrics[i] + "_scaled"].fillna(0)

spatial_consistency_metrics_scaled = [metric_col + "_scaled" for metric_col in spatial_consistency_metrics]

run_df["Spatial Consistency Score"] = np.average(run_df[spatial_consistency_metrics_scaled],
                                                        weights=[0.25, 0.25, 0.25, 0.25],
                                                        axis=1)

# Apply min-max scaling to metric columns
for i in range(len(niche_coherence_metrics)):
    min_val = run_df[niche_coherence_metrics[i]].min()
    max_val = run_df[niche_coherence_metrics[i]].max()
    run_df[niche_coherence_metrics[i] + "_scaled"] = ((
        run_df[niche_coherence_metrics[i]] - min_val) / (max_val - min_val))
    run_df[niche_coherence_metrics[i] + "_scaled"] = run_df[niche_coherence_metrics[i] + "_scaled"].fillna(0)

niche_coherence_metrics_scaled = [metric_col + "_scaled" for metric_col in niche_coherence_metrics]

run_df["Niche Coherence Score"] = np.average(run_df[niche_coherence_metrics_scaled],
                                             weights=[0.5, 0.5],
                                             axis=1)

In [None]:
cat = "Encoder"
title = "Encoder Ablation"

mapping_dict = {"gcnconv": "GCNConv (NicheCompass Light)",
                "gatv2conv": "GATv2Conv (NicheCompass)",
               }

col1 = "conv_layer_encoder"

def map_values(row):
    return mapping_dict.get(str(row[col1]), "NA")

run_df["Encoder"] = run_df.apply(lambda row: map_values(row), axis=1)

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Spatial Consistency Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Niche Coherence Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.4 Neighbor Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "neighbor_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="n_neighbors", ascending=True, inplace=True)

In [None]:
spatial_consistency_metrics = ["cas", "mlami", "clisis", "gcs"]
niche_coherence_metrics = ["cnmi", "nasw"]

# Apply min-max scaling to metric columns
for i in range(len(spatial_consistency_metrics)):
    min_val = run_df[spatial_consistency_metrics[i]].min()
    max_val = run_df[spatial_consistency_metrics[i]].max()
    run_df[spatial_consistency_metrics[i] + "_scaled"] = ((
        run_df[spatial_consistency_metrics[i]] - min_val) / (max_val - min_val))
    run_df[spatial_consistency_metrics[i] + "_scaled"] = run_df[spatial_consistency_metrics[i] + "_scaled"].fillna(0)

spatial_consistency_metrics_scaled = [metric_col + "_scaled" for metric_col in spatial_consistency_metrics]

run_df["Spatial Consistency Score"] = np.average(run_df[spatial_consistency_metrics_scaled],
                                                        weights=[0.25, 0.25, 0.25, 0.25],
                                                        axis=1)

# Apply min-max scaling to metric columns
for i in range(len(niche_coherence_metrics)):
    min_val = run_df[niche_coherence_metrics[i]].min()
    max_val = run_df[niche_coherence_metrics[i]].max()
    run_df[niche_coherence_metrics[i] + "_scaled"] = ((
        run_df[niche_coherence_metrics[i]] - min_val) / (max_val - min_val))
    run_df[niche_coherence_metrics[i] + "_scaled"] = run_df[niche_coherence_metrics[i] + "_scaled"].fillna(0)

niche_coherence_metrics_scaled = [metric_col + "_scaled" for metric_col in niche_coherence_metrics]

run_df["Niche Coherence Score"] = np.average(run_df[niche_coherence_metrics_scaled],
                                             weights=[0.5, 0.5],
                                             axis=1)

In [None]:
cat = "n_neighbors"
cat_label = "KNN"
title = "Neighborhood Ablation"

run_df["n_neighbors"] = run_df["n_neighbors"].apply(lambda x: str(x))

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel("KNN")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel("KNN")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Spatial Consistency Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel("KNN")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Niche Coherence Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel("KNN")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.5 De novo GP Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "denovogp_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="n_addon_gp", ascending=True, inplace=True)

In [None]:
spatial_consistency_metrics = ["cas", "mlami", "clisis", "gcs"]
niche_coherence_metrics = ["cnmi", "nasw"]

# Apply min-max scaling to metric columns
for i in range(len(spatial_consistency_metrics)):
    min_val = run_df[spatial_consistency_metrics[i]].min()
    max_val = run_df[spatial_consistency_metrics[i]].max()
    run_df[spatial_consistency_metrics[i] + "_scaled"] = ((
        run_df[spatial_consistency_metrics[i]] - min_val) / (max_val - min_val))
    run_df[spatial_consistency_metrics[i] + "_scaled"] = run_df[spatial_consistency_metrics[i] + "_scaled"].fillna(0)

spatial_consistency_metrics_scaled = [metric_col + "_scaled" for metric_col in spatial_consistency_metrics]

run_df["Spatial Consistency Score"] = np.average(run_df[spatial_consistency_metrics_scaled],
                                                        weights=[0.25, 0.25, 0.25, 0.25],
                                                        axis=1)

# Apply min-max scaling to metric columns
for i in range(len(niche_coherence_metrics)):
    min_val = run_df[niche_coherence_metrics[i]].min()
    max_val = run_df[niche_coherence_metrics[i]].max()
    run_df[niche_coherence_metrics[i] + "_scaled"] = ((
        run_df[niche_coherence_metrics[i]] - min_val) / (max_val - min_val))
    run_df[niche_coherence_metrics[i] + "_scaled"] = run_df[niche_coherence_metrics[i] + "_scaled"].fillna(0)

niche_coherence_metrics_scaled = [metric_col + "_scaled" for metric_col in niche_coherence_metrics]

run_df["Niche Coherence Score"] = np.average(run_df[niche_coherence_metrics_scaled],
                                             weights=[0.5, 0.5],
                                             axis=1)

In [None]:
cat = "n_addon_gp"
cat_label = "De-novo GPs"
title = "De-Novo GP Ablation"

run_df["n_addon_gp"] = run_df["n_addon_gp"].apply(lambda x: str(x))

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel("De-novo GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel("De-novo GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Spatial Consistency Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel("De-novo GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "Niche Coherence Score"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric)
plt.ylabel("De-novo GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

In [None]:
cat = "n_addon_gp"
cat_label = "De-novo GPs"
title = "De-novo GP Ablation"

run_df["n_addon_gp"] = run_df["n_addon_gp"].apply(lambda x: str(x))

metric = "nmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper())
plt.ylabel(cat_label)
plt.savefig(f"{artifacts_folder_path}/ablation/{ablation}_{metric}.svg", bbox_inches="tight")
plt.show()

metric = "ari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper())
plt.ylabel(cat_label)
plt.savefig(f"{artifacts_folder_path}/ablation/{ablation}_{metric}.svg", bbox_inches="tight")
plt.show()

metric = "hom"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper())
plt.ylabel(cat_label)
plt.savefig(f"{artifacts_folder_path}/ablation/{ablation}_{metric}.svg", bbox_inches="tight")
plt.show()

metric = "com"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper())
plt.ylabel(cat_label)
plt.savefig(f"{artifacts_folder_path}/ablation/{ablation}_{metric}.svg", bbox_inches="tight")
plt.show()

metric = "f1_enriched_prior_gps"
metric_label = "F1 Prior GPs"

plt.figure(figsize=(2.25, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.axvline(x=np.mean(run_df.groupby("run_number")[metric.replace("enriched", "random")].mean()), color='red', linestyle='--')
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric_label)
plt.ylabel(cat_label)
plt.savefig(f"{artifacts_folder_path}/ablation/{ablation}_{metric}.svg", bbox_inches="tight")
plt.show()

metric = "f1_enriched_prior_gp_source_genes_top3"
metric_label = "F1 Prior GP Source Genes"

plt.figure(figsize=(2.25, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.axvline(x=np.mean(run_df.groupby("run_number")[metric.replace("enriched", "random")].mean()), color='red', linestyle='--')
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric_label)
plt.ylabel(cat_label)
plt.savefig(f"{artifacts_folder_path}/ablation/{ablation}_{metric}.svg", bbox_inches="tight")
plt.show()

metric = "f1_enriched_prior_gp_target_genes_top3"
metric_label = "F1 Prior GP Target Genes"

plt.figure(figsize=(2.25, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.axvline(x=np.mean(run_df.groupby("run_number")[metric.replace("enriched", "random")].mean()), color='red', linestyle='--')
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric_label)
plt.ylabel(cat_label)
plt.savefig(f"{artifacts_folder_path}/ablation/{ablation}_{metric}.svg", bbox_inches="tight")
plt.show()

metric = "f1_enriched_denovo_gp_target_genes_top3"
metric_label = "F1 De-Novo GP\nTarget Genes"

plt.figure(figsize=(2.25, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.axvline(x=np.mean(run_df.groupby("run_number")[metric.replace("enriched", "random")].mean()), color='red', linestyle='--')
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric_label)
plt.ylabel(cat_label)
plt.savefig(f"{artifacts_folder_path}/ablation/{ablation}_{metric}.svg", bbox_inches="tight")
plt.show()

metric = "f1_enriched_denovo_gp_source_genes_top3"
metric_label = "F1 De-Novo GP\nSource Genes"

plt.figure(figsize=(2.25, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.axvline(x=np.mean(run_df.groupby("run_number")[metric.replace("enriched", "random")].mean()), color='red', linestyle='--')
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric_label)
plt.ylabel(cat_label)
plt.savefig(f"{artifacts_folder_path}/ablation/{ablation}_{metric}.svg", bbox_inches="tight")
plt.show()

### 3.6 GP Selection Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "gpselection_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
cat = "active_gp_thresh_ratio"
title = "GP Pruning Ablation"

run_df["active_gp_thresh_ratio"] = run_df["active_gp_thresh_ratio"].apply(lambda x: str(x))

mapping_dict = {"0.0_0.0": "No Reg",
                "0.0_3.0": "Only Weak De-novo Reg",
                "0.0_30.0": "Only Medium De-novo Reg",
                "0.0_300.0": "Only Strong De-novo Reg",
                "3.0_3.0": "Weak Prior & De-novo Reg",
                "30.0_30.0": "Medium Prior & De-novo Reg",
                "300.0_300.0": "Strong Prior & De-novo Reg",
               }

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel("Active GP Thresh")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel("Active GP Thresh")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.7 No Prior GP Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "nopriorgp_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
cat = "priorgp"
title = "Prior GP Ablation"

mapping_dict = {0: "No Prior GPs",
                1: "Prior GPs",
               }


def map_values(row):
    return mapping_dict.get(row[cat])

run_df["priorgp"] = run_df.apply(lambda row: map_values(row), axis=1)

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel("Use of Prior GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel("Use of Prior GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()