# Experiments

In [None]:
from pathlib import Path
from data_processing import (load_binary_mask, 
                             extract_connected_components, 
                             match_lesions)
from evaluation_pipeline import (evaluate_case,
                                 evaluate_fold,
                                 aggregate_fold_results,
                                 evaluate_experiment)
from exporting import (export_case_results_to_csv)
from metrics import (compute_detection_and_lesion_metrics, 
                     compute_scan_dice)
from statistical_tests import compare_models_wilcoxon
import pandas as pd
import numpy as np
import copy
import json

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from collections import defaultdict

from scipy.stats import wilcoxon
import pandas as pd
import plotly.express as px

In [None]:
lesion_size_threshold = 2

## Experiment 1: Ablation Study

Paths and mapping

In [None]:
model_paths = {
    "SAM2_original": Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/predictions/MOIS_SAM2_sam_automatic_lesion_wise/lesion_wise_corrective"),
    "SAM2_exemplars": Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/predictions/MOIS_SAM2_sam_exemplar_lesion_wise/lesion_wise_corrective"),
    "MOIS_no_pointer": Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/predictions/MOIS_SAM2_no_pointer_lesion_wise/lesion_wise_corrective"),
    "MOIS_no_temporal": Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/predictions/MOIS_SAM2_no_temporal_lesion_wise/lesion_wise_corrective"),
    "MOIS_proposed": Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/predictions/MOIS_SAM2_main_lesion_wise/lesion_wise_corrective")
}

gt_root = Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/ground_truth")

sets_and_folds = {
    "TestSet_1": [1],
    "TestSet_2": [2],
    "TestSet_3": [3]
}

gt_set_map = {
    "TestSet_1": "ValSet_1",
    "TestSet_2": "ValSet_2",
    "TestSet_3": "ValSet_3"
}

Run evaluation for each model

In [None]:
experiment_results = {}

for model_name, pred_root in model_paths.items():
    print(f"Evaluating {model_name}...")
    result = evaluate_experiment(
        pred_root=pred_root,
        gt_root=gt_root,
        sets_and_folds=sets_and_folds,
        gt_set_map=gt_set_map,
        size_threshold=lesion_size_threshold
    )
    experiment_results[model_name] = result

Statistical analysis

In [None]:
for model_name, result in experiment_results.items():
    scan_dice_all = []
    lesion_f1_all = []
    lesion_dice_all = []

    for testset in ["TestSet_1", "TestSet_2", "TestSet_3"]:
        cases = result[testset]["cases"]
        for case in cases:
            # Scan Dice
            scan_dice_all.append(case.get("scan_dice", None))

            # Lesion Detection F1
            lesion_f1_all.append(case['detection_metrics']['f1_score'])

            # Lesion Dice (mean of lesion-level DSCs per case)
            lesion_dice = case.get("lesion_dscs", [])
            lesion_dice = [d for d in lesion_dice if d > 0]  # Optional: filter out 0s
            if lesion_dice:
                lesion_dice_all.append(np.mean(lesion_dice))

    def mean_std_report(values):
        values = [v for v in values if v is not None]
        if len(values) == 0:
            return "N/A"
        return f"{np.mean(values):.4f} ± {np.std(values):.4f}"

    print(f"\nModel: {model_name}")
    print(f"  Scan Dice: {mean_std_report(scan_dice_all)}")
    print(f"  Detection Rate: {mean_std_report(lesion_f1_all)}")
    print(f"  Lesion Dice: {mean_std_report(lesion_dice_all)}")

In [None]:
scan_dice_scores_by_model = {}
f1_scores_by_model = {}
lesion_dice_scores_by_model = {}

for model_name, sets in experiment_results.items():
    
    scan_dice_scores = []
    f1_scores = []
    lesion_dice_scores = []

    for set_results in sets.values():
        for case in set_results["cases"]:
            scan_dice_scores.append(case["scan_dice"])
            f1 = case["detection_metrics"]["f1_score"]
            f1_scores.append(f1)

            lesion_dice = case.get("lesion_dscs", [])
            
            if lesion_dice:
                lesion_dice = [d for d in lesion_dice if d > 0.0]
                lesion_dice_scores.append(np.mean(lesion_dice))

    scan_dice_scores_by_model[model_name] = scan_dice_scores
    f1_scores_by_model[model_name] = f1_scores
    lesion_dice_scores_by_model[model_name] = lesion_dice_scores


# Scan Dice
wilcoxon_results_dice = compare_models_wilcoxon(scan_dice_scores_by_model, alpha=0.05, correction='bonferroni')
df_dice = pd.DataFrame(wilcoxon_results_dice)
print("\n=== Wilcoxon: Scan Dice ===")
display(df_dice)

# Lesion Detection F1
wilcoxon_results_f1 = compare_models_wilcoxon(f1_scores_by_model, alpha=0.05, correction='bonferroni')
df_f1 = pd.DataFrame(wilcoxon_results_f1)
print("\n=== Wilcoxon: Lesion Detection F1 ===")
display(df_f1)

# Lesion-wise Dice
wilcoxon_results_lesion = compare_models_wilcoxon(lesion_dice_scores_by_model, alpha=0.05, correction='bonferroni')
df_lesion = pd.DataFrame(wilcoxon_results_lesion)
print("\n=== Wilcoxon: Lesion-wise Dice ===")
display(df_lesion)

Saving the results

In [None]:
with open("experiment_1.json", "w") as f:
    json.dump(experiment_results, f, indent=2)

## Experiment 2: Validating the number of the prompted lesions

Paths and mapping

In [None]:
# Ground truth mapping
gt_root = Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/ground_truth")
gt_set_map = {
    "TestSet_1": "ValSet_1",
    "TestSet_2": "ValSet_2",
    "TestSet_3": "ValSet_3"
}

# Set and fold info
sets_and_folds = {
    "TestSet_1": [1],
    "TestSet_2": [2],
    "TestSet_3": [3]
}

# List of models for 1–10 prompted lesions
model_config = {
    f"num_ex_{i}": Path(
        f"/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/"
        f"evaluation/results/predictions/MOIS_SAM2_stream_exemplars_num_ex_{i}/lesion_wise_corrective"
    )
    for i in range(1, 11)
}

Running the evaluation for each model

In [None]:
prompt_counts = list(range(1, 11))

experiment_results_prompted = {}

for num in prompt_counts:
    model_name = f"MOIS_SAM2_num_{num}"
    model_path = f"/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/predictions/MOIS_SAM2_stream_exemplars_num_ex_{num}/lesion_wise_corrective"

    print(f"\n=== Evaluating model: {model_name} ===")
    result = evaluate_experiment(
        pred_root=model_path,
        gt_root=gt_root,
        sets_and_folds={
            "TestSet_1": [1],
            "TestSet_2": [2],
            "TestSet_3": [3],
        },
        gt_set_map={
            "TestSet_1": "ValSet_1",
            "TestSet_2": "ValSet_2",
            "TestSet_3": "ValSet_3",
        },
        size_threshold=lesion_size_threshold
    )

    experiment_results_prompted[num] = result

In [None]:
with open("experiment_2.json", "w") as f:
    json.dump(experiment_results_prompted, f, indent=2)

In [None]:
# Collect metrics
num_prompts = sorted(int(k) for k in experiment_results_prompted.keys())
scan_dice_means, scan_dice_stds = [], []
lesion_dice_means, lesion_dice_stds = [], []
f1_means, f1_stds = [], []

for n in num_prompts:
    sets = experiment_results_prompted[n]

    # Extract metrics from all three folds
    scan = [sets[s]["aggregated"]["scan_dice"] for s in sets]
    lesion = [sets[s]["aggregated"]["lesion_dice"] for s in sets]
    f1 = [sets[s]["aggregated"]["lesion_detection_f1"] for s in sets]

    def mean_std(metrics):
        valid = [m for m in metrics if m["mean"] is not None]
        if not valid:
            return 0.0, 0.0
        mean = np.mean([m["mean"] for m in valid])
        std = np.mean([m["std"] for m in valid])
        return mean, std

    scan_mean, scan_std = mean_std(scan)
    lesion_mean, lesion_std = mean_std(lesion)
    f1_mean, f1_std = mean_std(f1)

    scan_dice_means.append(scan_mean)
    scan_dice_stds.append(scan_std)
    lesion_dice_means.append(lesion_mean)
    lesion_dice_stds.append(lesion_std)
    f1_means.append(f1_mean)
    f1_stds.append(f1_std)

In [None]:
# Prepare x values (number of prompts)
x = num_prompts

# Helper to create upper/lower bounds
def bounds(mean_list, std_list):
    lower = [max(m - s, 0) for m, s in zip(mean_list, std_list)]
    upper = [min(m + s, 1) for m, s in zip(mean_list, std_list)]
    return lower, upper

# Compute bounds for shaded areas
scan_dice_lower, scan_dice_upper = bounds(scan_dice_means, scan_dice_stds)
lesion_dice_lower, lesion_dice_upper = bounds(lesion_dice_means, lesion_dice_stds)
f1_lower, f1_upper = bounds(f1_means, f1_stds)

# Create subplots (1 row x 3 columns)
fig = make_subplots(rows=1, cols=3, subplot_titles=(
    "Scan-wise Dice vs #Prompts", 
    "Lesion-wise Dice vs #Prompts", 
    "Lesion Detection F1 vs #Prompts"
))

# --- Scan Dice
fig.add_trace(go.Scatter(x=x, y=scan_dice_means, mode='lines+markers', name="Scan Dice",
                         line=dict(color='blue')), row=1, col=1)
fig.add_trace(go.Scatter(x=x + x[::-1], y=scan_dice_upper + scan_dice_lower[::-1],
                         fill='toself', fillcolor='rgba(0, 0, 255, 0.2)',
                         line=dict(color='rgba(255,255,255,0)'), showlegend=False), row=1, col=1)

# --- Lesion Dice
fig.add_trace(go.Scatter(x=x, y=lesion_dice_means, mode='lines+markers', name="Lesion Dice",
                         line=dict(color='green')), row=1, col=2)
fig.add_trace(go.Scatter(x=x + x[::-1], y=lesion_dice_upper + lesion_dice_lower[::-1],
                         fill='toself', fillcolor='rgba(0, 255, 0, 0.2)',
                         line=dict(color='rgba(255,255,255,0)'), showlegend=False), row=1, col=2)

# --- F1 Score
fig.add_trace(go.Scatter(x=x, y=f1_means, mode='lines+markers', name="Lesion F1",
                         line=dict(color='red')), row=1, col=3)
fig.add_trace(go.Scatter(x=x + x[::-1], y=f1_upper + f1_lower[::-1],
                         fill='toself', fillcolor='rgba(255, 0, 0, 0.2)',
                         line=dict(color='rgba(255,255,255,0)'), showlegend=False), row=1, col=3)

# Update layout
fig.update_layout(height=500, width=1200, 
                  title_text="Segmentation Performance vs Prompt Count", 
                  showlegend=False)

# Uniform axis settings
for i in range(1, 4):
    fig.update_xaxes(range=[0, 10], dtick=1, row=1, col=i)
    fig.update_yaxes(range=[0.0, 1.0], row=1, col=i,)

fig.show()

In [None]:
scan_dice_means

In [None]:
lesion_dice_means

In [None]:
f1_means

In [None]:
import plotly.graph_objs as go
import plotly.io as pio

# Ensure kaleido is installed: pip install kaleido

# Function to compute bounds
def bounds(mean_list, std_list):
    lower = [max(m - s, 0) for m, s in zip(mean_list, std_list)]
    upper = [min(m + s, 1) for m, s in zip(mean_list, std_list)]
    return lower, upper

# Bounds
scan_lower, scan_upper = bounds(scan_dice_means, scan_dice_stds)
lesion_lower, lesion_upper = bounds(lesion_dice_means, lesion_dice_stds)
f1_lower, f1_upper = bounds(f1_means, f1_stds)

# Helper function to create and save each plot
def save_plot(x, y_mean, y_lower, y_upper, line_color, fill_color, filename):
    fig = go.Figure()

    # Main line
    fig.add_trace(go.Scatter(
        x=x, y=y_mean,
        line=dict(color=line_color, width=3),
        marker=dict(size=6),
        showlegend=False
    ))

    # Fill area for std
    fig.add_trace(go.Scatter(
        x=x + x[::-1],
        y=y_upper + y_lower[::-1],
        fill='toself',
        fillcolor=fill_color,
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
    ))

    fig.update_layout(
        width=400, height=400,
        margin=dict(l=10, r=10, t=10, b=10),
        xaxis=dict(range=[0, 11], showticklabels=False, showgrid=True, zeroline=True),
        yaxis=dict(range=[0.0, 1.0], showticklabels=False, showgrid=True, zeroline=True),
    )

    # Save image at 400 DPI
    pio.write_image(fig, filename, width=400, height=400, scale=4)

# Save all three plots
save_plot(
    x=num_prompts,
    y_mean=scan_dice_means,
    y_lower=scan_lower,
    y_upper=scan_upper,
    line_color='rgb(255, 0, 0)',
    fill_color='rgba(255, 0, 0, 0.2)',
    filename='scanwise_dice.png'
)

save_plot(
    x=num_prompts,
    y_mean=lesion_dice_means,
    y_lower=lesion_lower,
    y_upper=lesion_upper,
    line_color='rgb(0, 0, 255)',
    fill_color='rgba(0, 0, 255, 0.2)',
    filename='lesionwise_dice.png'
)

save_plot(
    x=num_prompts,
    y_mean=f1_means,
    y_lower=f1_lower,
    y_upper=f1_upper,
    line_color='rgb(0, 128, 0)',
    fill_color='rgba(0, 128, 0, 0.2)',
    filename='lesion_detection_f1.png'
)


## Experiment 3: Benchmarking against other algorithms

Paths and mapping

In [None]:
# Root paths
pred_root_base = Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/predictions")
gt_root = Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/ground_truth")

# Model folder mapping
model_paths = {
    "UNet": pred_root_base / "unet" / "global_non_corrective",
    "nnUNet": pred_root_base / "nnunet" / "global_non_corrective",
    "DINs": pred_root_base / "DINs" / "lesion_wise_corrective",
    "SW-FastEdit": pred_root_base / "SW-FastEdit" / "lesion_wise_corrective",
    "SAM2": pred_root_base / "SAM2" / "lesion_wise_corrective",
    "VISTA3D (Interactive)": pred_root_base / "VISTA" / "lesion_wise_corrective",
    "VISTA3D (Automatic)": pred_root_base / "VISTA" / "global_non_corrective",
    "MOIS-SAM2": pred_root_base / "MOIS_SAM2" / "lesion_wise_corrective"
}

# Sets and folds
sets_and_folds = {
    "TestSet_1": [1, 2, 3],
    "TestSet_2": [1, 2, 3],
    "TestSet_3": [1, 2, 3],
    "TestSet_4": [1, 2, 3],
}

# GT folder mapping
gt_set_map = {
    "TestSet_1": "TestSet_1",
    "TestSet_2": "TestSet_2",
    "TestSet_3": "TestSet_3",
    "TestSet_4": "TestSet_4"
}

Evaluation

In [None]:
all_experiment_results = {}
for model_name, model_path in model_paths.items():
    print(f"\nEvaluating model: {model_name}")
    experiment_result = evaluate_experiment(
        pred_root=model_path,
        gt_root=gt_root,
        sets_and_folds=sets_and_folds,
        gt_set_map=gt_set_map,
        size_threshold=lesion_size_threshold
    )
    all_experiment_results[model_name] = experiment_result

In [None]:
def print_benchmark_summary(results_dict):
    print("\n=== Benchmarking Summary Per Dataset ===")
    for dataset in ["TestSet_1", "TestSet_2", "TestSet_3", "TestSet_4"]:
        print(f"\n--- {dataset} ---")
        for model_name, result in results_dict.items():
            agg = result.get(dataset, {}).get("aggregated", {})
            scan_dice = agg.get("scan_dice", {"mean": None, "std": None})
            f1 = agg.get("lesion_detection_f1", {"mean": None, "std": None})
            lesion_dice = agg.get("lesion_dice", {"mean": None, "std": None})

            print(f"{model_name:25s} | "
                  f"Scan DSC: {scan_dice['mean']:.3f} ± {scan_dice['std']:.3f} | "
                  f"F1: {f1['mean']:.3f} ± {f1['std']:.3f} | "
                  f"Lesion DSC: {lesion_dice['mean']:.3f} ± {lesion_dice['std']:.3f}")

print_benchmark_summary(all_experiment_results)

Save the result

In [None]:
with open("all_experiment_results.json", "w") as f:
    json.dump(all_experiment_results, f, indent=2)

In [None]:
all_experiment_results

Statistical significance

In [None]:
# Initialize
metric_scores_by_testset = {
    'scan_dice': defaultdict(dict),
    'lesion_dscs': defaultdict(dict),
    'lesion_f1': defaultdict(dict),
}

# Extract scores
for model, testsets in all_experiment_results.items():
    for testset, data in testsets.items():
        case_results = data.get("cases", [])
        scan_scores, lesion_scores, f1_scores = [], [], []
        for case in case_results:
            scan_scores.append(case.get("scan_dice", None))

            lesion_dice = case.get("lesion_dscs", [])
            lesion_scores.append(np.mean([d for d in lesion_dice if d > 0.0]) if lesion_dice else None)

            f1 = case.get("detection_metrics", {}).get("f1_score", None)
            f1_scores.append(f1)

        # Clean Nones
        metric_scores_by_testset["scan_dice"][testset][model] = [v for v in scan_scores if v is not None]
        metric_scores_by_testset["lesion_dscs"][testset][model] = [v for v in lesion_scores if v is not None]
        metric_scores_by_testset["lesion_f1"][testset][model] = [v for v in f1_scores if v is not None]


In [None]:

def wilcoxon_vs_reference(metric_data, reference_model='MOIS-SAM2', alpha=0.05):
    comparisons = []
    for testset, model_scores in metric_data.items():
        if reference_model not in model_scores:
            continue
        ref_scores = model_scores[reference_model]
        for model, scores in model_scores.items():
            if model == reference_model or len(ref_scores) != len(scores):
                continue
            try:
                stat, p = wilcoxon(ref_scores, scores)
            except ValueError:
                p = 1.0
            comparisons.append({
                "testset": testset,
                "reference": reference_model,
                "compared_to": model,
                "p_value": p,
                "significant": p < alpha
            })
    return comparisons

scan_results = wilcoxon_vs_reference(metric_scores_by_testset["scan_dice"])
lesion_results = wilcoxon_vs_reference(metric_scores_by_testset["lesion_dscs"])
f1_results = wilcoxon_vs_reference(metric_scores_by_testset["lesion_f1"])

df_scan = pd.DataFrame(scan_results).assign(metric="Scan Dice")
df_lesion = pd.DataFrame(lesion_results).assign(metric="Lesion Dice")
df_f1 = pd.DataFrame(f1_results).assign(metric="Lesion Detection F1")

# Combine all
wilcoxon_df = pd.concat([df_scan, df_lesion, df_f1], ignore_index=True)
display(wilcoxon_df)




Domain shift analysis

In [None]:
# Organize data: {model: {testset: [scan_dice values]}}
def extract_scan_dice_scores(data):
    scan_dice_scores = {}
    for model_name, model_data in data.items():
        scan_dice_scores[model_name] = {}
        for test_set in ["TestSet_1", "TestSet_2", "TestSet_3", "TestSet_4"]:
            cases = model_data.get(test_set, {}).get("cases", [])
            scan_dice_scores[model_name][test_set] = [
                case["scan_dice"] for case in cases if "scan_dice" in case
            ]
    return scan_dice_scores

# Replace this with your loaded dictionary
# all_experiment_results = ...

model_testset_dice = extract_scan_dice_scores(all_experiment_results)

In [None]:
# Flatten the scan_dice scores into a long-format DataFrame
rows = []
for model, testsets in model_testset_dice.items():
    if model in ['nnUNet', 'DINs', 'SAM2', 'MOIS-SAM2']:
        for testset, scores in testsets.items():
            for score in scores:
                rows.append({
                    "Model": model,
                    "Test Set": testset,
                    "Scan-wise DSC": score
                })

df_box = pd.DataFrame(rows)

# Create box plot
fig = px.box(
    df_box,
    x="Model",
    y="Scan-wise DSC",
    color="Test Set",
    width=1000,
    height=600
)

fig.update_layout(
    yaxis=dict(range=[0.0, 1.0], showticklabels=False),
    xaxis=dict(showticklabels=False),
    margin=dict(l=10, r=10, t=10, b=10),
    boxmode='group',
    template="plotly_white"
)

    # Save image at 400 DPI
pio.write_image(fig, "domain_generalization.png", scale=4)

fig.show()


## Experiment 4: Interaction efficiency

Paths and mapping

In [None]:
# Reinitialize after kernel reset
# Define the test sets and folds
sets_and_folds = {
    "TestSet_1": [1, 2, 3],
    "TestSet_2": [1, 2, 3],
    "TestSet_3": [1, 2, 3],
    "TestSet_4": [1, 2, 3]
}

# Ground truth folder mapping
gt_set_map = {
    "TestSet_1": "TestSet_1",
    "TestSet_2": "TestSet_2",
    "TestSet_3": "TestSet_3",
    "TestSet_4": "TestSet_4"
}

# Base path
base_path = Path("/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarkingPrivate/evaluation/results/predictions")

# Model folders
sam2_paths = {
    f"SAM2_{i}": base_path / f"SAM2_num_lesions_{i}" / "lesion_wise_corrective"
    for i in range(1, 11)
}
mois_sam2_paths = {
    f"MOIS_SAM2_{i}": base_path / f"MOIS_SAM2_num_lesions_{i}" / "lesion_wise_corrective"
    for i in range(1, 11)
}
interaction_models = {**sam2_paths, **mois_sam2_paths}

Evaluation of models

In [None]:
# Evaluate all 20 models (SAM2 and MOIS-SAM2 with 1–10 prompts)
interaction_results = {}

for model_name, model_path in interaction_models.items():
    print(f"Evaluating {model_name}...")
    result = evaluate_experiment(
        pred_root=model_path,
        gt_root=gt_root,
        sets_and_folds=sets_and_folds,
        gt_set_map=gt_set_map,
        size_threshold=lesion_size_threshold
    )
    interaction_results[model_name] = result

Export json

In [None]:
with open("experiment_4.json", "w") as f:
    json.dump(interaction_results, f, indent=2)

Aggregation of the scan-wise DSC

In [None]:
with open("experiment_4.json", "r") as f:
    interaction_results = json.load(f)

In [None]:
# Prepare structure: testset → model → list of (prompt_count, mean, std)
scan_dice_summary = defaultdict(lambda: {"SAM2": [], "MOIS_SAM2": []})

for model_key, model_result in interaction_results.items():
    if model_key.startswith("MOIS_SAM2_"):
        model_type = "MOIS_SAM2"
        prompt_count = int(model_key.replace("MOIS_SAM2_", ""))
    elif model_key.startswith("SAM2_"):
        model_type = "SAM2"
        prompt_count = int(model_key.replace("SAM2_", ""))
    else:
        continue

    for testset in model_result:
        agg = model_result[testset]["aggregated"]
        mean = agg["scan_dice"]["mean"]
        std = agg["scan_dice"]["std"]
        scan_dice_summary[testset][model_type].append((prompt_count, mean, std))

# Sort entries by prompt count
for testset in scan_dice_summary:
    for model in ["SAM2", "MOIS_SAM2"]:
        scan_dice_summary[testset][model] = sorted(scan_dice_summary[testset][model])


Plot the results

In [None]:
def plot_scan_dice_single_testset(scan_dice_summary, testset_name):
    fig = go.Figure()

    for model in ["SAM2", "MOIS_SAM2"]:
        data = scan_dice_summary[testset_name][model]
        print("Model: ", model)
        
        if not data:
            continue
        x = [d[0] for d in data]
        y = [d[1] for d in data]
        print(y)
        std = [d[2] for d in data]

        # Mean curve
        fig.add_trace(go.Scatter(
            x=x, y=y, mode="lines+markers",
            name=model,
            line=dict(shape='spline'),
        ))

        # Shaded error band
        fig.add_trace(go.Scatter(
            x=x + x[::-1],
            y=[m + s for m, s in zip(y, std)] + [m - s for m, s in zip(y[::-1], std[::-1])],
            fill='toself',
            fillcolor='rgba(0,100,200,0.1)' if model == "SAM2" else 'rgba(0,200,100,0.1)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))

    fig.update_layout(
        title=f"Scan-wise DSC vs Prompted Lesions – {testset_name}",
        xaxis=dict(title="Number of Prompted Lesions", range=[0, 10], dtick=1),
        yaxis=dict(title="Scan-wise DSC", range=[0.0, 1.0]),
        height=500,
        width=600,
    )
    fig.show()

In [None]:
plot_scan_dice_single_testset(scan_dice_summary, "TestSet_1")

In [None]:
plot_scan_dice_single_testset(scan_dice_summary, "TestSet_2")

In [None]:
plot_scan_dice_single_testset(scan_dice_summary, "TestSet_3")

In [None]:
plot_scan_dice_single_testset(scan_dice_summary, "TestSet_4")

In [None]:
import plotly.graph_objects as go
import plotly.io as pio

def plot_scan_dice_single_testset(scan_dice_summary, testset_name, save_path):
    fig = go.Figure()

    for model in ["SAM2", "MOIS_SAM2"]:
        data = scan_dice_summary[testset_name][model]
        if not data:
            continue
        x = [d[0] for d in data]
        y = [d[1] for d in data]
        std = [d[2] for d in data]

        # Choose color for MOIS-SAM2 and SAM2
        color = 'rgba(0,100,200,1)' if model == "SAM2" else 'rgba(0,200,100,1)'

        # Mean curve
        fig.add_trace(go.Scatter(
            x=x, y=y, mode="lines+markers",
            name=model,
            line=dict(shape='spline', color=color, width=3),
            marker=dict(color=color),
        ))

        # Shaded error band
        fig.add_trace(go.Scatter(
            x=x + x[::-1],
            y=[m + s for m, s in zip(y, std)] + [m - s for m, s in zip(y[::-1], std[::-1])],
            fill='toself',
            fillcolor='rgba(0,100,200,0.1)' if model == "SAM2" else 'rgba(0,200,100,0.1)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))

    fig.update_layout(
        xaxis=dict(range=[0, 11], dtick=2, showticklabels=False),
        yaxis=dict(range=[0.0, 1.0], showticklabels=False),
        width=400, height=400,
        showlegend=False,
        margin=dict(l=10, r=10, t=10, b=10),
    )

    # Save figure at 400 DPI
    pio.write_image(fig, save_path, format="png", width=400, height=400,scale=4)  # scale=4 ~ 400 DPI

    print(f"Saved: {save_path}")


In [None]:
plot_scan_dice_single_testset(scan_dice_summary, "TestSet_1", 'exp_4_ts1.png')

In [None]:
plot_scan_dice_single_testset(scan_dice_summary, "TestSet_2", 'exp_4_ts2.png')

In [None]:
plot_scan_dice_single_testset(scan_dice_summary, "TestSet_3", 'exp_4_ts3.png')

In [None]:
plot_scan_dice_single_testset(scan_dice_summary, "TestSet_4", 'exp_4_ts4.png')