In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 
# For using yaml configs.
%load_ext yamlmagic

In [None]:
from ese.experiment.analysis.inference import load_cal_inference_stats

inference_paths = [
    root / "inference/01_09_24_ExpandedMetrics/WMH_Individual_Uncalibrated",
    root / "inference/01_09_24_ExpandedMetrics/WMH_Individual_TempScaling",
    root / "inference/01_09_24_ExpandedMetrics/WMH_Individual_VectorScaling",
    root / "inference/01_09_24_ExpandedMetrics/WMH_Individual_DirichletScaling",
    root / "inference/01_09_24_ExpandedMetrics/WMH_Individual_LTS",
    root / "inference/01_09_24_ExpandedMetrics/WMH_Ensemble_Uncalibrated",
    root / "inference/01_09_24_ExpandedMetrics/WMH_Ensemble_TempScaling",
    root / "inference/01_09_24_ExpandedMetrics/WMH_Ensemble_VectorScaling",
    root / "inference/01_09_24_ExpandedMetrics/WMH_Ensemble_DirichletScaling",
    root / "inference/01_09_24_ExpandedMetrics/WMH_Ensemble_LTS"
]

inference_info_dict = load_cal_inference_stats(
    log_dirs=inference_paths,
    load_image_df=True,
    load_pixel_meters_dict=False
)
image_info_df = inference_info_dict['image_info_df']
image_info_df = image_info_df.fillna('None')

In [None]:
image_info_df['model._class'].unique()

In [None]:
image_info_df[image_info_df['model._class'] == "ese.experiment.models.calibrators.Temperature_Scaling"]

In [None]:
image_info_df['model._pretrained_class'].iloc[0]

In [None]:
image_info_df.head(10)

In [None]:
image_info_df["ensemble"] = image_info_df["model.ensemble"]
image_info_df["pre_softmax"] = image_info_df["model.ensemble_pre_softmax"]
image_info_df["combine_fn"] = image_info_df["model.ensemble_combine_fn"]

image_info_df["pretrained_seed"] = image_info_df["experiment.pretrained_seed"]

image_info_df["model_class"] = image_info_df["model._class"]
image_info_df["pretrained_model_class"] = image_info_df["model._pretrained_class"]

def method_name(model_class, pretrained_model_class, pretrained_seed, ensemble, pre_softmax, combine_fn):
    if ensemble:
        softmax_modifier = "pre" if pre_softmax else "post"
        method_name_string = f"Ensemble ({combine_fn}, {softmax_modifier})" 
    else:
        if pretrained_model_class == "None":
            method_name_string = f"{model_class.split('.')[-1]} (seed={pretrained_seed})"
        else:
            method_name_string = f"{pretrained_model_class.split('.')[-1]} (seed={pretrained_seed})"

    return method_name_string

def model_class(model_class):
    if "UNet" in model_class:
        return "Uncalibrated"
    else:
        return model_class.split('.')[-1]

def configuration(method_name, model_class):
    return f"{method_name}_{model_class}"

image_info_df.augment(method_name)
image_info_df.augment(model_class)
image_info_df.augment(configuration)

## Now we can look at the averages themselves.

In [None]:
table_df = image_info_df.groupby(["method_name", "model_class", "qual_metric", "ensemble"])['qual_score'].mean().reset_index()
dice_table = table_df[table_df["qual_metric"] == "Dice"]
hd95_table = table_df[table_df["qual_metric"] == "HD95"]

In [None]:
dice_table.head()

In [None]:
g = sns.catplot(
    data=hd95_table,
    x="model_class",
    y="qual_score",
    hue="method_name",
    kind="bar",
    height=4,
    aspect=2
)
# Set the title of the bar plot
g.fig.suptitle("Hausdorff Distance for Different Calibration Methods")
# Give the title a bit of spacing from the plot
g.fig.subplots_adjust(top=0.90)

In [None]:
g = sns.catplot(
    data=dice_table,
    x="model_class",
    y="qual_score",
    hue="method_name",
    kind="bar",
    height=4,
    aspect=2
)
# Set the title of the bar plot
g.fig.suptitle("Dice Score for Different Calibration Methods")
# Give the title a bit of spacing from the plot
g.fig.subplots_adjust(top=0.90)

In [None]:
# Custom formatting function to display 3 significant digits
def format_sigfigs(x, num_sigfigs):
    if isinstance(x, (int, float)):
        format_str = '{:.' + str(num_sigfigs) + 'g}'
        return format_str.format(x)  # Using format to display in scientific notation with specified significant digits
    else:
        return x  # Return the value as is if it's not numeric

# Applying the formatting function to the pivot table
formatted_dice_table = dice_table.applymap(format_sigfigs, num_sigfigs=3)
# Applying the formatting function to the pivot table
formatted_hd95_table = hd95_table.applymap(format_sigfigs, num_sigfigs=4)

In [None]:
formatted_dice_table.pivot(index='method_name', columns='model_class', values='qual_score')

In [None]:
formatted_hd95_table.pivot(index='method_name', columns='model_class', values='qual_score')

## Let's look first at the distribution of errors per configuration.

In [None]:
def build_ensemble_vs_individual_cmap(dice_image_df):
    # Build a custom color palette where each configuration is mapped to a color_map
    # corresponding to if it is an ensemble or individual model.
    num_individual_configurations = len(dice_image_df[dice_image_df['ensemble'] == False]['configuration'].unique())
    num_ensemble_configurations = len(dice_image_df[dice_image_df['ensemble'] == True]['configuration'].unique())
    # Define the palettes
    individual_palette = sns.color_palette("rocket", num_individual_configurations)
    ensemble_palette = sns.color_palette("mako", num_ensemble_configurations) 
    # Build the color map
    individual_colors = {}
    for i, configuration in enumerate(dice_image_df[dice_image_df['ensemble'] == False]['configuration'].unique()):
        individual_colors[configuration] = individual_palette[i]
    ensemble_colors = {}
    for i, configuration in enumerate(dice_image_df[dice_image_df['ensemble'] == True]['configuration'].unique()):
        ensemble_colors[configuration] = ensemble_palette[i]
    # Combine the two color maps
    return {
        **individual_colors,
        **ensemble_colors
    }

In [None]:
dice_image_df = image_info_df[image_info_df['qual_metric'] == 'Dice']
# Use seaborn to create KDE plot for each configuration
g = sns.displot(
    data=dice_image_df.sort_values('configuration'), 
    x='qual_score', 
    hue='configuration', 
    kind='kde',
    palette=build_ensemble_vs_individual_cmap(dice_image_df),
    alpha=0.8
    )

In [None]:
dice_image_subject_df = dice_image_df.groupby(['configuration', 'data_id'])['qual_score'].mean().reset_index()
g = sns.displot(
    data=dice_image_subject_df.sort_values('configuration'), 
    x='qual_score', 
    hue='configuration', 
    kind='kde',
    palette=build_ensemble_vs_individual_cmap(dice_image_df),
    alpha=0.8
    )