# Result Analysis for Interactive Segmentation Models

This Jupyter Notebook processes and analyzes the evaluation results of different interactive segmentation models.
It aggregates Dice Similarity Coefficients (DSC) across multiple folds and test sets to compare segmentation performance.

### Workflow Overview
1. **Data Loading**: Reads evaluation results from Excel files.
2. **Aggregation**: Computes lesion-wise and global DSC statistics.
3. **Visualization**: Generates summary plots using `plotly`.
4. **Comparison**: Compares models across different evaluation modes.

### Models Analyzed
- **SW-FastEdit**
- **DINs (Deep Interactive Networks)**
- **SAM2 (Segment Anything Model 2)**

Each model is evaluated in **lesion-wise corrective** and **global (scan-wise) corrective** modes using test sets 1 (high tumor burden) and 3 (low tumor burden).

---

In [None]:
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
import plotly.graph_objects as go

### Results aggregation
Read excel files containing interactive segmentation performance foe each case, fold, interaction scenario, and model.

#### Averaged metrics

In [None]:
# Define the paths and models
base_path = "/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarking/evaluation/results/metrics"
models = ["SW-FastEdit", "DINs", "SAM2"]
modes = ["lesion_wise_corrective", "global_corrective"]
test_sets = ["TestSet_1", "TestSet_3"]
folds = [1, 2, 3]


In [None]:
def process_lesion_wise_metrics(model, test_set, folds):
    """
    Processes lesion-wise DSC (Dice Similarity Coefficient) metrics for a given model across multiple folds.
    
    This function reads lesion-wise and global DSC values from Excel files stored in experiment folders, 
    aggregates them across folds, and computes the mean and standard deviation for each case.
    
    Args:
        model (str): The name of the model being analyzed (e.g., "DINs", "SAM2").
        test_set (str): Identifier for the test dataset (e.g., "TestSet_1").
        folds (list): List of fold numbers to process (e.g., [1, 2, 3]).

    Returns:
        dict: A dictionary containing the averaged lesion-wise and global DSC metrics across all folds.
    
    """
    print(">>> Processing lesion-wise mode metrics across folds...")
    case_metrics = {}

    # Step 1: Collect metrics across folds for each case
    for fold in folds:
        folder_path = os.path.join(base_path, model, "lesion_wise_corrective", test_set, f"fold_{fold}")

        for case in tqdm(os.listdir(folder_path)):
            lesion_file = os.path.join(folder_path, case, "lesion_metrics.xlsx")
            global_file_1 = os.path.join(folder_path, case, "global_metrics_after_single_interaction.xlsx")
            global_file_all = os.path.join(folder_path, case, "global_metrics.xlsx")

            # Initialize case entry if not already present
            if case not in case_metrics:
                case_metrics[case] = {
                    'Lesion DSC after 1 interaction': [],
                    'Lesion DSC after 3 interactions': [],
                    'Global DSC after 1 interaction': [],
                    'Global DSC after all interactions': []
                }

            # Read lesion DSC for the 1st and 3rd interactions
            if os.path.exists(lesion_file):
                lesion_df = pd.read_excel(lesion_file)
                if len(lesion_df) >= 3:
                    case_metrics[case]['Lesion DSC after 1 interaction'].append(lesion_df.iloc[1].values.mean())  # First interaction
                    case_metrics[case]['Lesion DSC after 3 interactions'].append(lesion_df.iloc[2].values.mean())  # Third interaction
                elif len(lesion_df) >= 1:
                    case_metrics[case]['Lesion DSC after 1 interaction'].append(lesion_df.iloc[1].values.mean())
                    case_metrics[case]['Lesion DSC after 3 interactions'].append(lesion_df.iloc[-1].values.mean())  # Take last if <3

            # Read global DSC for 1 and all interactions
            if os.path.exists(global_file_1):
                global_df_1 = pd.read_excel(global_file_1)
                case_metrics[case]['Global DSC after 1 interaction'].append(global_df_1.iloc[1]['dsc_global'])  # After 1 interaction

            if os.path.exists(global_file_all):
                global_df_all = pd.read_excel(global_file_all)
                case_metrics[case]['Global DSC after all interactions'].append(global_df_all.iloc[1]['dsc_global'])  # After all interactions

    
    # Step 2: Compute mean and std for each case across folds
    aggregated_results = []
    for case, metrics in case_metrics.items():
        aggregated_results.append({
            'Lesion DSC after 1 interaction (mean)': np.mean(metrics['Lesion DSC after 1 interaction']) if metrics['Lesion DSC after 1 interaction'] else None,
            'Lesion DSC after 1 interaction (std)': np.std(metrics['Lesion DSC after 1 interaction']) if metrics['Lesion DSC after 1 interaction'] else None,
            'Lesion DSC after 3 interactions (mean)': np.mean(metrics['Lesion DSC after 3 interactions']) if metrics['Lesion DSC after 3 interactions'] else None,
            'Lesion DSC after 3 interactions (std)': np.std(metrics['Lesion DSC after 3 interactions']) if metrics['Lesion DSC after 3 interactions'] else None,
            'Global DSC after 1 interaction (mean)': np.mean(metrics['Global DSC after 1 interaction']) if metrics['Global DSC after 1 interaction'] else None,
            'Global DSC after 1 interaction (std)': np.std(metrics['Global DSC after 1 interaction']) if metrics['Global DSC after 1 interaction'] else None,
            'Global DSC after all interactions (mean)': np.mean(metrics['Global DSC after all interactions']) if metrics['Global DSC after all interactions'] else None,
            'Global DSC after all interactions (std)': np.std(metrics['Global DSC after all interactions']) if metrics['Global DSC after all interactions'] else None,
        })
    
    # Step 3: Compute overall average results
    aggregated_results_pd = pd.DataFrame(aggregated_results).mean()
    averaged_aggregated_results = {
        'Model': model,
        'TestSet': test_set,
    }
    for key in aggregated_results_pd.keys():
        averaged_aggregated_results[key] = aggregated_results_pd[key]

    print(">>> Finished processing lesion-wise mode metrics.")
    return averaged_aggregated_results


In [None]:
def process_global_metrics(model, test_set, folds):
    """
    Processes global DSC (Dice Similarity Coefficient) metrics for a given model across multiple folds.
    
    This function reads global DSC values from Excel files stored in experiment folders, aggregates them 
    across folds, and computes the mean and standard deviation for each case.
    
    Args:
        model (str): The name of the model being analyzed (e.g., "DINs", "SAM2").
        test_set (str): Identifier for the test dataset (e.g., "TestSet_1").
        folds (list): List of fold numbers to process (e.g., [1, 2, 3]).

    Returns:
        dict: A dictionary containing the averaged global DSC metrics across all folds.
    
    """
    print(">>> Processing global mode metrics across folds...")
    case_metrics = {}

    # Step 1: Collect metrics across folds for each case
    for fold in folds:
        folder_path = os.path.join(base_path, model, "global_corrective", test_set, f"fold_{fold}")

        for case in tqdm(os.listdir(folder_path)):
            global_file = os.path.join(folder_path, case, "lesion_metrics.xlsx")

            # Initialize case entry if not already present
            if case not in case_metrics:
                case_metrics[case] = {
                    'Global DSC after 1 interaction': [],
                    'Global DSC after 3 interactions': [],
                    'Global DSC after 20 interactions': [],
                    'Global DSC after 60 interactions': []
                }

            # Process global metrics if file exists
            if os.path.exists(global_file):
                global_df = pd.read_excel(global_file)
                
                if len(global_df) > 1:
                    case_metrics[case]['Global DSC after 1 interaction'].append(global_df.iloc[1].values.mean())  # After 1 interaction
                if len(global_df) > 3:
                    case_metrics[case]['Global DSC after 3 interactions'].append(global_df.iloc[3].values.mean())  # After 3 interactions
                if len(global_df) > 20:
                    case_metrics[case]['Global DSC after 20 interactions'].append(global_df.iloc[20].values.mean())  # After 20 interactions
                if len(global_df) > 60:
                    case_metrics[case]['Global DSC after 60 interactions'].append(global_df.iloc[60].values.mean())  # After 60 interactions

    # Step 2: Compute mean and std for each case across folds
    aggregated_results = []
    for case, metrics in case_metrics.items():
        aggregated_results.append({
            'Global DSC after 1 interaction (mean)': np.mean(metrics['Global DSC after 1 interaction']) if metrics['Global DSC after 1 interaction'] else None,
            'Global DSC after 1 interaction (std)': np.std(metrics['Global DSC after 1 interaction']) if metrics['Global DSC after 1 interaction'] else None,
            'Global DSC after 3 interactions (mean)': np.mean(metrics['Global DSC after 3 interactions']) if metrics['Global DSC after 3 interactions'] else None,
            'Global DSC after 3 interactions (std)': np.std(metrics['Global DSC after 3 interactions']) if metrics['Global DSC after 3 interactions'] else None,
            'Global DSC after 20 interactions (mean)': np.mean(metrics['Global DSC after 20 interactions']) if metrics['Global DSC after 20 interactions'] else None,
            'Global DSC after 20 interactions (std)': np.std(metrics['Global DSC after 20 interactions']) if metrics['Global DSC after 20 interactions'] else None,
            'Global DSC after 60 interactions (mean)': np.mean(metrics['Global DSC after 60 interactions']) if metrics['Global DSC after 60 interactions'] else None,
            'Global DSC after 60 interactions (std)': np.std(metrics['Global DSC after 60 interactions']) if metrics['Global DSC after 60 interactions'] else None,
        })
    
    # Step 3: Compute overall average results
    aggregated_results_pd = pd.DataFrame(aggregated_results).mean()
    averaged_aggregated_results = {
        'Model': model,
        'TestSet': test_set,
    }
    for key in aggregated_results_pd.keys():
        averaged_aggregated_results[key] = aggregated_results_pd[key]
    print(">>> Finished processing global mode metrics.")
    
    return averaged_aggregated_results

In [None]:
# Process all models and test sets
# Lists to store aggregated results for lesion-wise and global metrics
lesion_wise_results = []
global_results = []

# Loop through each model in the specified list
for model in models:
    print(f"> Started processing data for model: {model}...")
    
    # Loop through each test set
    for test_set in test_sets:
        print(f">> Started processing set: {test_set}...")
        
        # Compute lesion-wise and global metrics for the current model and test set
        lesion_wise_results.append(process_lesion_wise_metrics(model, test_set, folds))
        global_results.append(process_global_metrics(model, test_set, folds))
        
        print(">> Finished processing.")
    print("> Finished processing.")

# Display the aggregated lesion-wise results
lesion_wise_results

# Convert results to Pandas DataFrame for further analysis
lesion_wise_df = pd.DataFrame(lesion_wise_results)
global_df = pd.DataFrame(global_results)

# Save the aggregated results to Excel files
lesion_wise_df.to_excel(f"lesion_wise_results.xlsx", index=False)
global_df.to_excel(f"global_results.xlsx", index=False)

print("Aggregation complete. Results saved to Excel.")


Show results for the lesion-wise interaction scenario

In [None]:
lesion_wise_df.loc[lesion_wise_df["TestSet"] == "TestSet_1"]

In [None]:
lesion_wise_df.loc[lesion_wise_df["TestSet"] == "TestSet_3"]

Show results for the global scan-wise interaction scenario

In [None]:
global_df.loc[global_df["TestSet"] == "TestSet_1"]

In [None]:
global_df.loc[global_df["TestSet"] == "TestSet_3"]

#### Per-case metrics

In [None]:
# Define the base directory containing evaluation results
base_dir = "/home/gkolokolnikov/PhD_project/nf_segmentation_interactive/NFInteractiveSegmentationBenchmarking/evaluation/results/metrics"

# Define the models and folds to process
models = ["DINs", "SW-FastEdit", "SAM2"]
folds = ["fold_1", "fold_2", "fold_3"]

# Initialize lists to store results for lesion-wise and global DSC (Dice Similarity Coefficient) metrics
lesion_wise_data = []
global_data = []


# Helper function to extract the last DSC value from a DataFrame
def extract_last_dsc(file_path):
    """
    Reads an Excel file and extracts the DSC (Dice Similarity Coefficient) value from the second row.
    
    Args:
        file_path (str): Path to the Excel file containing global metrics.
    
    Returns:
        float: The DSC value from the second row.
    """
    df = pd.read_excel(file_path)
    return df.iloc[1]['dsc_global']


# Process results for the global corrective approach
print("Processing global corrective approach...")
for model in models:
    print(f"> Processing model: {model}")
    for test_set in os.listdir(os.path.join(base_dir, model, "global_corrective")):
        print(f">> Processing test_set: {test_set}")
        test_set_path = os.path.join(base_dir, model, "global_corrective", test_set)
        for fold in folds:
            print(f">>> Processing fold {fold}")
            fold_path = os.path.join(test_set_path, fold)
            if os.path.exists(fold_path):
                for case in tqdm(os.listdir(fold_path)):
                    case_path = os.path.join(fold_path, case)
                    metrics_file = os.path.join(case_path, "global_metrics.xlsx")

                    if os.path.exists(metrics_file):
                        # Extract the last DSC value
                        last_dsc = extract_last_dsc(metrics_file)

                        # Append to the global data
                        global_data.append({
                            "TestSet_Case": f"{test_set}_{case}",
                            f"{model}_{fold}": round(last_dsc, 2)
                        })

# Merge global DSC results into a structured dictionary
merged_global_data = {}
for entry in global_data:
    test_set_case = entry["TestSet_Case"]
    if test_set_case not in merged_global_data:
        merged_global_data[test_set_case] = {"TestSet_Case": test_set_case}
    # Update the merged dictionary with the remaining keys and values
    for key, value in entry.items():
        if key != "TestSet_Case":
            merged_global_data[test_set_case][key] = value

# Convert merged global data to a Pandas DataFrame
global_df = pd.DataFrame(merged_global_data.values())


# Process results for the lesion-wise corrective approach
print("Processing lesion-wise corrective approach...")
for model in models:
    print(f"> Processing model: {model}")
    for test_set in os.listdir(os.path.join(base_dir, model, "lesion_wise_corrective")):
        print(f">> Processing test_set: {test_set}")
        test_set_path = os.path.join(base_dir, model, "lesion_wise_corrective", test_set)
        for fold in folds:
            print(f">>> Processing fold {fold}")
            fold_path = os.path.join(test_set_path, fold)
            if os.path.exists(fold_path):
                for case in tqdm(os.listdir(fold_path)):
                    case_path = os.path.join(fold_path, case)
                    metrics_file = os.path.join(case_path, "global_metrics.xlsx")

                    if os.path.exists(metrics_file):
                        # Extract the last DSC value
                        last_dsc = extract_last_dsc(metrics_file)

                        # Append to the global data
                        lesion_wise_data.append({
                            "TestSet_Case": f"{test_set}_{case}",
                            f"{model} {fold}": round(last_dsc, 2)
                        })

# Merge lesion-wise DSC results into a structured dictionary
merged_lesion_wise = {}

for entry in lesion_wise_data:
    test_set_case = entry["TestSet_Case"]
    if test_set_case not in merged_lesion_wise:
        merged_lesion_wise[test_set_case] = {"TestSet_Case": test_set_case}
    # Update the merged dictionary with the remaining keys and values
    for key, value in entry.items():
        if key != "TestSet_Case":
            merged_lesion_wise[test_set_case][key] = value

# Convert merged_data to a pandas DataFrame
lesion_wise_df = pd.DataFrame(merged_lesion_wise.values())

# Save results to Excel files
lesion_wise_df.to_excel("lesion_wise_corrective_all_cases.xlsx", index=False)
global_df.to_excel("global_corrective_all_cases.xlsx", index=False)

print("Excel files created: lesion_wise_corrective_all_cases.xlsx and global_corrective_all_cases.xlsx")


### Results visualization
Show how the performance changed with increasing the number of interaction points

#### Per-lesion

In [None]:
font_size_value = 16
scale = 2
width = 350
height = 450
pos_y = -0.2

# Load the data
file_path = 'lesion_wise_results.xlsx'  # Replace with your file path
df = pd.read_excel(file_path)

# Filter data for TestSet_1
df_testset1 = df[df['TestSet'] == 'TestSet_1']

# Extract data for each model
models = ['DINs', 'SW-FastEdit', 'SAM2']
traces = []

# Define line styles and colors for publication quality
line_styles = ['solid', 'dash', 'dot']
colors = ['blue', 'green', 'red']

for i, model in enumerate(models):
    model_data = df_testset1[df_testset1['Model'] == model]

    x = [1, 3]  # Interactions: 1 and 3
    y = [
        model_data['Lesion DSC after 1 interaction (mean)'].values[0],
        model_data['Lesion DSC after 3 interactions (mean)'].values[0]
    ]
    error_y = [
        model_data['Lesion DSC after 1 interaction (std)'].values[0],
        model_data['Lesion DSC after 3 interactions (std)'].values[0]
    ]

    # Add a trace for the model
    traces.append(go.Scatter(
        x=x,
        y=y,
        mode='lines+markers',
        name=model,
        line=dict(color=colors[i], dash=line_styles[i], width=2),
        marker=dict(size=8, symbol='circle'),
        error_y=dict(
            type='data',
            array=error_y,
            visible=True,
            thickness=1.5,
            width=2
        )
    ))

# Create the figure
fig = go.Figure(traces)

# Update layout for research paper formatting
fig.update_layout(
    annotations=[
        dict(
            text="(a)",  # Title text
            x=0.5,  # Center horizontally
            y=pos_y,  # Position below the plot
            xref="paper",
            yref="paper",
            showarrow=False,
            font=dict(family="Times New Roman", size=font_size_value, color="black"),
        )
    ],
    xaxis=dict(
        title="Number of Interactions",
        title_font=dict(family="Times New Roman", size=font_size_value, color="black"),
        tickvals=[0, 1, 2, 3, 4, 5],
        range=[0.5, 3.5],  # Set x-axis range
        tickfont=dict(family="Times New Roman", size=font_size_value, color="black"),
    ),
    yaxis=dict(
        title="Per-lesion Dice Similarity Score",
        title_font=dict(family="Times New Roman", size=font_size_value, color="black"),
        range=[0.0, 0.5],  # Set y-axis range
        tickfont=dict(family="Times New Roman", size=font_size_value, color="black"),
    ),
    showlegend=False,
    # legend=dict(
    #     title=dict(
    #         text="Models",
    #         font=dict(family="Times New Roman", size=14)  # Smaller font size for legend title
    #     ),
    #     font=dict(family="Times New Roman", size=14),
    #     orientation="v",  # Vertical legend
    #     x=1.02,  # Position legend outside the plot on the right
    #     y=0.5,
    #     xanchor="left",
    #     yanchor="middle",
    # ),
    template="plotly",
    margin=dict(l=50, r=150, t=20, b=80),  # Adjust margins to fit legend
    width=width,  # Set the width of the plot 300
    height=height,  # Set the height of the plot for 1:1 aspect ratio 350
)

In [None]:
fig.write_image("Per_lesion_DSC_vs_interactions.svg")

#### Per-scan

In [None]:
# Load the data
file_path = 'global_results.xlsx'  # Replace with your file path
df = pd.read_excel(file_path)

# Filter data for TestSet_1
df_testset1 = df[df['TestSet'] == 'TestSet_1']

# Extract data for each model
models = ['DINs', 'SW-FastEdit', 'SAM2']
traces = []

# Define line styles and colors for publication quality
line_styles = ['solid', 'dash', 'dot']
colors = ['blue', 'green', 'red']

for i, model in enumerate(models):
    model_data = df_testset1[df_testset1['Model'] == model]

    x = [1, 3]  # Interactions: 1 and 3
    y = [
        model_data['Global DSC after 1 interaction (mean)'].values[0],
        model_data['Global DSC after 3 interactions (mean)'].values[0]
    ]
    error_y = [
        model_data['Global DSC after 1 interaction (std)'].values[0],
        model_data['Global DSC after 3 interactions (std)'].values[0]
    ]

    # Add a trace for the model
    traces.append(go.Scatter(
        x=x,
        y=y,
        mode='lines+markers',
        name=model,
        line=dict(color=colors[i], dash=line_styles[i], width=2),
        marker=dict(size=8, symbol='circle'),
        error_y=dict(
            type='data',
            array=error_y,
            visible=True,
            thickness=1.5,
            width=2
        ),
    ))

# Create the figure
fig = go.Figure(traces)

# Update layout for research paper formatting
fig.update_layout(
    annotations=[
        dict(
            text="(b)",  # Title text
            x=0.5,  # Center horizontally
            y=-0.2,  # Position below the plot
            xref="paper",
            yref="paper",
            showarrow=False,
            font=dict(family="Times New Roman", size=font_size_value, color="black"),
        )
    ],
    xaxis=dict(
        title="Number of Interactions",
        title_font=dict(family="Times New Roman", size=font_size_value, color="black"),
        tickvals=[0, 1, 2, 3, 4, 5],
        range=[0.5, 3.5],  # Set x-axis range
        tickfont=dict(family="Times New Roman", size=font_size_value, color="black"),
    ),
    yaxis=dict(
        title="Per-scan Dice Similarity Score",
        title_font=dict(family="Times New Roman", size=font_size_value, color="black"),
        range=[0.0, 0.5],  # Set y-axis range
        tickfont=dict(family="Times New Roman", size=font_size_value, color="black"),
    ),
    legend=dict(
        title=dict(
            text="Models",
            font=dict(family="Times New Roman", size=font_size_value, color="black")  # Smaller font size for legend title
        ),
        font=dict(family="Times New Roman", size=font_size_value, color="black"),
        orientation="v",  # Vertical legend
        x=1.02,  # Position legend outside the plot on the right
        y=0.5,
        xanchor="left",
        yanchor="middle",
    ),
    template="plotly",
    margin=dict(l=50, r=150, t=20, b=80),  # Adjust margins to fit legend
    width=width,  # Set the width of the plot 300
    height=height,  # Set the height of the plot for 1:1 aspect ratio 350
)

In [101]:
# fig.write_image("Per_scan_DSC_vs_interactions.png", scale=scale)
fig.write_image("Per_scan_DSC_vs_interactions.svg")