In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_group: "03_03_24_RandomCircles_Shr2Ind"
    
calibration:
    metric_cfg_file: "/storage/vbutoi/projects/ESE/ese/experiment/configs/inference/Calibration_Metrics.yaml"

options:
    add_baseline_rows: True 
    load_pixel_meters: False 
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    load_groupavg_metrics: False
    remove_shared_columns: False
    equal_rows_per_cfg_assert: True 

In [None]:
from ese.experiment.analysis.analyze_inf import load_cal_inference_stats

image_info_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False
)

In [None]:
image_info_df['calibrator'].unique()

## We are going to remove the case where there are very few pixels, cause unrealistic outliers.

In [None]:
# Sort the image_info_df by method name, so everything appears nicely
image_info_df = image_info_df.sort_values(by=['method_name', 'calibrator'])
# Make sure that the model_class 'Uncalibrated' is first
image_info_df['calibrator'] = image_info_df['calibrator'].astype('category')
image_info_df['calibrator'] = image_info_df['calibrator'].cat.reorder_categories([
    'Uncalibrated',
    'FT_CE',
    'FT_Dice',
    'TempScaling', 
    'NectarScaling',
])

In [None]:
image_info_df['method_name'].unique()

In [None]:
# Select only the rows corresponding to group methods
image_info_df = image_info_df[image_info_df['model_type'] == 'group']

image_info_df['method_name'] = image_info_df['method_name'].astype('category')
image_info_df['method_name'] = image_info_df['method_name'].cat.reorder_categories([
    'Average UNet',
    'Ensemble (mean, probs)', 
    'Ensemble (product, probs)', 
    # 'UNet (seed=40)', 
    # 'UNet (seed=41)', 
    # 'UNet (seed=42)', 
    # 'UNet (seed=43)', 
])

image_info_df['split'] = image_info_df['split'].astype('category')
image_info_df['split'] = image_info_df['split'].cat.reorder_categories([
    'val',
    'cal'
])

# Let's looks at the calibration scores of our models.

## ECE Metrics

In [None]:
g = sns.catplot(
    data=image_info_df,
    x="calibrator",
    y="ECE",
    hue="method_name",
    col="split",
    kind="bar",
    height=8,
    aspect=3,
    sharex=False
)
g.fig.suptitle("ECE by Calibration Method and Model Class", fontsize=25)
g.fig.subplots_adjust(top=0.85)

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="CW-ECE",
#     hue="method_name",
#     row="split",
#     kind="bar",
#     height=8,
#     aspect=3,
#     sharex=False
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.001))
# g.fig.suptitle("CW-ECE by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="Uniform-CW-ECE",
#     hue="method_name",
#     row="split",
#     kind="bar",
#     height=8,
#     aspect=3,
#     sharex=False
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.001))
# g.fig.suptitle("Uniform-CW-ECE by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="Edge-ECE",
#     hue="method_name",
#     row="split",
#     kind="bar",
#     height=8,
#     aspect=3,
#     sharex=False
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.05))
# g.fig.suptitle("Edge ECE by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="ECW-ECE",
#     hue="method_name",
#     row="split",
#     kind="bar",
#     height=8,
#     aspect=3,
#     sharex=False
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.004))
# g.fig.suptitle("Edge CW-ECE by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="Uniform-ECW-ECE",
#     hue="method_name",
#     row="split",
#     kind="bar",
#     height=8,
#     aspect=3,
#     sharex=False
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.05))
# g.fig.suptitle("Uniform ECW-ECE by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

## ELM Metrics

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="ELM",
#     hue="method_name",
#     row="split",
#     kind="bar",
#     height=8,
#     aspect=3,
#     sharex=False
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.001))
# g.fig.suptitle("ELM by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="Edge-ELM",
#     hue="method_name",
#     row="split",
#     kind="bar",
#     height=8,
#     aspect=3,
#     sharex=False
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.001))
# g.fig.suptitle("Edge-ELM by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="Uniform-ELM",
#     hue="method_name",
#     row="split",
#     kind="bar",
#     height=8,
#     aspect=3,
#     sharex=False
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.3))
# g.fig.suptitle("Uniform-ELM by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

In [None]:
# g = sns.catplot(
#     data=image_info_df,
#     x="calibrator",
#     y="Uniform-Edge-ELM",
#     hue="method_name",
#     row="split",
#     kind="bar",
#     height=8,
#     aspect=3,
#     sharex=False
# )
# # Set column spacing
# # # Set the y-axis limits
# # g.set(ylim=(0.0, 0.3))
# g.fig.suptitle("Uniform-Edge-ELM by Calibration Method and Model Class")
# # Move the title slightly up
# g.fig.subplots_adjust(top=0.9)

## Now we can look at the quality averages themselves, first looking slice-wise.

In [None]:
image_info_df["image_metric"].unique()

In [None]:
# Load the pickled df corresponding to the upper-bound of the uncalibrated UNets
from ese.experiment.analysis.analysis_utils.inference_utils import load_upperbound_df 

# Fill the column corresponding to slice_idx with string 'None'
upperbound_df = load_upperbound_df(results_cfg['log'])

In [None]:
upperbound_df

In [None]:
from ese.experiment.analysis.analysis_utils.plot_utils import plot_upperbound_line

num_calibrators = len(image_info_df['calibrator'].unique())

In [None]:
g = sns.catplot(
    data=image_info_df[image_info_df["image_metric"] == "Dice"],
    x="calibrator",
    y="metric_score",
    hue="method_name",
    col="split",
    kind="bar",
    height=8,
    aspect=3,
    sharex=False
)
if upperbound_df is not None:
    plot_upperbound_line(
        graph=g, 
        plot_df=dice_ub_df, 
        y="metric_score", 
        num_calibrators=num_calibrators, 
        col="split"
    )
# Set the title of the bar plot
g.fig.suptitle("Dice for Different Calibration Methods", fontsize=25)
# Give the title a bit of spacing from the plot
g.fig.subplots_adjust(top=0.90)
# Set the y axis to be between 0.5 and 1.0
calibrators_width = num_calibrators - 1
g.set(xlim=(-0.8, calibrators_width + 0.8))
# g.set(ylim=(0.75, 1.0))

In [None]:
g = sns.catplot(
    data=image_info_df[image_info_df["image_metric"] == "HD95"],
    x="calibrator",
    y="metric_score",
    hue="method_name",
    col="split",
    kind="bar",
    height=8,
    aspect=3
)
if upperbound_df is not None:
    plot_upperbound_line(
        graph=g, 
        plot_df=hd95_ub_df, 
        y="metric_score", 
        num_calibrators=num_calibrators, 
        col="split"
    )
# Set the title of the bar plot
g.fig.suptitle("HD95 for Different Calibration Methods", fontsize=25)
# Give the title a bit of spacing from the plot
g.fig.subplots_adjust(top=0.85)
# Set the y axis to be between 0.5 and 1.0
calibrators_width = num_calibrators - 1
g.set(xlim=(-0.8, calibrators_width + 0.8))
# g.set(ylim=(3.0, 12))

In [None]:
g = sns.catplot(
    data=image_info_df[image_info_df["image_metric"] == "BoundaryIOU"],
    x="calibrator",
    y="metric_score",
    hue="method_name",
    col="split",
    kind="bar",
    height=8,
    aspect=3,
    sharex=False
)
if upperbound_df is not None:
    plot_upperbound_line(
        graph=g, 
        plot_df=boundaryiou_df, 
        y="metric_score", 
        num_calibrators=num_calibrators, 
        col="split"
    )
# Set the title of the bar plot
g.fig.suptitle("Boundary IoU for Different Calibration Methods", fontsize=25)
# Give the title a bit of spacing from the plot
g.fig.subplots_adjust(top=0.85)
# Set the y axis to be between 0.5 and 1.0
calibrators_width = num_calibrators - 1
g.set(xlim=(-0.8, calibrators_width + 0.8))
# g.set(ylim=(0.1, 0.3))

# Looking at Ensemble Variance Under Different Definitions

In [None]:
ensemble_info_df = image_info_df[image_info_df['method_name'] != "Average UNet"]

In [None]:
# g = sns.catplot(
#     data=ensemble_info_df[ensemble_info_df["image_metric"] == "Avg-PW Soft-Dice"],
#     x="calibrator",
#     y="metric_score",
#     col="split",
#     kind="bar",
#     height=8,
#     aspect=3,
# )
# # Set the title of the bar plot
# g.fig.suptitle("Average Pairwise Soft Dice for Different Calibration Methods", fontsize=25)
# # Give the title a bit of spacing from the plot
# g.fig.subplots_adjust(top=0.85)
# # Set the y axis to be between 0.5 and 1.0
# calibrators_width = num_calibrators - 1
# g.set(xlim=(-0.8, calibrators_width + 0.8))
# # g.set(ylim=(0.8, 1.0))

In [None]:
# g = sns.catplot(
#     data=ensemble_info_df[ensemble_info_df["image_metric"] == "Avg-PW Hard-Dice"],
#     x="calibrator",
#     y="metric_score",
#     col="split",
#     kind="bar",
#     height=8,
#     aspect=3,
# )
# # Set the title of the bar plot
# g.fig.suptitle("Average Pairwise Hard Dice for Different Calibration Methods", fontsize=25)
# # Give the title a bit of spacing from the plot
# g.fig.subplots_adjust(top=0.85)
# # Set the y axis to be between 0.5 and 1.0
# calibrators_width = num_calibrators - 1
# g.set(xlim=(-0.8, calibrators_width + 0.8))
# # g.set(ylim=(0.8, 1.0))

In [None]:
# g = sns.catplot(
#     data=ensemble_info_df[ensemble_info_df["image_metric"] == "Ensemble-VAR"],
#     x="calibrator",
#     y="metric_score",
#     col="split",
#     kind="bar",
#     height=8,
#     aspect=3,
# )
# # Set the title of the bar plot
# g.fig.suptitle("Ensemble Variance Pixel-Probs for Different Calibration Methods", fontsize=25)
# # Give the title a bit of spacing from the plot
# g.fig.subplots_adjust(top=0.85)
# # Set the y axis to be between 0.5 and 1.0
# calibrators_width = num_calibrators - 1
# g.set(xlim=(-0.8, calibrators_width + 0.8))
# # g.set(ylim=(0.8, 1.0))

In [None]:
g = sns.catplot(
    data=ensemble_info_df[ensemble_info_df["image_metric"] == "Ambiguity"],
    x="calibrator",
    y="metric_score",
    col="split",
    kind="bar",
    height=8,
    aspect=3,
)
# Set the title of the bar plot
g.fig.suptitle("Pixel-wise Ensemble Ambiguity for Different Calibration Methods", fontsize=25)
# Give the title a bit of spacing from the plot
g.fig.subplots_adjust(top=0.85)
# Set the y axis to be between 0.5 and 1.0
calibrators_width = num_calibrators - 1
g.set(xlim=(-0.8, calibrators_width + 0.8))
g.set(ylim=(0.0, 0.015))