# Real Data Model Ablation

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>)
- **Date of Creation:** 08.10.2024
- **Date of Last Modification:** 21.12.2024

## 1. Setup

Run this notebook in the nichecompass-reproducibility environment, installable from ```('../../../envs/environment.yaml')```.

Before running this notebook:
- Clone SDMBench from https://github.com/zhaofangyuan98/SDMBench.git into ```('../benchmarking')``` (some slight modifications to the SDMBench source code were necessary to remove technical bugs).
    - Move _compute_CHAOS function into compute_CHAOS
    - Move _compute_PAS function into compute_PAS
    - Move fx_kNN function into compute_PAS
    - Move fx_1NN function into compute_CHAOS

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../benchmarking/SDMBench/SDMBench")
sys.path.append("../../utils")

In [None]:
import gc
import math
import os
import warnings

import anndata as ad
import scanpy as sc
import scib_metrics
import squidpy as sq
from nichecompass.benchmarking import compute_benchmarking_metrics

from SDMBench import sdmbench
from benchmarking_utils import *

### 1.2 Run Notebook Setup

In [None]:
warnings.filterwarnings("ignore")
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [None]:
plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['font.size'] = 10

### 1.3 Configure Paths

In [None]:
artifacts_folder_path = f"../../artifacts"

### 1.4 Define Functions

In [None]:
def compute_metrics(run_dict,
                    hyperparam_dict,
                    job_ids=[]):
    for subdir, dirs, files in os.walk(model_folder_path):
        for file in files:
            if file == f"{dataset_name}_{ablation_task}.h5ad":
                file_path = os.path.join(subdir, file)
                print(f"Loading file: {file_path}")
                adata = ad.read_h5ad(file_path)
                job_id = int(subdir[-2:].strip("_"))
                
                if job_id not in job_ids:
                    run_dict["dataset"].append(dataset_name)
                    run_dict["job_id"].append(job_id)
                    run_dict["seed"].append(((job_id - 1) % 8))
                    for key, values in hyperparam_dict.items():
                        run_dict[key].append(values[(math.floor((job_id - 1) / 8))])

                    metrics = scib_metrics.nmi_ari_cluster_labels_kmeans(
                        adata.obsm["nichecompass_latent"],
                        adata.obs["Main_molecular_tissue_region"])

                    run_dict["nnmi"].append(metrics["nmi"])
                    run_dict["nari"].append(metrics["ari"])

                    del(adata)
                    gc.collect()

                    run_df = pd.DataFrame(run_dict)
                    run_df.to_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv", index=False)
                else:
                    continue

## 2. Compute Metrics

### 2.1 Loss Ablation

### 2.2 Loss Ablation Extended

### 2.3 Encoder Ablation

### 2.4 Neighbor Ablation

### 2.5 De novo GP Ablation

### 2.6 GP Selection Ablation

### 2.7 No Prior GP Ablation

## 3. Visualize Results

### 3.1 Loss Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "loss_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
### Supplementary Fig. 12a ###
cat = "Loss"
title = "Edge & Gene Expression Reconstruction Ablation"

mapping_dict = {"500000.0_300.0": "Balanced Edge & Gene Expr Recon",
                "0.0_300.0": "Only Gene Expr Recon",
                "500000.0_0.0": "Only Edge Recon",
                "50000.0_300.0": "Weak Edge Recon",
                "500000.0_30.0": "Weak Gene Expr Recon"
               }

col1 = "lambda_edge_recon"
col2 = "lambda_gene_expr_recon"

def map_values(row):
    return mapping_dict.get((str(row[col1]) + "_" + str(row[col2])), "NA")

run_df["Loss"] = run_df.apply(lambda row: map_values(row), axis=1)

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.2 Loss Ablation Extended

In [None]:
ablation_task = "loss_extended_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
### Supplementary Fig. 12b ###
cat = "Loss"
title = "Gene Expr Regularization Ablation"

mapping_dict = {"0.0_0.0": "No Reg",
                "0.0_3.0": "Only Weak De-novo Reg",
                "0.0_30.0": "Only Medium De-novo Reg",
                "0.0_300.0": "Only Strong De-novo Reg",
                "3.0_3.0": "Weak Prior & De-novo Reg",
                "30.0_30.0": "Medium Prior & De-novo Reg",
                "300.0_300.0": "Strong Prior & De-novo Reg",
               }

col1 = "lambda_l1_masked"
col2 = "lambda_l1_addon"

def map_values(row):
    return mapping_dict.get((str(row[col1]) + "_" + str(row[col2])), "NA")

run_df["Loss"] = run_df.apply(lambda row: map_values(row), axis=1)

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.3 Encoder Ablation

In [None]:
ablation_task = "encoder_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
### Supplementary Fig. 12c ###
cat = "Encoder"
title = "Encoder Ablation"

mapping_dict = {"gcnconv": "GCNConv (NicheCompass Light)",
                "gatv2conv": "GATv2Conv (NicheCompass)",
               }

col1 = "conv_layer_encoder"

def map_values(row):
    return mapping_dict.get(str(row[col1]), "NA")

run_df["Encoder"] = run_df.apply(lambda row: map_values(row), axis=1)

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel(cat.capitalize())
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.4 Neighbor Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "neighbor_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="n_neighbors", ascending=True, inplace=True)

In [None]:
### Supplementary Fig. 12d ###
cat = "n_neighbors"
cat_label = "KNN"
title = "Neighborhood Ablation"

run_df["n_neighbors"] = run_df["n_neighbors"].apply(lambda x: str(x))

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel("KNN")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel("KNN")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.5 De novo GP Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "denovogp_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="n_addon_gp", ascending=True, inplace=True)

In [None]:
### Supplementary Fig. 12e ###
cat = "n_addon_gp"
cat_label = "De-novo GPs"
title = "De-Novo GP Ablation"

run_df["n_addon_gp"] = run_df["n_addon_gp"].apply(lambda x: str(x))

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel("De-novo GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel("De-novo GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.6 GP Selection Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "gpselection_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
### Supplementary Fig. 12f ###
cat = "active_gp_thresh_ratio"
title = "GP Pruning Ablation"

run_df["active_gp_thresh_ratio"] = run_df["active_gp_thresh_ratio"].apply(lambda x: str(x))

mapping_dict = {"0.0_0.0": "No Reg",
                "0.0_3.0": "Only Weak De-novo Reg",
                "0.0_30.0": "Only Medium De-novo Reg",
                "0.0_300.0": "Only Strong De-novo Reg",
                "3.0_3.0": "Weak Prior & De-novo Reg",
                "30.0_30.0": "Medium Prior & De-novo Reg",
                "300.0_300.0": "Strong Prior & De-novo Reg",
               }

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel("Active GP Thresh")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel("Active GP Thresh")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

### 3.7 No Prior GP Ablation

In [None]:
dataset_name = "starmap_plus_mouse_cns"
ablation_task = "nopriorgp_ablation"
run_df = pd.read_csv(f"{artifacts_folder_path}/ablation/{dataset_name}_metrics_{ablation_task}.csv")

run_df.sort_values(by="job_id", ascending=True, inplace=True)

In [None]:
### Supplementary Fig. 12g ###
cat = "priorgp"
title = "Prior GP Ablation"

mapping_dict = {0: "No Prior GPs",
                1: "Prior GPs",
               }


def map_values(row):
    return mapping_dict.get(row[cat])

run_df["priorgp"] = run_df.apply(lambda row: map_values(row), axis=1)

metric = "nnmi"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NNMI", "NMI"))
plt.ylabel("Use of Prior GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()

metric = "nari"

plt.figure(figsize=(3, 1.5))
sns.boxplot(x=metric, y=cat, data=run_df)
sns.stripplot(x=metric,
              y=cat,
              data=run_df,
              dodge=True,
              alpha=0.6,
              size=5,
              jitter=0.2,
              edgecolor="black",
              linewidth=0.5)
plt.suptitle(title, x=0.09, ha="center", va="top", y=1.1)
plt.xlabel(metric.upper().replace("NARI", "ARI"))
plt.ylabel("Use of Prior GPs")
plt.savefig(f"{artifacts_folder_path}/ablation/{dataset_name}_{metric}_{ablation_task}.svg", bbox_inches="tight")
plt.show()