# Results

## Library Import

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import seaborn as sns
from matplotlib import colors as mcolors
from matplotlib.patches import Patch
from scipy.stats import shapiro, ttest_rel, wilcoxon
from statsmodels.stats.multitest import multipletests

## Helper Functions

In [None]:
def plot_merfish_zhuang_large(adata, color_key, figsize=(6, 6), size=5, title=None):
    fig, ax = plt.subplots(figsize=figsize, dpi=300)
    domain_label = color_key

    unique_classes = adata.obs[domain_label].unique()
    num_classes = len(unique_classes)
    palette = sns.color_palette("tab20", num_classes)
    adata.uns[f"{domain_label}_colors"] = [mcolors.rgb2hex(c) for c in palette]

    sc.pl.embedding(
        adata, basis="spatial", color=domain_label, size=size, ax=ax, show=False, legend_loc=None
    )

    ax.set_ylim(11, 0)
    ax.set_xlim(0, 11)
    ax.axis("equal")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)

    plt.show()

## Domain123 Baseline

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain123/baseline/2025-04-16_08-55-57/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="sample_name",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2), sharex=True, dpi=300)

metrics = ["nmi", "homogeneity", "completeness"]
titles = ["NMI", "HOM", "COM"]

for ax, metric, title in zip(axes, metrics, titles):
    sns.boxplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        orient="h",
        dodge=False,
    )
    sns.stripplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        color="black",
        size=5,
        jitter=True,
        orient="h",
    )
    ax.set_xlim(0.2, 0.8)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Score")
    ax.set_ylabel("")
    ax.set_yticklabels([])

    ax.set_xticks([0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85])
    ax.set_xticklabels(["", "", "0.25", "", "", "", "", "0.75", ""])

plt.tight_layout()
plt.show()

In [None]:
merfish_small5 = sc.read_h5ad(
    "../data/domain/results/domain123/baseline/2025-04-16_08-55-57/adata_files/MERFISH_small5.h5ad"
)
merfish_small5

In [None]:
merfish_small5.obsm["spatial"] *= -1
merfish_small5.obsm["spatial"][:, 0] *= -1

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="domain_annotation",
    size=60,
    title="Ground Truth",
    ax=ax,
    show=False,
)
plt.show()

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="leiden",
    size=60,
    title="NMI: 0.68, HOM: 0.67, COM: 0.69",
    ax=ax,
    show=False,
    legend_loc=None,
)
plt.show()

In [None]:
merfish_small5.uns["leiden_colors"] = [
    merfish_small5.uns["leiden_colors"][i] for i in [1, 0, 3, 4, 2, 7, 5, 6]
]

## Domain123 Augmentation - Baseline + SpatialNoise + FeatureNoise

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-56-13/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="sample_name",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2), sharex=True, dpi=300)

metrics = ["nmi", "homogeneity", "completeness"]
titles = ["NMI", "HOM", "COM"]

for ax, metric, title in zip(axes, metrics, titles):
    sns.boxplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        orient="h",
        dodge=False,
    )
    sns.stripplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        color="black",
        size=5,
        jitter=True,
        orient="h",
    )
    ax.set_xlim(0.2, 0.8)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Score")
    ax.set_ylabel("")
    ax.set_yticklabels([])

    ax.set_xticks([0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85])
    ax.set_xticklabels(["", "", "0.25", "", "", "", "", "0.75", ""])

plt.tight_layout()
plt.show()

In [None]:
merfish_small5 = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-56-13/adata_files/MERFISH_small5.h5ad"
)
merfish_small5

In [None]:
merfish_small5.obsm["spatial"] *= -1
merfish_small5.obsm["spatial"][:, 0] *= -1

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="domain_annotation",
    size=60,
    title="Ground Truth",
    ax=ax,
    show=False,
)
plt.show()

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="leiden",
    size=60,
    title="NMI: 0.69, HOM: 0.66, COM: 0.71",
    ax=ax,
    show=False,
    legend_loc=None,
)
plt.show()

In [None]:
merfish_small5.uns["leiden_colors"] = [
    merfish_small5.uns["leiden_colors"][i] for i in [1, 0, 6, 2, 7, 3, 5, 4]
]

## Domain123 Augmentation - DropImportance

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-48-55/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="sample_name",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2), sharex=True, dpi=300)

metrics = ["nmi", "homogeneity", "completeness"]
titles = ["NMI", "HOM", "COM"]

for ax, metric, title in zip(axes, metrics, titles):
    sns.boxplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        orient="h",
        dodge=False,
    )
    sns.stripplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        color="black",
        size=5,
        jitter=True,
        orient="h",
    )
    ax.set_xlim(0.2, 0.8)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Score")
    ax.set_ylabel("")
    ax.set_yticklabels([])

    ax.set_xticks([0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85])
    ax.set_xticklabels(["", "", "0.25", "", "", "", "", "0.75", ""])

plt.tight_layout()
plt.show()

In [None]:
merfish_small5 = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-48-55/adata_files/MERFISH_small5.h5ad"
)
merfish_small5

In [None]:
merfish_small5.obsm["spatial"] *= -1
merfish_small5.obsm["spatial"][:, 0] *= -1

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="domain_annotation",
    size=60,
    title="Ground Truth",
    ax=ax,
    show=False,
)
plt.show()

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="leiden",
    size=60,
    title="NMI: 0.64, HOM: 0.63, COM: 0.65",
    ax=ax,
    show=False,
    legend_loc=None,
)
plt.show()

In [None]:
merfish_small5.uns["leiden_colors"] = [
    merfish_small5.uns["leiden_colors"][i] for i in [1, 2, 4, 3, 0, 7, 5, 6]
]

## Domain123 Augmentation - DropImportance + SpatialNoise + FeatureNoise

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-58-11/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="sample_name",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2), sharex=True, dpi=300)

metrics = ["nmi", "homogeneity", "completeness"]
titles = ["NMI", "HOM", "COM"]

for ax, metric, title in zip(axes, metrics, titles):
    sns.boxplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        orient="h",
        dodge=False,
    )
    sns.stripplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        color="black",
        size=5,
        jitter=True,
        orient="h",
    )
    ax.set_xlim(0.2, 0.8)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Score")
    ax.set_ylabel("")
    ax.set_yticklabels([])

    ax.set_xticks([0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85])
    ax.set_xticklabels(["", "", "0.25", "", "", "", "", "0.75", ""])

plt.tight_layout()
plt.show()

In [None]:
merfish_small5 = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-58-11/adata_files/MERFISH_small5.h5ad"
)
merfish_small5

In [None]:
merfish_small5.obsm["spatial"] *= -1
merfish_small5.obsm["spatial"][:, 0] *= -1

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="domain_annotation",
    size=60,
    title="Ground Truth",
    ax=ax,
    show=False,
)
plt.show()

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="leiden",
    size=60,
    title="NMI: 0.70, HOM: 0.70, COM: 0.69",
    ax=ax,
    show=False,
    legend_loc=None,
)
plt.show()

In [None]:
merfish_small5.uns["leiden_colors"] = [
    merfish_small5.uns["leiden_colors"][i] for i in [0, 7, 3, 1, 2, 4, 5, 6]
]

## Domain123 Augmentation - DropImportance + SpatialNoise + FeatureNoise + ShufflePositions

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-20_09-55-01/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="sample_name",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2), sharex=True, dpi=300)

metrics = ["nmi", "homogeneity", "completeness"]
titles = ["NMI", "HOM", "COM"]

for ax, metric, title in zip(axes, metrics, titles):
    sns.boxplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        orient="h",
        dodge=False,
    )
    sns.stripplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        color="black",
        size=5,
        jitter=True,
        orient="h",
    )
    ax.set_xlim(0.2, 0.8)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Score")
    ax.set_ylabel("")
    ax.set_yticklabels([])

    ax.set_xticks([0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85])
    ax.set_xticklabels(["", "", "0.25", "", "", "", "", "0.75", ""])

plt.tight_layout()
plt.show()

In [None]:
merfish_small5 = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-20_09-55-01/adata_files/MERFISH_small5.h5ad"
)
merfish_small5

In [None]:
merfish_small5.obsm["spatial"] *= -1
merfish_small5.obsm["spatial"][:, 0] *= -1

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="domain_annotation",
    size=60,
    title="Ground Truth",
    ax=ax,
    show=False,
)
plt.show()

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="leiden",
    size=60,
    title="NMI: 0.70, HOM: 0.72, COM: 0.68",
    ax=ax,
    show=False,
    legend_loc=None,
)
plt.show()

In [None]:
merfish_small5.uns["leiden_colors"] = [
    merfish_small5.uns["leiden_colors"][i] for i in [1, 3, 0, 2, 7, 5, 4, 6]
]

## Domain123 Augmentation - DropImportance + SpatialNoise + FeatureNoise + AddEdgesByFeatureSimilarity

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-20_13-22-07/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="sample_name",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2), sharex=True, dpi=300)

metrics = ["nmi", "homogeneity", "completeness"]
titles = ["NMI", "HOM", "COM"]

for ax, metric, title in zip(axes, metrics, titles):
    sns.boxplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        orient="h",
        dodge=False,
    )
    sns.stripplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        color="black",
        size=5,
        jitter=True,
        orient="h",
    )
    ax.set_xlim(0.2, 0.8)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Score")
    ax.set_ylabel("")
    ax.set_yticklabels([])

    ax.set_xticks([0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85])
    ax.set_xticklabels(["", "", "0.25", "", "", "", "", "0.75", ""])

plt.tight_layout()
plt.show()

In [None]:
merfish_small5 = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-20_13-22-07/adata_files/MERFISH_small5.h5ad"
)
merfish_small5

In [None]:
merfish_small5.obsm["spatial"] *= -1
merfish_small5.obsm["spatial"][:, 0] *= -1

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="domain_annotation",
    size=60,
    title="Ground Truth",
    ax=ax,
    show=False,
)
plt.show()

fig, ax = plt.subplots(figsize=(5, 6), dpi=300)
sc.pl.embedding(
    merfish_small5,
    basis="spatial",
    color="leiden",
    size=60,
    title="NMI: 0.67, HOM: 0.65, COM: 0.69",
    ax=ax,
    show=False,
    legend_loc=None,
)
plt.show()

In [None]:
merfish_small5.uns["leiden_colors"] = [
    merfish_small5.uns["leiden_colors"][i] for i in [1, 0, 2, 6, 7, 3, 4, 5]
]

## Domain 123 Baseline vs. Augmentation

In [None]:
test_results_baseline = pd.read_csv(
    "../data/domain/results/domain123/baseline/2025-04-16_08-55-57/csv/version_0/test_results.csv"
)
test_results_baseline = test_results_baseline.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results_baseline

In [None]:
test_results_spatial_noise_feature_noise = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-56-13/csv/version_0/test_results.csv"
)
test_results_spatial_noise_feature_noise = test_results_spatial_noise_feature_noise.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results_spatial_noise_feature_noise

In [None]:
test_results_drop_importance = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-48-55/csv/version_0/test_results.csv"
)
test_results_drop_importance = test_results_drop_importance.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results_drop_importance

In [None]:
test_results_drop_importance_feature_noise_spatial_noise = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-58-11/csv/version_0/test_results.csv"
)
test_results_drop_importance_feature_noise_spatial_noise = (
    test_results_drop_importance_feature_noise_spatial_noise.map(
        lambda x: (
            float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
        )
    )
)
test_results_drop_importance_feature_noise_spatial_noise

In [None]:
test_results_drop_importance_feature_noise_spatial_noise_shuffle = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-20_09-55-01/csv/version_0/test_results.csv"
)
test_results_drop_importance_feature_noise_spatial_noise_shuffle = (
    test_results_drop_importance_feature_noise_spatial_noise_shuffle.map(
        lambda x: (
            float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
        )
    )
)
test_results_drop_importance_feature_noise_spatial_noise_shuffle

In [None]:
test_results_drop_importance_feature_noise_spatial_noise_addedges = pd.read_csv(
    "../data/domain/results/domain123/augmentation/2025-05-20_13-22-07/csv/version_0/test_results.csv"
)
test_results_drop_importance_feature_noise_spatial_noise_addedges = (
    test_results_drop_importance_feature_noise_spatial_noise_addedges.map(
        lambda x: (
            float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
        )
    )
)
test_results_drop_importance_feature_noise_spatial_noise_addedges

In [None]:
baseline_df = test_results_baseline
augmentation_dfs = {
    "Baseline + Spatial Noise + Feature Noise": test_results_spatial_noise_feature_noise,
    "DropImportance": test_results_drop_importance,
    "DropImportance + Spatial Noise + Feature Noise": test_results_drop_importance_feature_noise_spatial_noise,
    "DropImportance + Spatial Noise + Feature Noise + ShufflePositions": test_results_drop_importance_feature_noise_spatial_noise_shuffle,
    "DropImportance + Spatial Noise + Feature Noise + AddEdges": test_results_drop_importance_feature_noise_spatial_noise_addedges,
}

metrics = ["nmi", "homogeneity", "completeness"]
results = {}

raw_pvals = []
test_info = []

for aug_name, aug_df in augmentation_dfs.items():
    results[aug_name] = {}
    for metric in metrics:
        baseline_values = baseline_df[metric]
        aug_values = aug_df[metric]
        differences = aug_values - baseline_values

        p_normal = shapiro(differences).pvalue

        if p_normal > 0.05:
            stat, p = ttest_rel(baseline_values, aug_values)
            test = "t-test"
        else:
            stat, p = wilcoxon(baseline_values, aug_values)
            test = "wilcoxon"

        # Store raw result
        results[aug_name][metric] = {"test": test, "statistic": stat, "p_value": p}
        raw_pvals.append(p)
        test_info.append((aug_name, metric))

_, corrected_pvals, _, _ = multipletests(raw_pvals, method="fdr_bh")
for (aug_name, metric), p_corr in zip(test_info, corrected_pvals):
    results[aug_name][metric]["p_value_corrected"] = p_corr

print(results)

In [None]:
summary_rows = []
for aug, metrics_dict in results.items():
    row = {"Augmentation": aug}
    for metric in ["nmi", "homogeneity", "completeness"]:
        mean_val = augmentation_dfs[aug][metric].mean()
        p = results[aug][metric]["p_value"]
        if p < 0.001:
            stars = "***"
        elif p < 0.01:
            stars = "**"
        elif p < 0.05:
            stars = "*"
        elif p < 0.1:
            stars = "."
        else:
            stars = ""
        row[metric.upper()] = f"{mean_val:.3f}{stars}"
        row[f"p_value_{metric}"] = p
        # row[f"test_{metric}"] = results[aug][metric]['test']
    summary_rows.append(row)

summary_df = pd.DataFrame(summary_rows)
summary_df

In [None]:
df = pd.DataFrame(
    {
        "mode": [
            "Baseline",
            "Baseline + Spatial Noise + Feature Noise",
            "DropImportance",
            "DropImportance + Spatial Noise + Feature Noise",
            "DropImportance + Spatial Noise + Feature Noise + ShufflePositions",
            "DropImportance + Spatial Noise + Feature Noise + AddEdges",
        ],
        "NMI": [
            test_results_baseline["nmi"].mean(),
            test_results_spatial_noise_feature_noise["nmi"].mean(),
            test_results_drop_importance["nmi"].mean(),
            test_results_drop_importance_feature_noise_spatial_noise["nmi"].mean(),
            test_results_drop_importance_feature_noise_spatial_noise_shuffle["nmi"].mean(),
            test_results_drop_importance_feature_noise_spatial_noise_addedges["nmi"].mean(),
        ],
        "HOM": [
            test_results_baseline["homogeneity"].mean(),
            test_results_spatial_noise_feature_noise["homogeneity"].mean(),
            test_results_drop_importance["homogeneity"].mean(),
            test_results_drop_importance_feature_noise_spatial_noise["homogeneity"].mean(),
            test_results_drop_importance_feature_noise_spatial_noise_shuffle["homogeneity"].mean(),
            test_results_drop_importance_feature_noise_spatial_noise_addedges[
                "homogeneity"
            ].mean(),
        ],
        "COM": [
            test_results_baseline["completeness"].mean(),
            test_results_spatial_noise_feature_noise["completeness"].mean(),
            test_results_drop_importance["completeness"].mean(),
            test_results_drop_importance_feature_noise_spatial_noise["completeness"].mean(),
            test_results_drop_importance_feature_noise_spatial_noise_shuffle[
                "completeness"
            ].mean(),
            test_results_drop_importance_feature_noise_spatial_noise_addedges[
                "completeness"
            ].mean(),
        ],
    }
)
df

In [None]:
df_melted = df.melt(
    id_vars="mode", value_vars=["NMI", "HOM", "COM"], var_name="Metric", value_name="Score"
)
df_melted["hue"] = df_melted["mode"]
df_melted["mode"] = df_melted["mode"].astype(str)
df_melted

In [None]:
sns.set(style="whitegrid")
fig, axes = plt.subplots(1, 3, figsize=(6, 3), dpi=300, sharey=True)
fig.suptitle("Domain Identification Datasets 1-3", fontsize=14)

metrics = ["NMI", "HOM", "COM"]
titles = ["NMI", "HOM", "COM"]

for ax, metric, title in zip(axes, metrics, titles):
    sns.barplot(
        data=df_melted[df_melted["Metric"] == metric],
        x="mode",
        y="Score",
        ax=ax,
        palette="Blues_d",
        hue="hue",
    )

    ax.set_title(title, fontsize=12)
    ax.set_xlabel("")
    ax.set_ylabel("Score" if metric == "NMI" else "")
    ax.set_xticks([])
    ax.set_ylim(0.5, 0.7)

labels = [
    "Baseline",
    "Baseline + SpatialNoise + FeatureNoise",
    "DropImportance",
    "DropImportance + SpatialNoise + FeatureNoise",
    "DropImportance + Spatial Noise + Feature Noise + ShufflePositions",
    "DropImportance + Spatial Noise + Feature Noise + AddEdges",
]
palette = sns.color_palette("Blues_d", n_colors=len(labels))
handles = [Patch(color=palette[i], label=labels[i]) for i in range(len(labels))]

fig.legend(
    handles,
    labels,
    loc="lower center",
    bbox_to_anchor=(0.5, -0.45),
    title="Augmentation Mode",
    title_fontsize="11",
    fontsize="10",
)

plt.tight_layout()
plt.show()

In [None]:
merfish_small5_baseline = sc.read_h5ad(
    "../data/domain/results/domain123/baseline/2025-04-16_08-55-57/adata_files/MERFISH_small5.h5ad"
)
merfish_small5_baseline

In [None]:
merfish_small5_spatial_noise_feature_noise = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-56-13/adata_files/MERFISH_small5.h5ad"
)
merfish_small5_spatial_noise_feature_noise

In [None]:
merfish_small5_drop_importance = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-48-55/adata_files/MERFISH_small5.h5ad"
)
merfish_small5_drop_importance

In [None]:
merfish_small5_drop_importance_feature_noise_spatial_noise = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-19_08-58-11/adata_files/MERFISH_small5.h5ad"
)
merfish_small5_drop_importance_feature_noise_spatial_noise

In [None]:
merfish_small5_drop_importance_feature_noise_spatial_noise_shuffle = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-20_09-55-01/adata_files/MERFISH_small5.h5ad"
)
merfish_small5_drop_importance_feature_noise_spatial_noise_shuffle

In [None]:
merfish_small5_drop_importance_feature_noise_spatial_noise_addedges = sc.read_h5ad(
    "../data/domain/results/domain123/augmentation/2025-05-20_13-22-07/adata_files/MERFISH_small5.h5ad"
)
merfish_small5_drop_importance_feature_noise_spatial_noise_addedges

In [None]:
merfish_small5_baseline.obsm["spatial"] *= -1
merfish_small5_baseline.obsm["spatial"][:, 0] *= -1

merfish_small5_spatial_noise_feature_noise.obsm["spatial"] *= -1
merfish_small5_spatial_noise_feature_noise.obsm["spatial"][:, 0] *= -1

merfish_small5_drop_importance.obsm["spatial"] *= -1
merfish_small5_drop_importance.obsm["spatial"][:, 0] *= -1

merfish_small5_drop_importance_feature_noise_spatial_noise.obsm["spatial"] *= -1
merfish_small5_drop_importance_feature_noise_spatial_noise.obsm["spatial"][:, 0] *= -1

merfish_small5_drop_importance_feature_noise_spatial_noise_shuffle.obsm["spatial"] *= -1
merfish_small5_drop_importance_feature_noise_spatial_noise_shuffle.obsm["spatial"][:, 0] *= -1

merfish_small5_drop_importance_feature_noise_spatial_noise_addedges.obsm["spatial"] *= -1
merfish_small5_drop_importance_feature_noise_spatial_noise_addedges.obsm["spatial"][:, 0] *= -1

adata_objects = [
    merfish_small5_baseline,
    merfish_small5_baseline,
    merfish_small5_spatial_noise_feature_noise,
    merfish_small5_drop_importance,
    merfish_small5_drop_importance_feature_noise_spatial_noise,
    merfish_small5_drop_importance_feature_noise_spatial_noise_shuffle,
    merfish_small5_drop_importance_feature_noise_spatial_noise_addedges,
]
titles = [
    "Ground Truth",
    "Baseline \n(NMI: 0.68)",
    "Baseline + Spatial Noise + Feature Noise \n(NMI: 0.69)",
    "DropImportance \n(NMI: 0.64)",
    "DropImportance + Spatial Noise \n+ Feature Noise (NMI: 0.70)",
    "DropImportance + Spatial Noise \n+ Feature Noise + ShufflePositions \n(NMI: 0.70)",
    "DropImportance + Spatial Noise \n+ Feature Noise + AddEdges \n(NMI: 0.67)",
]
color_keys = ["domain_annotation", "leiden", "leiden", "leiden", "leiden", "leiden", "leiden"]

fig, axes = plt.subplots(2, 4, figsize=(8, 6), dpi=300)
for ax, adata, color_key, title in zip(axes.flatten(), adata_objects, color_keys, titles):
    sc.pl.embedding(
        adata,
        basis="spatial",
        color=color_key,
        size=20,
        title=None,
        ax=ax,
        show=False,
        legend_loc=None,
    )

    ax.set_title(title, fontsize=8)
    ax.set_xlabel("")
    ax.set_ylabel("")

axes.flatten()[7].set_visible(False)

plt.tight_layout()
plt.show()

In [None]:
merfish_small5_baseline.uns["leiden_colors"] = [
    merfish_small5_baseline.uns["leiden_colors"][i] for i in [1, 0, 3, 4, 2, 7, 5, 6]
]

merfish_small5_spatial_noise_feature_noise.uns["leiden_colors"] = [
    merfish_small5_spatial_noise_feature_noise.uns["leiden_colors"][i]
    for i in [1, 0, 6, 2, 7, 3, 5, 4]
]

merfish_small5_drop_importance.uns["leiden_colors"] = [
    merfish_small5_drop_importance.uns["leiden_colors"][i] for i in [1, 2, 4, 3, 0, 7, 5, 6]
]

merfish_small5_drop_importance_feature_noise_spatial_noise.uns["leiden_colors"] = [
    merfish_small5_drop_importance_feature_noise_spatial_noise.uns["leiden_colors"][i]
    for i in [0, 7, 3, 1, 2, 4, 5, 6]
]

merfish_small5_drop_importance_feature_noise_spatial_noise_shuffle.uns["leiden_colors"] = [
    merfish_small5_drop_importance_feature_noise_spatial_noise_shuffle.uns["leiden_colors"][i]
    for i in [1, 3, 0, 2, 7, 5, 4, 6]
]

merfish_small5_drop_importance_feature_noise_spatial_noise_addedges.uns["leiden_colors"] = [
    merfish_small5_drop_importance_feature_noise_spatial_noise_addedges.uns["leiden_colors"][i]
    for i in [1, 0, 2, 6, 7, 3, 4, 5]
]

## Domain4 on Domain123 Baseline

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain4/domain4_domain123_baseline/2025-04-18_11-34-38/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="sample_name",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2), sharex=False, dpi=300)

metrics = ["nmi", "homogeneity", "completeness"]
titles = ["NMI", "HOM", "COM"]
xlims_list = [[0.58, 0.70], [0.49, 0.67], [0.59, 0.74]]
xticks_list = [
    [0.59, 0.60, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69],
    [0.51, 0.53, 0.55, 0.57, 0.59, 0.61, 0.63, 0.65],
    [0.60, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.70, 0.71, 0.72, 0.73],
]
xtick_labels_list = [
    ["", "0.60", "", "", "", "", "0.65", "", "", "", ""],
    ["", "0.53", "", "", "", "", "", "0.65"],
    ["0.60", "", "", "", "", "", "", "", "", "", "", "", "0.73", ""],
]

for ax, metric, title, xlims, xticks, xtick_labels in zip(
    axes, metrics, titles, xlims_list, xticks_list, xtick_labels_list
):
    sns.boxplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        orient="h",
        dodge=False,
    )
    sns.stripplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        color="black",
        size=5,
        jitter=True,
        orient="h",
    )
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Score")
    ax.set_ylabel("")
    ax.set_yticklabels([])

    ax.set_xlim(xlims)
    ax.set_xticks(xticks)
    ax.set_xticklabels(xtick_labels)

plt.tight_layout()
plt.show()

In [None]:
xenium2 = sc.read_h5ad(
    "../data/domain/results/domain4/domain4_domain123_baseline/2025-04-18_11-34-38/adata_files/Xenium2.h5ad"
)
xenium2

In [None]:
domain_label = "domain_annotation"
title = "Xenium Ground Truth"

x_midpoint = xenium2.obsm["spatial"][:, 0].mean()
left_half_xenium = xenium2[xenium2.obsm["spatial"][:, 0] < x_midpoint]

unique_classes = left_half_xenium.obs[domain_label].unique()
num_classes = len(unique_classes)

palette = sns.color_palette("tab20", num_classes)
left_half_xenium.uns[f"{domain_label}_colors"] = [mcolors.rgb2hex(c) for c in palette]

fig, ax = plt.subplots(figsize=(4, 5), dpi=300)
sc.pl.embedding(
    left_half_xenium,
    basis="spatial",
    color=domain_label,
    size=5,
    title=title,
    legend_loc=None,
    show=False,
    ax=ax,
)
plt.show()

In [None]:
domain_label = "leiden"
title = "NMI: 0.60, HOM: 0.57, COM: 0.62"

x_midpoint = xenium2.obsm["spatial"][:, 0].mean()
left_half_xenium = xenium2[xenium2.obsm["spatial"][:, 0] < x_midpoint]

unique_classes = left_half_xenium.obs[domain_label].unique()
num_classes = len(unique_classes)

palette = sns.color_palette("tab20", num_classes)
left_half_xenium.uns[f"{domain_label}_colors"] = [mcolors.rgb2hex(c) for c in palette]

fig, ax = plt.subplots(figsize=(4, 5), dpi=300)
sc.pl.embedding(
    left_half_xenium,
    basis="spatial",
    color=domain_label,
    size=5,
    title=title,
    legend_loc=None,
    show=False,
    ax=ax,
)
plt.show()

## Domain7 Baseline

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain7/baseline/2025-04-22_08-31-58/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
def assign_dataset(sample_name):
    if "ABCA-1" in sample_name:
        return "dataset7.1"
    elif "ABCA-2" in sample_name:
        return "dataset7.2"
    elif "ABCA-3" in sample_name:
        return "dataset7.3"
    elif "ABCA-4" in sample_name:
        return "dataset7.4"
    else:
        return "Unknown"


test_results["dataset"] = test_results["sample_name"].apply(assign_dataset)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="dataset",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)
data_set_nmi_mean = pd.DataFrame(test_results.groupby("dataset")["nmi"].mean())
category_order = ["dataset7.1", "dataset7.2", "dataset7.3", "dataset7.4"]
hue_order = ["dataset7.1", "dataset7.2", "dataset7.3", "dataset7.4"]
palette = sns.color_palette("tab10")[0:4]

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(2, 5), sharex=True, dpi=300)

sns.barplot(
    data=data_set_nmi_mean,
    x="dataset",
    y="nmi",
    ax=axes[0],
    palette=palette,
    order=category_order,
    hue="dataset",
    hue_order=hue_order,
    legend=False,
)

axes[0].set_title("Baseline Performance", fontsize=12)
axes[0].set_xlabel("Dataset")
axes[0].set_ylabel("NMI")
axes[0].set_yticks(
    [
        0.00,
        0.04,
        0.08,
        0.12,
        0.16,
        0.20,
        0.24,
        0.28,
        0.32,
        0.36,
        0.40,
        0.44,
        0.48,
        0.52,
        0.56,
        0.60,
        0.64,
        0.68,
    ]
)
axes[0].set_yticklabels(
    ["0.0", "", "", "", "", "0.2", "", "", "", "", "0.4", "", "", "", "", "0.6", "", ""]
)

sns.boxplot(
    x="dataset",
    y="Score",
    data=test_results_melted[test_results_melted["Metric"] == "nmi"],
    ax=axes[1],
    palette=palette,
    order=category_order,
    hue="dataset",
    hue_order=hue_order,
    legend=False,
)
sns.stripplot(
    x="dataset",
    y="Score",
    data=test_results_melted[test_results_melted["Metric"] == "nmi"],
    ax=axes[1],
    color="black",
    size=5,
    jitter=True,
)

axes[1].set_xlabel("Dataset")
axes[1].set_ylabel("NMI")
axes[1].set_yticks(
    [
        0.20,
        0.22,
        0.24,
        0.26,
        0.28,
        0.30,
        0.32,
        0.34,
        0.36,
        0.38,
        0.40,
        0.42,
        0.44,
        0.46,
        0.48,
        0.50,
        0.52,
        0.54,
        0.56,
        0.58,
        0.60,
        0.62,
        0.64,
        0.66,
        0.68,
        0.70,
        0.72,
    ]
)
axes[1].set_yticklabels(
    [
        "0.2",
        "",
        "",
        "",
        "",
        "0.3",
        "",
        "",
        "",
        "",
        "0.4",
        "",
        "",
        "",
        "",
        "0.5",
        "",
        "",
        "",
        "",
        "0.6",
        "",
        "",
        "",
        "",
        "0.7",
        "",
    ]
)
axes[1].tick_params(axis="x", rotation=90, bottom=True)

plt.tight_layout()
plt.show()

In [None]:
abca4_003 = sc.read_h5ad(
    "../data/domain/results/domain7/baseline/2025-04-22_08-31-58/adata_files/Zhuang-ABCA-4.003.h5ad"
)
abca4_003

In [None]:
plot_merfish_zhuang_large(
    abca4_003, "domain_annotation", title="Zhuang-ABCA-4.003 Ground Truth", figsize=(10, 6)
)
plot_merfish_zhuang_large(
    abca4_003, "leiden", title="Zhuang-ABCA-4.003 NMI: 0.57, HOM: 0.52, COM: 0.62", figsize=(10, 6)
)

In [None]:
abca2_036 = sc.read_h5ad(
    "../data/domain/results/domain7/baseline/2025-04-22_08-31-58/adata_files/Zhuang-ABCA-2.036.h5ad"
)
abca2_036

In [None]:
plot_merfish_zhuang_large(
    abca2_036, "domain_annotation", size=10, title="Zhuang-ABCA-2.036 Ground Truth"
)
plot_merfish_zhuang_large(
    abca2_036, "leiden", size=10, title="Zhuang-ABCA-2.036 NMI: 0.60, HOM: 0.57, COM: 0.62"
)

## Domain 7 Augmentation - SpatialNoise + FeatureNoise

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain7/augmentation/2025-05-22_14-30-23/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
def assign_dataset(sample_name):
    if "ABCA-1" in sample_name:
        return "dataset7.1"
    elif "ABCA-2" in sample_name:
        return "dataset7.2"
    elif "ABCA-3" in sample_name:
        return "dataset7.3"
    elif "ABCA-4" in sample_name:
        return "dataset7.4"
    else:
        return "Unknown"


test_results["dataset"] = test_results["sample_name"].apply(assign_dataset)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="dataset",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)
data_set_nmi_mean = pd.DataFrame(test_results.groupby("dataset")["nmi"].mean())
category_order = ["dataset7.1", "dataset7.2", "dataset7.3", "dataset7.4"]
hue_order = ["dataset7.1", "dataset7.2", "dataset7.3", "dataset7.4"]
palette = sns.color_palette("tab10")[0:4]

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(2, 5), sharex=True, dpi=300)

sns.barplot(
    data=data_set_nmi_mean,
    x="dataset",
    y="nmi",
    ax=axes[0],
    palette=palette,
    order=category_order,
    hue="dataset",
    hue_order=hue_order,
    legend=False,
)

axes[0].set_title("Noise Performance", fontsize=12)
axes[0].set_xlabel("Dataset")
axes[0].set_ylabel("NMI")
axes[0].set_yticks(
    [
        0.00,
        0.04,
        0.08,
        0.12,
        0.16,
        0.20,
        0.24,
        0.28,
        0.32,
        0.36,
        0.40,
        0.44,
        0.48,
        0.52,
        0.56,
        0.60,
        0.64,
        0.68,
    ]
)
axes[0].set_yticklabels(
    ["0.0", "", "", "", "", "0.2", "", "", "", "", "0.4", "", "", "", "", "0.6", "", ""]
)

sns.boxplot(
    x="dataset",
    y="Score",
    data=test_results_melted[test_results_melted["Metric"] == "nmi"],
    ax=axes[1],
    palette=palette,
    order=category_order,
    hue="dataset",
    hue_order=hue_order,
    legend=False,
)
sns.stripplot(
    x="dataset",
    y="Score",
    data=test_results_melted[test_results_melted["Metric"] == "nmi"],
    ax=axes[1],
    color="black",
    size=5,
    jitter=True,
)

axes[1].set_xlabel("Dataset")
axes[1].set_ylabel("NMI")
axes[1].set_yticks(
    [
        0.20,
        0.22,
        0.24,
        0.26,
        0.28,
        0.30,
        0.32,
        0.34,
        0.36,
        0.38,
        0.40,
        0.42,
        0.44,
        0.46,
        0.48,
        0.50,
        0.52,
        0.54,
        0.56,
        0.58,
        0.60,
        0.62,
        0.64,
        0.66,
        0.68,
        0.70,
        0.72,
    ]
)
axes[1].set_yticklabels(
    [
        "0.2",
        "",
        "",
        "",
        "",
        "0.3",
        "",
        "",
        "",
        "",
        "0.4",
        "",
        "",
        "",
        "",
        "0.5",
        "",
        "",
        "",
        "",
        "0.6",
        "",
        "",
        "",
        "",
        "0.7",
        "",
    ]
)
axes[1].tick_params(axis="x", rotation=90, bottom=True)

plt.tight_layout()
plt.show()

In [None]:
abca4_003 = sc.read_h5ad(
    "../data/domain/results/domain7/augmentation/2025-05-22_14-30-23/adata_files/Zhuang-ABCA-4.003.h5ad"
)
abca4_003

In [None]:
plot_merfish_zhuang_large(
    abca4_003, "domain_annotation", title="Zhuang-ABCA-4.003 Ground Truth", figsize=(10, 6)
)
plot_merfish_zhuang_large(
    abca4_003, "leiden", title="Zhuang-ABCA-4.003 NMI: 0.57, HOM: 0.52, COM: 0.62", figsize=(10, 6)
)

In [None]:
abca2_036 = sc.read_h5ad(
    "../data/domain/results/domain7//augmentation/2025-05-22_14-30-23/adata_files/Zhuang-ABCA-2.036.h5ad"
)
abca2_036

In [None]:
plot_merfish_zhuang_large(
    abca2_036, "domain_annotation", size=10, title="Zhuang-ABCA-2.036 Ground Truth"
)
plot_merfish_zhuang_large(
    abca2_036, "leiden", size=10, title="Zhuang-ABCA-2.036 NMI: 0.60, HOM: 0.57, COM: 0.62"
)

## Domain 7 Augmentation - DropImportance + SpatialNoise + FeatureNoise

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain7/augmentation/2025-05-23_08-10-53/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
def assign_dataset(sample_name):
    if "ABCA-1" in sample_name:
        return "dataset7.1"
    elif "ABCA-2" in sample_name:
        return "dataset7.2"
    elif "ABCA-3" in sample_name:
        return "dataset7.3"
    elif "ABCA-4" in sample_name:
        return "dataset7.4"
    else:
        return "Unknown"


test_results["dataset"] = test_results["sample_name"].apply(assign_dataset)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="dataset",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)
data_set_nmi_mean = pd.DataFrame(test_results.groupby("dataset")["nmi"].mean())
category_order = ["dataset7.1", "dataset7.2", "dataset7.3", "dataset7.4"]
hue_order = ["dataset7.1", "dataset7.2", "dataset7.3", "dataset7.4"]
palette = sns.color_palette("tab10")[0:4]

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(2, 5), sharex=True, dpi=300)

sns.barplot(
    data=data_set_nmi_mean,
    x="dataset",
    y="nmi",
    ax=axes[0],
    palette=palette,
    order=category_order,
    hue="dataset",
    hue_order=hue_order,
    legend=False,
)

axes[0].set_title("Noise Performance", fontsize=12)
axes[0].set_xlabel("Dataset")
axes[0].set_ylabel("NMI")
axes[0].set_yticks(
    [
        0.00,
        0.04,
        0.08,
        0.12,
        0.16,
        0.20,
        0.24,
        0.28,
        0.32,
        0.36,
        0.40,
        0.44,
        0.48,
        0.52,
        0.56,
        0.60,
        0.64,
        0.68,
    ]
)
axes[0].set_yticklabels(
    ["0.0", "", "", "", "", "0.2", "", "", "", "", "0.4", "", "", "", "", "0.6", "", ""]
)

sns.boxplot(
    x="dataset",
    y="Score",
    data=test_results_melted[test_results_melted["Metric"] == "nmi"],
    ax=axes[1],
    palette=palette,
    order=category_order,
    hue="dataset",
    hue_order=hue_order,
    legend=False,
)
sns.stripplot(
    x="dataset",
    y="Score",
    data=test_results_melted[test_results_melted["Metric"] == "nmi"],
    ax=axes[1],
    color="black",
    size=5,
    jitter=True,
)

axes[1].set_xlabel("Dataset")
axes[1].set_ylabel("NMI")
axes[1].set_yticks(
    [
        0.20,
        0.22,
        0.24,
        0.26,
        0.28,
        0.30,
        0.32,
        0.34,
        0.36,
        0.38,
        0.40,
        0.42,
        0.44,
        0.46,
        0.48,
        0.50,
        0.52,
        0.54,
        0.56,
        0.58,
        0.60,
        0.62,
        0.64,
        0.66,
        0.68,
        0.70,
        0.72,
    ]
)
axes[1].set_yticklabels(
    [
        "0.2",
        "",
        "",
        "",
        "",
        "0.3",
        "",
        "",
        "",
        "",
        "0.4",
        "",
        "",
        "",
        "",
        "0.5",
        "",
        "",
        "",
        "",
        "0.6",
        "",
        "",
        "",
        "",
        "0.7",
        "",
    ]
)
axes[1].tick_params(axis="x", rotation=90, bottom=True)

plt.tight_layout()
plt.show()

In [None]:
abca4_003 = sc.read_h5ad(
    "../data/domain/results/domain7/augmentation/2025-05-23_08-10-53/adata_files/Zhuang-ABCA-4.003.h5ad"
)
abca4_003

In [None]:
plot_merfish_zhuang_large(
    abca4_003, "domain_annotation", title="Zhuang-ABCA-4.003 Ground Truth", figsize=(10, 6)
)
plot_merfish_zhuang_large(
    abca4_003, "leiden", title="Zhuang-ABCA-4.003 NMI: 0.57, HOM: 0.52, COM: 0.62", figsize=(10, 6)
)

In [None]:
abca2_036 = sc.read_h5ad(
    "../data/domain/results/domain7/augmentation/2025-05-23_08-10-53/adata_files/Zhuang-ABCA-2.036.h5ad"
)
abca2_036

In [None]:
plot_merfish_zhuang_large(
    abca2_036, "domain_annotation", size=10, title="Zhuang-ABCA-2.036 Ground Truth"
)
plot_merfish_zhuang_large(
    abca2_036, "leiden", size=10, title="Zhuang-ABCA-2.036 NMI: 0.60, HOM: 0.57, COM: 0.62"
)

## Domain7 Baseline vs. Augmentation

In [None]:
test_results_baseline = pd.read_csv(
    "../data/domain/results/domain7/baseline/2025-04-22_08-31-58/csv/version_0/test_results.csv"
)
test_results_baseline = test_results_baseline.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results_baseline

In [None]:
test_results_noise = pd.read_csv(
    "../data/domain/results/domain7/augmentation/2025-05-22_14-30-23/csv/version_0/test_results.csv"
)
test_results_noise = test_results_noise.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results_noise

In [None]:
test_results_dropimportance_noise = pd.read_csv(
    "../data/domain/results/domain7/augmentation/2025-05-23_08-10-53/csv/version_0/test_results.csv"
)
test_results_dropimportance_noise = test_results_dropimportance_noise.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results_dropimportance_noise

In [None]:
baseline_df = test_results_baseline
augmentation_dfs = {
    "Baseline + Spatial Noise + Feature Noise": test_results_noise,
    "DropImportance + Spatial Noise + Feature Noise": test_results_dropimportance_noise,
}

metrics = ["nmi", "homogeneity", "completeness"]
results = {}

raw_pvals = []
test_info = []

for aug_name, aug_df in augmentation_dfs.items():
    results[aug_name] = {}
    for metric in metrics:
        baseline_values = baseline_df[metric]
        aug_values = aug_df[metric]
        differences = aug_values - baseline_values

        p_normal = shapiro(differences).pvalue

        if p_normal > 0.05:
            stat, p = ttest_rel(baseline_values, aug_values)
            test = "t-test"
        else:
            stat, p = wilcoxon(baseline_values, aug_values)
            test = "wilcoxon"

        # Store raw result
        results[aug_name][metric] = {"test": test, "statistic": stat, "p_value": p}
        raw_pvals.append(p)
        test_info.append((aug_name, metric))

_, corrected_pvals, _, _ = multipletests(raw_pvals, method="fdr_bh")
for (aug_name, metric), p_corr in zip(test_info, corrected_pvals):
    results[aug_name][metric]["p_value_corrected"] = p_corr

print(results)

In [None]:
summary_rows = []
for aug, metrics_dict in results.items():
    row = {"Augmentation": aug}
    for metric in ["nmi", "homogeneity", "completeness"]:
        mean_val = augmentation_dfs[aug][metric].mean()
        p = results[aug][metric]["p_value"]
        if p < 0.001:
            stars = "***"
        elif p < 0.01:
            stars = "**"
        elif p < 0.05:
            stars = "*"
        elif p < 0.1:
            stars = "."
        else:
            stars = ""
        row[metric.upper()] = f"{mean_val:.3f}{stars}"
        row[f"p_value_{metric}"] = p
        row[f"test_{metric}"] = results[aug][metric]["test"]
    summary_rows.append(row)

summary_df = pd.DataFrame(summary_rows)
summary_df

In [None]:
df = pd.DataFrame(
    {
        "mode": [
            "Baseline",
            "Baseline + Spatial Noise + Feature Noise",
            "DropImportance + Spatial Noise + Feature Noise",
        ],
        "NMI": [
            test_results_baseline["nmi"].mean(),
            test_results_noise["nmi"].mean(),
            test_results_dropimportance_noise["nmi"].mean(),
        ],
        "HOM": [
            test_results_baseline["homogeneity"].mean(),
            test_results_noise["homogeneity"].mean(),
            test_results_dropimportance_noise["homogeneity"].mean(),
        ],
        "COM": [
            test_results_baseline["completeness"].mean(),
            test_results_noise["completeness"].mean(),
            test_results_dropimportance_noise["completeness"].mean(),
        ],
    }
)
df

In [None]:
df_melted = df.melt(
    id_vars="mode", value_vars=["NMI", "HOM", "COM"], var_name="Metric", value_name="Score"
)
df_melted["hue"] = df_melted["mode"]
df_melted

In [None]:
sns.set(style="whitegrid")
fig, axes = plt.subplots(1, 3, figsize=(6, 3), dpi=300, sharey=True)
fig.suptitle("Domain Identification Dataset 7", fontsize=14)

metrics = ["NMI", "HOM", "COM"]
titles = ["NMI", "HOM", "COM"]

for ax, metric, title in zip(axes, metrics, titles):
    sns.barplot(
        data=df_melted[df_melted["Metric"] == metric],
        x="mode",
        y="Score",
        ax=ax,
        palette="Blues_d",
        hue="hue",
    )

    ax.set_title(title, fontsize=12)
    ax.set_xlabel("")
    ax.set_ylabel("Score" if metric == "NMI" else "")
    ax.set_xticks([])
    ax.set_ylim([0.5, 0.65])

labels = [
    "Baseline",
    "Baseline + SpatialNoise + FeatureNoise",
    "DropImportance + SpatialNoise + FeatureNoise",
]
palette = sns.color_palette("Blues_d", n_colors=len(labels))
handles = [Patch(color=palette[i], label=labels[i]) for i in range(len(labels))]

fig.legend(
    handles,
    labels,
    loc="lower center",
    bbox_to_anchor=(0.5, -0.25),
    title="Augmentation Mode",
    title_fontsize="11",
    fontsize="10",
)

plt.tight_layout()
plt.show()

In [None]:
abca4_001_baseline = sc.read_h5ad(
    "../data/domain/results/domain7/baseline/2025-04-22_08-31-58/adata_files/Zhuang-ABCA-4.001.h5ad"
)
abca4_001_baseline

In [None]:
abca4_001_noise = sc.read_h5ad(
    "../data/domain/results/domain7/augmentation/2025-05-22_14-30-23/adata_files/Zhuang-ABCA-4.001.h5ad"
)
abca4_001_noise

In [None]:
abca4_001_dropimportance_noise = sc.read_h5ad(
    "../data/domain/results/domain7/augmentation/2025-05-23_08-10-53/adata_files/Zhuang-ABCA-4.001.h5ad"
)
abca4_001_dropimportance_noise

In [None]:
adata_objects = [
    abca4_001_baseline,
    abca4_001_baseline,
    abca4_001_noise,
    abca4_001_dropimportance_noise,
]
titles = [
    "Ground Truth",
    "Baseline \n(NMI: 0.61)",
    "Baseline + Spatial Noise + Feature Noise \n(NMI: 0.62)",
    "DropImportance + Spatial Noise \n+ Feature Noise (NMI: 0.63)",
]
color_keys = ["domain_annotation", "leiden", "leiden", "leiden"]

fig, axes = plt.subplots(2, 2, figsize=(6, 6), dpi=300)
for ax, adata, color_key, title in zip(axes.flatten(), adata_objects, color_keys, titles):
    domain_label = color_key
    unique_classes = adata.obs[domain_label].unique()
    num_classes = len(unique_classes)
    palette = sns.color_palette("tab20", num_classes)
    adata.uns[f"{domain_label}_colors"] = [mcolors.rgb2hex(c) for c in palette]

    sc.pl.embedding(
        adata,
        basis="spatial",
        color=domain_label,
        size=5,
        ax=ax,
        show=False,
        legend_loc=None,
    )

    ax.set_title(title, fontsize=10)
    ax.set_xlabel("")
    ax.set_ylabel("")

    ax.set_ylim(11, 0)
    ax.set_xlim(0, 11)
    ax.axis("equal")
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

In [None]:
abca2_039_baseline = sc.read_h5ad(
    "../data/domain/results/domain7/baseline/2025-04-22_08-31-58/adata_files/Zhuang-ABCA-2.039.h5ad"
)
abca2_039_baseline

In [None]:
abca2_039_noise = sc.read_h5ad(
    "../data/domain/results/domain7//augmentation/2025-05-22_14-30-23/adata_files/Zhuang-ABCA-2.039.h5ad"
)
abca2_039_noise

In [None]:
abca2_039_dropimportance_noise = sc.read_h5ad(
    "../data/domain/results/domain7/augmentation/2025-05-23_08-10-53/adata_files/Zhuang-ABCA-2.039.h5ad"
)
abca2_039_dropimportance_noise

In [None]:
adata_objects = [
    abca2_039_baseline,
    abca2_039_baseline,
    abca2_039_noise,
    abca2_039_dropimportance_noise,
]
titles = [
    "Ground Truth",
    "Baseline \n(NMI: 0.68)",
    "Baseline + Spatial Noise + Feature Noise \n(NMI: 0.69)",
    "DropImportance + Spatial Noise \n+ Feature Noise (NMI: 0.69)",
]
color_keys = ["domain_annotation", "leiden", "leiden", "leiden"]

fig, axes = plt.subplots(2, 2, figsize=(6, 6), dpi=300)
for ax, adata, color_key, title in zip(axes.flatten(), adata_objects, color_keys, titles):
    domain_label = color_key
    unique_classes = adata.obs[domain_label].unique()
    num_classes = len(unique_classes)
    palette = sns.color_palette("tab20", num_classes)
    adata.uns[f"{domain_label}_colors"] = [mcolors.rgb2hex(c) for c in palette]

    sc.pl.embedding(
        adata,
        basis="spatial",
        color=domain_label,
        size=5,
        ax=ax,
        show=False,
        legend_loc=None,
    )

    ax.set_title(title, fontsize=10)
    ax.set_xlabel("")
    ax.set_ylabel("")

    ax.set_ylim(11, 0)
    ax.set_xlim(0, 11)
    ax.axis("equal")
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

## Domain4 on Domain7 Baseline

In [None]:
test_results = pd.read_csv(
    "../data/domain/results/domain4/domain4_domain7_baseline/2025-04-23_13-02-20/csv/version_0/test_results.csv"
)
test_results = test_results.map(
    lambda x: float(x.split("(")[-1].rstrip(")")) if isinstance(x, str) and "tensor" in x else x
)
test_results

In [None]:
test_results_melted = test_results.melt(
    id_vars="sample_name",
    value_vars=["nmi", "homogeneity", "completeness"],
    var_name="Metric",
    value_name="Score",
)

sns.set(style="whitegrid")
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2), sharex=False, dpi=300)

metrics = ["nmi", "homogeneity", "completeness"]
titles = ["NMI", "HOM", "COM"]
xlims_list = [[0.58, 0.70], [0.49, 0.67], [0.59, 0.74]]
xticks_list = [
    [0.59, 0.60, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69],
    [0.51, 0.53, 0.55, 0.57, 0.59, 0.61, 0.63, 0.65],
    [0.60, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.70, 0.71, 0.72, 0.73],
]
xtick_labels_list = [
    ["", "0.60", "", "", "", "", "0.65", "", "", "", ""],
    ["", "0.53", "", "", "", "", "", "0.65"],
    ["0.60", "", "", "", "", "", "", "", "", "", "", "", "0.73", ""],
]

for ax, metric, title, xlims, xticks, xtick_labels in zip(
    axes, metrics, titles, xlims_list, xticks_list, xtick_labels_list
):
    sns.boxplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        orient="h",
        dodge=False,
    )
    sns.stripplot(
        x="Score",
        y="Metric",
        data=test_results_melted[test_results_melted["Metric"] == metric],
        ax=ax,
        color="black",
        size=5,
        jitter=True,
        orient="h",
    )
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("Score")
    ax.set_ylabel("")
    ax.set_yticklabels([])

    ax.set_xlim(xlims)
    ax.set_xticks(xticks)
    ax.set_xticklabels(xtick_labels)

plt.tight_layout()
plt.show()

In [None]:
xenium2 = sc.read_h5ad(
    "../data/domain/results/domain4/domain4_domain7_baseline/2025-04-23_13-02-20/adata_files/Xenium2.h5ad"
)
xenium2

In [None]:
domain_label = "domain_annotation"
title = "Xenium Ground Truth"

x_midpoint = xenium2.obsm["spatial"][:, 0].mean()
left_half_xenium = xenium2[xenium2.obsm["spatial"][:, 0] < x_midpoint]

unique_classes = left_half_xenium.obs[domain_label].unique()
num_classes = len(unique_classes)

palette = sns.color_palette("tab20", num_classes)
left_half_xenium.uns[f"{domain_label}_colors"] = [mcolors.rgb2hex(c) for c in palette]

fig, ax = plt.subplots(figsize=(4, 5), dpi=300)
sc.pl.embedding(
    left_half_xenium,
    basis="spatial",
    color=domain_label,
    size=5,
    title=title,
    legend_loc=None,
    show=False,
    ax=ax,
)
plt.show()

In [None]:
domain_label = "leiden"
title = "NMI: 0.62, HOM: 0.60, COM: 0.65"

x_midpoint = xenium2.obsm["spatial"][:, 0].mean()
left_half_xenium = xenium2[xenium2.obsm["spatial"][:, 0] < x_midpoint]

unique_classes = left_half_xenium.obs[domain_label].unique()
num_classes = len(unique_classes)

palette = sns.color_palette("tab20", num_classes)
left_half_xenium.uns[f"{domain_label}_colors"] = [mcolors.rgb2hex(c) for c in palette]

fig, ax = plt.subplots(figsize=(4, 5), dpi=300)
sc.pl.embedding(
    left_half_xenium,
    basis="spatial",
    color=domain_label,
    size=5,
    title=title,
    legend_loc=None,
    show=False,
    ax=ax,
)
plt.show()

## Phenotype NSCLC Baseline

## Phenotype NSCLC Augmentation - FeatureNoise

## Phenotype NSCLC Augmentation - DropImportance

## Phenotype NSCLC Augmentation - DropImportance + FeatureNoise

## Phenotype NSCLC Augmentation - DropImportance + FeatureNoise + ShufflePositions

## Phenotype NSCLC Augmentation - DropImportance + FeatureNoise + AddEdgesByCellType

## Phenotype NSCLC Baseline vs. Augmentation

In [None]:
test_results_baseline = pd.read_csv(
    "../data/phenotype/results/nsclc/baseline/2025-05-22_15-17-27/csv/version_0/metrics.csv"
)
test_results_baseline

In [None]:
test_results_noise = pd.read_csv(
    "../data/phenotype/results/nsclc/augmentation/2025-05-22_15-17-27_baseline_noise/csv/version_0/metrics.csv"
)
test_results_noise

In [None]:
test_results_dropimportance = pd.read_csv(
    "../data/phenotype/results/nsclc/augmentation/2025-05-22_16-40-46_dropimportance/csv/version_0/metrics.csv"
)
test_results_dropimportance

In [None]:
test_results_dropimportance_noise = pd.read_csv(
    "../data/phenotype/results/nsclc/augmentation/2025-05-23_08-11-57_dropimportance_noise/csv/version_0/metrics.csv"
)
test_results_dropimportance_noise

In [None]:
test_results_dropimportance_noise_shuffle = pd.read_csv(
    "../data/phenotype/results/nsclc/augmentation/2025-05-22_16-14-18_dropimportance_noise_shuffle/csv/version_0/metrics.csv"
)
test_results_dropimportance_noise_shuffle

In [None]:
test_results_dropimportance_noise_addedges = pd.read_csv(
    "../data/phenotype/results/nsclc/augmentation/2025-05-23_15-27-46_dropimportance_noise_addedges/csv/version_0/metrics.csv"
)
test_results_dropimportance_noise_addedges

In [None]:
df = pd.DataFrame(
    {
        "mode": [
            "Baseline",
            "Baseline + Feature Noise",
            "DropImportance",
            "DropImportance + Feature Noise",
            "DropImportance + Feature Noise + ShufflePositions",
            "DropImportance + Feature Noise + AddEdges",
        ],
        "Accuracy": [
            test_results_baseline["test/accuracy"].item(),
            test_results_noise["test/accuracy"].item(),
            test_results_dropimportance["test/accuracy"].item(),
            test_results_dropimportance_noise["test/accuracy"].item(),
            test_results_dropimportance_noise_shuffle["test/accuracy"].item(),
            test_results_dropimportance_noise_addedges["test/accuracy"].item(),
        ],
        "Balanced Accuracy": [
            test_results_baseline["test/balanced_accuracy"].item(),
            test_results_noise["test/balanced_accuracy"].item(),
            test_results_dropimportance["test/balanced_accuracy"].item(),
            test_results_dropimportance_noise["test/balanced_accuracy"].item(),
            test_results_dropimportance_noise_shuffle["test/balanced_accuracy"].item(),
            test_results_dropimportance_noise_addedges["test/balanced_accuracy"].item(),
        ],
        "AUROC": [
            test_results_baseline["test/auroc"].item(),
            test_results_noise["test/auroc"].item(),
            test_results_dropimportance["test/auroc"].item(),
            test_results_dropimportance_noise["test/auroc"].item(),
            test_results_dropimportance_noise_shuffle["test/auroc"].item(),
            test_results_dropimportance_noise_addedges["test/auroc"].item(),
        ],
        "F1 score": [
            test_results_baseline["test/f1"].item(),
            test_results_noise["test/f1"].item(),
            test_results_dropimportance["test/f1"].item(),
            test_results_dropimportance_noise["test/f1"].item(),
            test_results_dropimportance_noise_shuffle["test/f1"].item(),
            test_results_dropimportance_noise_addedges["test/f1"].item(),
        ],
    }
)
df

In [None]:
df_melted = df.melt(
    id_vars="mode",
    value_vars=["Accuracy", "Balanced Accuracy", "AUROC", "F1 score"],
    var_name="Metric",
    value_name="Score",
)
df_melted["hue"] = df_melted["mode"]
df_melted

In [None]:
sns.set(style="whitegrid")
fig, axes = plt.subplots(1, 4, figsize=(6, 3), dpi=300, sharey=True)
fig.suptitle("Phenotype Prediction - Relapse NSCLC", fontsize=14)

metrics = ["Accuracy", "Balanced Accuracy", "AUROC", "F1 score"]
titles = ["Accuracy", "Balanced Accuracy", "AUROC", "F1 score"]

for ax, metric, title in zip(axes, metrics, titles):
    sns.barplot(
        data=df_melted[df_melted["Metric"] == metric],
        x="mode",
        y="Score",
        ax=ax,
        palette="Blues_d",
        hue="hue",
    )

    ax.set_title(title, fontsize=12)
    ax.set_xlabel("")
    ax.set_ylabel("Score" if metric == "NMI" else "")
    ax.set_xticks([])
    ax.set_ylim(0.4, 0.65)

labels = [
    "Baseline",
    "Baseline + FeatureNoise",
    "DropImportance",
    "DropImportance + FeatureNoise",
    "DropImportance + Feature Noise + ShufflePositions",
    "DropImportance + Feature Noise + AddEdges",
]
palette = sns.color_palette("Blues_d", n_colors=len(labels))
handles = [Patch(color=palette[i], label=labels[i]) for i in range(len(labels))]

fig.legend(
    handles,
    labels,
    loc="lower center",
    bbox_to_anchor=(0.5, -0.45),
    title="Augmentation Mode",
    title_fontsize="11",
    fontsize="10",
)

plt.tight_layout()
plt.show()

## Runtime & Memory Benchmarking - Domain

In [None]:
df = pd.read_csv("../data/benchmark/domain_benchmark_results.csv")

sns.set(style="whitegrid")
palette = sns.color_palette("tab10", n_colors=len(df["augmentation"].unique()))
linestyles = ["-", "--", "-.", ":"]
markers = ["o", "s", "^", "D", "P", "X", "*", "+", "x"]

plt.figure(figsize=(10, 5), dpi=300)
for i, aug in enumerate(df["augmentation"].unique()):
    group = df[df["augmentation"] == aug]
    plt.plot(
        group["num_nodes"],
        group["avg_time_s"],
        label=aug,
        color=palette[i % len(palette)],
        linestyle=linestyles[i % len(linestyles)],
        marker=markers[i % len(markers)],
        linewidth=2,
    )

plt.xscale("log")
# plt.yscale("log")
plt.xlabel("Number of Nodes (log scale)", fontsize=12)
plt.ylabel("Average Runtime (s)", fontsize=12)
plt.title("Domain Augmentation Runtime vs Graph Size", fontsize=14)

plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
sns.set(style="whitegrid")
palette = sns.color_palette("tab10", n_colors=len(df["augmentation"].unique()))
linestyles = ["-", "--", "-.", ":"]
markers = ["o", "s", "^", "D", "P", "X", "*", "+", "x"]

plt.figure(figsize=(10, 5), dpi=300)
for i, aug in enumerate(df["augmentation"].unique()):
    group = df[df["augmentation"] == aug]
    plt.plot(
        group["num_nodes"],
        group["max_memory_mb"],
        label=aug,
        color=palette[i % len(palette)],
        linestyle=linestyles[i % len(linestyles)],
        marker=markers[i % len(markers)],
        linewidth=2,
    )

plt.xscale("log")
# plt.yscale("log")
plt.xlabel("Number of Nodes (log scale)", fontsize=12)
plt.ylabel("Max Memory Usage (MB)", fontsize=12)
plt.title("Domain Augmentation Memory Usage vs Graph Size", fontsize=14)

plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
df = pd.read_csv("../data/benchmark/domain_benchmark_results.csv")

sns.set(style="whitegrid")
palette = sns.color_palette("tab10", n_colors=len(df["augmentation"].unique()))
linestyles = ["-", "--", "-.", ":"]
markers = ["o", "s", "^", "D", "P", "X", "*", "+", "x"]

fig, axes = plt.subplots(1, 2, figsize=(12, 5), dpi=300, sharex=True)

for i, aug in enumerate(df["augmentation"].unique()):
    group = df[df["augmentation"] == aug]
    axes[0].plot(
        group["num_nodes"],
        group["avg_time_s"],
        label=aug,
        color=palette[i % len(palette)],
        linestyle=linestyles[i % len(linestyles)],
        marker=markers[i % len(markers)],
        linewidth=2,
    )
axes[0].set_xscale("log")
# axes[0].set_yscale("log")
axes[0].set_xlabel("Number of Nodes (log scale)", fontsize=12)
axes[0].set_ylabel("Average Runtime (s)", fontsize=12)
axes[0].set_title("Domain Augmentation Runtime vs Graph Size", fontsize=14)
axes[0].grid(True, which="both", linestyle="--", linewidth=0.5)

for i, aug in enumerate(df["augmentation"].unique()):
    group = df[df["augmentation"] == aug]
    axes[1].plot(
        group["num_nodes"],
        group["max_memory_mb"],
        label=aug,
        color=palette[i % len(palette)],
        linestyle=linestyles[i % len(linestyles)],
        marker=markers[i % len(markers)],
        linewidth=2,
    )
axes[1].set_xscale("log")
# axes[1].set_yscale("log")
axes[1].set_xlabel("Number of Nodes (log scale)", fontsize=12)
axes[1].set_ylabel("Max Memory Usage (MB)", fontsize=12)
axes[1].set_title("Domain Augmentation Memory Usage vs Graph Size", fontsize=14)
axes[1].grid(True, which="both", linestyle="--", linewidth=0.5)

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(
    handles,
    labels,
    loc="lower center",
    bbox_to_anchor=(0.5, -0.2),
    ncol=3,
    fontsize=10,
    frameon=True,
    title="Augmentation Mode",
    title_fontsize=12,
)

plt.tight_layout()
plt.show()

## Runtime & Memory Benchmarking - Phenotype

In [None]:
df = pd.read_csv("../data/benchmark/phenotype_benchmark_results.csv")

sns.set(style="whitegrid")
palette = sns.color_palette("tab10", n_colors=len(df["augmentation"].unique()))
linestyles = ["-", "--", "-.", ":"]
markers = ["o", "s", "^", "D", "P", "X", "*", "+", "x"]

plt.figure(figsize=(10, 5), dpi=300)
for i, aug in enumerate(df["augmentation"].unique()):
    group = df[df["augmentation"] == aug]
    plt.plot(
        group["num_nodes"],
        group["avg_time_s"],
        label=aug,
        color=palette[i % len(palette)],
        linestyle=linestyles[i % len(linestyles)],
        marker=markers[i % len(markers)],
        linewidth=2,
    )

plt.xscale("log")
# plt.yscale("log")
plt.xlabel("Number of Nodes (log scale)", fontsize=12)
plt.ylabel("Average Runtime (s)", fontsize=12)
plt.title("Phenotype Augmentation Runtime vs Graph Size", fontsize=14)

plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
sns.set(style="whitegrid")
palette = sns.color_palette("tab10", n_colors=len(df["augmentation"].unique()))
linestyles = ["-", "--", "-.", ":"]
markers = ["o", "s", "^", "D", "P", "X", "*", "+", "x"]

plt.figure(figsize=(10, 5), dpi=300)
for i, aug in enumerate(df["augmentation"].unique()):
    group = df[df["augmentation"] == aug]
    plt.plot(
        group["num_nodes"],
        group["max_memory_mb"],
        label=aug,
        color=palette[i % len(palette)],
        linestyle=linestyles[i % len(linestyles)],
        marker=markers[i % len(markers)],
        linewidth=2,
    )

plt.xscale("log")
# plt.yscale("log")
plt.xlabel("Number of Nodes (log scale)", fontsize=12)
plt.ylabel("Max Memory Usage (MB)", fontsize=12)
plt.title("Phenotype Augmentation Memory Usage vs Graph Size", fontsize=14)

plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
df = pd.read_csv("../data/benchmark/phenotype_benchmark_results.csv")

sns.set(style="whitegrid")
palette = sns.color_palette("tab10", n_colors=len(df["augmentation"].unique()))
linestyles = ["-", "--", "-.", ":"]
markers = ["o", "s", "^", "D", "P", "X", "*", "+", "x"]

fig, axes = plt.subplots(1, 2, figsize=(12, 5), dpi=300, sharex=True)

for i, aug in enumerate(df["augmentation"].unique()):
    group = df[df["augmentation"] == aug]
    axes[0].plot(
        group["num_nodes"],
        group["avg_time_s"],
        label=aug,
        color=palette[i % len(palette)],
        linestyle=linestyles[i % len(linestyles)],
        marker=markers[i % len(markers)],
        linewidth=2,
    )
axes[0].set_xscale("log")
# axes[0].set_yscale("log")
axes[0].set_xlabel("Number of Nodes (log scale)", fontsize=12)
axes[0].set_ylabel("Average Runtime (s)", fontsize=12)
axes[0].set_title("Phenotype Augmentation Runtime vs Graph Size", fontsize=14)
axes[0].grid(True, which="both", linestyle="--", linewidth=0.5)

for i, aug in enumerate(df["augmentation"].unique()):
    group = df[df["augmentation"] == aug]
    axes[1].plot(
        group["num_nodes"],
        group["max_memory_mb"],
        label=aug,
        color=palette[i % len(palette)],
        linestyle=linestyles[i % len(linestyles)],
        marker=markers[i % len(markers)],
        linewidth=2,
    )
axes[1].set_xscale("log")
# axes[1].set_yscale("log")
axes[1].set_xlabel("Number of Nodes (log scale)", fontsize=12)
axes[1].set_ylabel("Max Memory Usage (MB)", fontsize=12)
axes[1].set_title("Phenotype Augmentation Memory Usage vs Graph Size", fontsize=14)
axes[1].grid(True, which="both", linestyle="--", linewidth=0.5)

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(
    handles,
    labels,
    loc="lower center",
    bbox_to_anchor=(0.5, -0.2),
    ncol=3,
    fontsize=10,
    frameon=True,
    title="Augmentation Mode",
    title_fontsize=12,
)

plt.tight_layout()
plt.show()