In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

# EXPERIMENT SETS:
# - WMH Calibration (no weighting): 01_29_24_WMH_Base
# - WMH Only Foreground Loss: 01_29_24_WMH_Foreground
# - WMH Balanced Loss: 01_21_24_Balanced_CE_Calibrators
# - CityScapes Calibration: 01_30_24_CityScapes_WeightedEnsembles

log:
    root: /storage/vbutoi/scratch/ESE/inference
    add_dice_loss_rows: True
    load_pixel_meters: True 
    drop_nan_metric_rows: True 
    remove_shared_columns: False 
    add_baseline_rows: True 
    equal_rows_per_cfg_assert: False
    inference_group: "01_31_24_CityScapes_AllCalibrators"
    # min_fg_pixels: 100
    
calibration:
    num_bins: 15
    square_diff: False 
    neighborhood_width: 3

# cal_metrics:
#     - ECE:
#         _fn: ese.experiment.metrics.ece.ece_loss
#     - Edge-ECE:
#         _fn: ese.experiment.metrics.ece.edge_ece_loss
#     # - CW-ECE:
#     #     _fn: ese.experiment.metrics.ece.cw_ece_loss
#     - ELM:
#         _fn: ese.experiment.metrics.elm.elm_loss

#     - Foreground-ECE:
#         _fn: ese.experiment.metrics.ece.ece_loss
#         ignore_index: 0
#     - Foreground-Edge-ECE:
#         _fn: ese.experiment.metrics.ece.edge_ece_loss
#         ignore_index: 0       
#     # - Foreground-CW-ECE:
#     #     _fn: ese.experiment.metrics.ece.cw_ece_loss
#     #     ignore_index: 0
#     - Foreground-ELM:
#         _fn: ese.experiment.metrics.elm.elm_loss
#         ignore_index: 0

In [None]:
from ese.experiment.analysis.analyze_inf import load_cal_inference_stats

image_info_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=True
)

In [None]:
image_info_df['calibrator'].unique()

## We are going to remove the case where there are very few pixels, cause unrealistic outliers.

In [None]:
# Sort the image_info_df by method name, so everything appears nicely
image_info_df = image_info_df.sort_values(by=['method_name', 'calibrator'])
# Make sure that the model_class 'Uncalibrated' is first
image_info_df['calibrator'] = image_info_df['calibrator'].astype('category')
image_info_df['calibrator'] = image_info_df['calibrator'].cat.reorder_categories([
    'Uncalibrated',
    'Vanilla',
    'Temperature_Scaling', 
    'Vector_Scaling', 
    'Dirichlet_Scaling',
    'LTS', 
    'NECTAR_Scaling'
])

In [None]:
image_info_df['method_name'].unique()

In [None]:
image_info_df['ensemble_w_metric'].unique()

In [None]:
# Select only the rows corresponding to group methods
image_info_df = image_info_df[image_info_df['model_type'] == 'group']
# group_methods_df = image_info_df

image_info_df['method_name'] = image_info_df['method_name'].astype('category')
image_info_df['method_name'] = image_info_df['method_name'].cat.reorder_categories([
    'Average UNet',
    'Ensemble (mean, logits)', 
    'Ensemble (mean, probs)', 
    'Ensemble (product, probs)', 
    # 'UNet (seed=40)', 
    # 'UNet (seed=41)', 
    # 'UNet (seed=42)', 
    # 'UNet (seed=43)', 
    ])

# image_info_df['ensemble_w_metric'] = image_info_df['ensemble_w_metric'].astype('category')
# image_info_df['ensemble_w_metric'] = image_info_df['ensemble_w_metric'].cat.reorder_categories([
#     'None',
#     'val-loss',
#     'val-dice_score',
#     'val-ece_loss',
#     'val-elm_loss'
# ])

## Let's looks at the calibration scores of our models.

In [None]:
g = sns.catplot(
    data=image_info_df,
    x="calibrator",
    y="ECE",
    hue="method_name",
    kind="bar",
    height=6,
    aspect=2
)
# Set column spacingj
g.fig.subplots_adjust(wspace=0.5)
# # Set the y-axis limits
# g.set(ylim=(0.0, 0.25))
# Set the title of the plot
g.fig.suptitle("ECE by Calibration Method and Model Class")
# Move the title slightly up
g.fig.subplots_adjust(top=0.9)

In [None]:
g = sns.catplot(
    data=image_info_df,
    x="calibrator",
    y="Foreground-ECE",
    hue="method_name",
    kind="bar",
    height=6,
    aspect=2
)
# Set column spacing
# # Set the y-axis limits
# g.set(ylim=(0.0, 0.25))
# Set the title of the plot
g.fig.suptitle("Foreground ECE by Calibration Method and Model Class")
# Move the title slightly up
g.fig.subplots_adjust(top=0.9)

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="CW-ECE",
#     hue="method_name",
#     kind="bar",
#     height=6,
#     aspect=2
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.25))
# g.fig.suptitle("CW ECE by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

In [None]:
g = sns.catplot(
    data=image_info_df,
    x="calibrator",
    y="Edge-ECE",
    hue="method_name",
    kind="bar",
    height=6,
    aspect=2
)
# Set column spacing
# # Set the y-axis limits
# g.set(ylim=(0.0, 0.25))
g.fig.suptitle("Edge ECE by Calibration Method and Model Class")
# Move the title slightly up
g.fig.subplots_adjust(top=0.9)

In [None]:
g = sns.catplot(
    data=image_info_df,
    x="calibrator",
    y="ELM",
    hue="method_name",
    kind="bar",
    height=6,
    aspect=2
)
# Set column spacing
# # Set the y-axis limits
# g.set(ylim=(0.0, 0.25))
g.fig.suptitle("ELM by Calibration Method and Model Class")
# Move the title slightly up
g.fig.subplots_adjust(top=0.9)

## Now we can look at the quality averages themselves, first looking slice-wise.

In [None]:
table_df = image_info_df.groupby(["method_name", "calibrator", "image_metric", "ensemble"])['metric_score'].mean().reset_index()
dice_table = image_info_df[image_info_df["image_metric"] == "Dice"]
dice_loss_table = image_info_df[image_info_df["image_metric"] == "Dice Loss"]
hd95_table = image_info_df[image_info_df["image_metric"] == "HD95"]
# Sort these by method name so they are consistent in the figures
dice_table = dice_table.sort_values(by=['method_name'])
dice_loss_table = dice_loss_table.sort_values(by=['method_name'])
hd95_table = hd95_table.sort_values(by=['method_name'])

In [None]:
# Load the pickled df corresponding to the upper-bound of the uncalibrated UNets
from ese.experiment.analysis.analysis_utils.inference_utils import load_upperbound_df 

# Fill the column corresponding to slice_idx with string 'None'
upperbound_df = load_upperbound_df(results_cfg['log'])

In [None]:
if upperbound_df is not None:
    dice_ub_df = upperbound_df[upperbound_df["image_metric"] == "Dice"]
    dice_loss_ub_df = upperbound_df[upperbound_df["image_metric"] == "Dice Loss"]
    hd_ub_df = upperbound_df[upperbound_df["image_metric"] == "HD95"]
    # De Nan the dice_ub_df
    dice_ub_df = dice_ub_df[dice_ub_df['metric_score'].notna()]
    dice_loss_ub_df = dice_loss_ub_df[dice_loss_ub_df['metric_score'].notna()]
    hd_ub_df = hd_ub_df[hd_ub_df['metric_score'].notna()]

In [None]:
upperbound_df

In [None]:
# from ese.experiment.analysis.analysis_utils.plot_utils import plot_upperbound_line

# g = sns.catplot(
#     data=dice_loss_table,
#     x="calibrator",
#     y="metric_score",
#     hue="method_name",
#     # hue="ensemble_w_metric",
#     # col="method_name",
#     kind="bar",
#     height=6,
#     aspect=2
# )
# num_calibrators = len(image_info_df['calibrator'].unique())
# if upperbound_df is not None:
#     plot_upperbound_line(dice_loss_ub_df["metric_score"], num_calibrators=num_calibrators)
# # Set the title of the bar plot
# g.fig.suptitle("WMH Dice Loss for Different Calibration Methods (Per Slice)")
# # Give the title a bit of spacing from the plot
# g.fig.subplots_adjust(top=0.90)
# # Set the y axis to be between 0.5 and 1.0
# calibrators_width = num_calibrators - 1
# g.set(xlim=(-0.8, calibrators_width + 0.8))
# g.set(ylim=(0.15, 0.3))

In [None]:
from ese.experiment.analysis.analysis_utils.plot_utils import plot_upperbound_line

g = sns.catplot(
    data=dice_loss_table,
    x="calibrator",
    y="metric_score",
    hue="method_name",
    # hue="ensemble_w_metric",
    # col="method_name",
    kind="bar",
    height=6,
    aspect=2
)
num_calibrators = len(image_info_df['calibrator'].unique())
if upperbound_df is not None:
    plot_upperbound_line(dice_loss_ub_df["metric_score"], num_calibrators=num_calibrators)
# Set the title of the bar plot
g.fig.suptitle("CityScapes Dice Loss for Different Calibration Methods (Per Slice)")
# Give the title a bit of spacing from the plot
g.fig.subplots_adjust(top=0.90)
# Set the y axis to be between 0.5 and 1.0
calibrators_width = num_calibrators - 1
g.set(xlim=(-0.8, calibrators_width + 0.8))
g.set(ylim=(0.17, 0.35))

In [None]:
# g = sns.catplot(
#     data=hd95_table,
#     x="calibrator",
#     y="metric_score",
#     hue="method_name",
#     # hue="ensemble_w_metric",
#     # col="method_name",
#     kind="bar",
#     height=6,
#     aspect=2
# )
# num_calibrators = len(image_info_df['calibrator'].unique())
# if upperbound_df is not None:
#     plot_upperbound_line(hd_ub_df["metric_score"], num_calibrators=num_calibrators)
# # Set the title of the bar plot
# g.fig.suptitle("WMH Hausdorff Distance for Different Calibration Methods (Per Slice)")
# # Give the title a bit of spacing from the plot
# g.fig.subplots_adjust(top=0.90)
# # Set the y axis to be between 0.5 and 1.0
# calibrators_width = num_calibrators - 1
# g.set(xlim=(-0.8, calibrators_width + 0.8))
# g.set(ylim=(3.5, 10))

In [None]:
g = sns.catplot(
    data=hd95_table,
    x="calibrator",
    y="metric_score",
    hue="method_name",
    # hue="ensemble_w_metric",
    # col="method_name",
    kind="bar",
    height=6,
    aspect=2
)
num_calibrators = len(image_info_df['calibrator'].unique())
if upperbound_df is not None:
    plot_upperbound_line(hd_ub_df["metric_score"], num_calibrators=num_calibrators)
# Set the title of the bar plot
g.fig.suptitle("CityScapes Hausdorff Distance for Different Calibration Methods (Per Slice)")
# Give the title a bit of spacing from the plot
g.fig.subplots_adjust(top=0.90)
# Set the y axis to be between 0.5 and 1.0
calibrators_width = num_calibrators - 1
g.set(xlim=(-0.8, calibrators_width + 0.8))
g.set(ylim=(27, 52))

## Now we want to consider these averaged within subjects.

In [None]:
# subj_image_info_df = image_info_df.groupby(["method_name", "model_class", "qual_metric", "ensemble", "data_id"])['qual_score'].mean().reset_index()

# subject_dice_table = subj_image_info_df[subj_image_info_df["qual_metric"] == "Dice"]
# subject_hd95_table = subj_image_info_df[subj_image_info_df["qual_metric"] == "HD95"]
# # sort these by method name so they are consistent in the tables
# subject_dice_table = subject_dice_table.sort_values(by=['method_name'])
# subject_hd95_table = subject_hd95_table.sort_values(by=['method_name'])

In [None]:
# g = sns.catplot(
#     data=subject_hd95_table,
#     x="model_class",
#     y="qual_score",
#     hue="method_name",
#     kind="bar",
#     height=4,
#     aspect=2
# )
# # Set the title of the bar plot
# g.fig.suptitle("Hausdorff Distance for Different Calibration Methods (Per Subject)")
# # Give the title a bit of spacing from the plot
# g.fig.subplots_adjust(top=0.90)
# # Set the y axis to be between 4 and 8
# g.set(ylim=(4, 14))

In [None]:
# g = sns.catplot(
#     data=subject_dice_table,
#     x="model_class",
#     y="qual_score",
#     hue="method_name",
#     kind="bar",
#     height=4,
#     aspect=2
# )
# # Set the title of the bar plot
# g.fig.suptitle("Dice Score for Different Calibration Methods (Per Subject)")
# # Give the title a bit of spacing from the plot
# g.fig.subplots_adjust(top=0.90)
# # Set the y axis to be between 0.5 and 1.0
# g.set(ylim=(0.6, 0.9))

## Make some tables to show these relationships in

In [None]:
# # Custom formatting function to display 3 significant digits
# def format_sigfigs(x, num_sigfigs):
#     if isinstance(x, (int, float)):
#         format_str = '{:.' + str(num_sigfigs) + 'g}'
#         return format_str.format(x)  # Using format to display in scientific notation with specified significant digits
#     else:
#         return x  # Return the value as is if it's not numeric

# # Applying the formatting function to the pivot table
# formatted_dice_table = dice_table.applymap(format_sigfigs, num_sigfigs=3)
# # Applying the formatting function to the pivot table
# formatted_hd95_table = hd95_table.applymap(format_sigfigs, num_sigfigs=4)

In [None]:
# formatted_dice_table.pivot(index='method_name', columns='model_class', values='qual_score')

In [None]:
# formatted_hd95_table.pivot(index='method_name', columns='model_class', values='qual_score')

## Let's look first at the distribution of errors per configuration.

In [None]:
# from ese.experiment.analysis.plot_utils import build_ensemble_vs_individual_cmap

In [None]:
# dice_image_df = image_info_df[image_info_df['qual_metric'] == 'Dice']
# # Use seaborn to create KDE plot for each configuration
# g = sns.displot(
#     data=dice_image_df.sort_values('configuration'), 
#     x='qual_score', 
#     hue='configuration', 
#     kind='kde',
#     palette=build_ensemble_vs_individual_cmap(dice_image_df),
#     alpha=0.8
#     )

In [None]:
# dice_image_subject_df = dice_image_df.groupby(['configuration', 'data_id'])['qual_score'].mean().reset_index()
# g = sns.displot(
#     data=dice_image_subject_df.sort_values('configuration'), 
#     x='qual_score', 
#     hue='configuration', 
#     kind='kde',
#     palette=build_ensemble_vs_individual_cmap(dice_image_df),
#     alpha=0.8
#     )