In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 
# For using yaml configs.
%load_ext yamlmagic

In [None]:
exp_base = "01_10_24_wLabAmounts"
# exp_base = "01_14_24_EnsembleAnalysis"

In [None]:
def sort_by_calibrator(image_info_df):
    # Make sure that the model_class 'Uncalibrated' is first
    image_info_df['calibrator'] = image_info_df['calibrator'].astype('category')
    image_info_df['calibrator'].cat.reorder_categories(['Uncalibrated', 'Temperature_Scaling', 'Vector_Scaling', 'Dirichlet_Scaling', 'LTS'])
    # Make the category normal again
    image_info_df['calibrator'] = image_info_df['calibrator'].astype('object')
    return image_info_df

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    load_pixel_meters: True 
    remove_shared_columns: True
    drop_nan_metric_rows: True
    min_fg_pixels: 100
    inference_paths:
        - "01_14_24_EnsembleAnalysis/WMH_Individual_Uncalibrated"
        - "01_14_24_EnsembleAnalysis/WMH_Individual_TempScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Individual_VectorScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Individual_DirichletScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Individual_LTS"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_Uncalibrated"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_TempScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_VectorScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_DirichletScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_LTS"
    
calibration:
    conf_interval:
        - 0.5
        - 1.
    num_bins: 10
    square_diff: False 
    neighborhood_width: 3

cal_metrics:
    - ECE:
        _fn: ese.experiment.metrics.ece.ece_loss
    - CW_ECE:
        _fn: ese.experiment.metrics.ece.cw_ece_loss
    - Edge_ECE:
        _fn: ese.experiment.metrics.ece.edge_ece_loss
    - ELM:
        _fn: ese.experiment.metrics.elm.elm_loss
    - Foreground_ECE:
        _fn: ese.experiment.metrics.ece.ece_loss
        ignore_index: 0
    - Foreground_CW_ECE:
        _fn: ese.experiment.metrics.ece.cw_ece_loss
        ignore_index: 0
    - Foreground_Edge_ECE:
        _fn: ese.experiment.metrics.ece.edge_ece_loss
        ignore_index: 0       
    - Foreground_ELM:
        _fn: ese.experiment.metrics.elm.elm_loss
        ignore_index: 0

In [None]:
from ese.experiment.analysis.inference import load_cal_inference_stats

image_info_df = load_cal_inference_stats(
    results_cfg=results_cfg
)

## We are going to do the same standardization for our df. 

In [None]:
# Add extra variable names.
####################################################################

def method_name(model_class, pretrained_model_class, pretrained_seed, ensemble, pre_softmax, combine_fn):
    if ensemble:
        softmax_modifier = "pre" if pre_softmax else "post"
        method_name_string = f"Ensemble ({combine_fn}, {softmax_modifier})" 
    else:
        if pretrained_model_class == "None":
            method_name_string = f"{model_class.split('.')[-1]} (seed={pretrained_seed})"
        else:
            method_name_string = f"{pretrained_model_class.split('.')[-1]} (seed={pretrained_seed})"

    return method_name_string

def calibrator(model_class):
    if "UNet" in model_class:
        return "Uncalibrated"
    else:
        return model_class.split('.')[-1]

def configuration(method_name, calibrator):
    return f"{method_name}_{calibrator}"

def model_type(ensemble):
    return 'group' if ensemble else 'individual'

def metric_type(image_metric):
    if 'ECE' in image_metric or 'ELM' in image_metric:
        return 'calibration'
    else:
        return 'quality'

image_info_df.augment(metric_type)
image_info_df.augment(method_name)
image_info_df.augment(calibrator)
image_info_df.augment(configuration)
image_info_df.augment(model_type)

## Let's try to see if there is any hope with having better ECE/ELM makes better ensembles. Note that this isn't a conclusive result just because the number of samples per images that are used to calculate ECE/ELM are not sufficient to get actual statistical quantities.

### First thing we have to do is calculate per slice per model configuration, the delta in performance that each configuration has between that configuration's slice performance and the average un-calibrated UNet performance on that slice.

In [None]:
unet_info_df = image_info_df[image_info_df['ensemble'] == False].reset_index(drop=True)
# Group everything we need. 
group_keys = ['data_id', 'slice_idx', 'image_metric', 'calibrator', 'model_class', 'metric_type', 'pretrained_model_class'] 
average_unet_group = unet_info_df.groupby(group_keys).agg({'metric_score': 'mean', 'num_lab_0_pixels': 'mean', 'num_lab_1_pixels': 'mean'}).reset_index()
average_unet_group['pretrained_seed'] = 'Average'
average_unet_group['model_type'] = 'group'

In [None]:
def method_name(pretrained_model_class, model_class):
    if pretrained_model_class == "None":
        return f"{model_class.split('.')[-1]} (seed=Average)"
    else:
        return f"{pretrained_model_class.split('.')[-1]} (seed=Average)"

def configuration(method_name, calibrator):
    return f"{method_name}_{calibrator}"

average_unet_group.augment(method_name)
average_unet_group.augment(configuration)

In [None]:
# Add this unet group back to image info df
image_info_df = pd.concat([image_info_df, average_unet_group], axis=0, ignore_index=True)

### Now we have to add to each row a column that is the difference betweeen the row's metric_score and the metric_score corresponding to the same image metric as mean uncalibrated UNet performance. 

In [None]:
# Step 1: Filter the dataframe
average_unet_row = image_info_df[(image_info_df['pretrained_seed'] == 'Average') & (image_info_df['calibrator'] == 'Uncalibrated')]
# assert that for the same data_id, slice_idx, and image_metric, there is only one row
merge_unet_cols = average_unet_row.groupby(['data_id', 'slice_idx', 'image_metric']).size()
assert merge_unet_cols.max() == 1,\
    f"There should be only one row for each data_id, slice_idx, and image_metric combination, got {merge_unet_cols}."

In [None]:
# Step 2: Merge based on 'image_metric', 'subject_id', and 'slice_idx'
merge_columns = ['image_metric', 'data_id', 'slice_idx']
merged_df = pd.merge(image_info_df, average_unet_row[merge_columns + ['metric_score']], on=merge_columns, how='left', suffixes=('', '_average_unet'))

In [None]:
# Step 3: Calculate the difference
merged_df['delta'] = merged_df['metric_score'] - merged_df['metric_score_average_unet'] # Current - Baseline
# Drop unnecessary columns from the merged dataframe
image_info_df = merged_df.drop(['metric_score_average_unet'], axis=1)
# Fill the NaNs with None
image_info_df = image_info_df.fillna('None')

In [None]:
# Assert here that the delta from base from the unet group is 0
base_rows = image_info_df[(image_info_df['pretrained_seed'] == 'Average') & (image_info_df['calibrator'] == 'Uncalibrated')]
assert (base_rows['delta'] == 0).all(),\
    f"Delta from base should be 0 for the unet group, got {base_rows['delta']}."

## Now we can look at trends! We want to make some scatterplots to look at relationships between calibration scores and their relative improvement over the baseline.   

In [None]:
# First, we want only the rows corresponding to group metrics, no longer looking at seeds.
image_info_df = image_info_df[image_info_df['model_type'] == 'group'].reset_index(drop=True)

In [None]:
# Create a pivot table with 'metric_type' as columns
pivot_df = image_info_df.pivot_table(
    index=['configuration', 'method_name', 'calibrator', 'data_id', 'slice_idx'],
    columns=['metric_type', 'image_metric'], 
    values=['metric_score', 'delta'], 
    aggfunc='mean'
).reset_index()

In [None]:
pivot_df

In [None]:
new_cols = []
for col in pivot_df.columns.values:
    if col[0] == 'delta':
        new_cols.append(f'delta_{col[-1]}')
    elif col[-1] == '':
        new_cols.append(col[0])
    else:
        new_cols.append(col[-1])
# Set the column names to be the lowest non empty level per column in the multi-index
pivot_df.columns = new_cols

In [None]:
pivot_df.columns

In [None]:
pivot_df_sorted = sort_by_calibrator(pivot_df)

## Looking at change in calibration vs change in Dice.

In [None]:
g = sns.relplot(
    data=pivot_df_sorted, 
    x='delta_Image_ECE', 
    y='delta_Dice',
    col='method_name',
    row='calibrator',
    hue='method_name',
    style='calibrator',
    height=2.5
    )
g.set_titles("")  # Set titles to empty string

In [None]:
g = sns.relplot(
    data=pivot_df_sorted, 
    x='delta_Image_Foreground-ECE', 
    y='delta_Dice',
    col='method_name',
    row='calibrator',
    hue='method_name',
    style='calibrator',
    height=2.5
    )
g.set_titles("")  # Set titles to empty string

In [None]:
g = sns.relplot(
    data=pivot_df_sorted, 
    x='delta_Image_ELM', 
    y='delta_Dice',
    col='method_name',
    row='calibrator',
    hue='method_name',
    style='calibrator',
    height=2.5
    )
g.set_titles("")  # Set titles to empty string

In [None]:
g = sns.relplot(
    data=pivot_df_sorted, 
    x='delta_Image_Foreground-ELM', 
    y='delta_Dice',
    col='method_name',
    row='calibrator',
    hue='method_name',
    style='calibrator',
    height=2.5
    )
g.set_titles("")  # Set titles to empty string

## Looking at change in calibration vs change in HD95.

In [None]:
g = sns.relplot(
    data=pivot_df_sorted, 
    x='delta_Image_ECE', 
    y='delta_HD95',
    col='method_name',
    row='calibrator',
    hue='method_name',
    style='calibrator',
    height=2.5
    )
g.set_titles("")  # Set titles to empty string

In [None]:
g = sns.relplot(
    data=pivot_df_sorted, 
    x='delta_Image_Foreground-ECE', 
    y='delta_HD95',
    col='method_name',
    row='calibrator',
    hue='method_name',
    style='calibrator',
    height=2.5
    )
g.set_titles("")  # Set titles to empty string

In [None]:
g = sns.relplot(
    data=pivot_df_sorted, 
    x='delta_Image_ELM', 
    y='delta_HD95',
    col='method_name',
    row='calibrator',
    hue='method_name',
    style='calibrator',
    height=2.5
    )
g.set_titles("")  # Set titles to empty string

In [None]:
g = sns.relplot(
    data=pivot_df_sorted, 
    x='delta_Image_Foreground-ELM', 
    y='delta_HD95',
    col='method_name',
    row='calibrator',
    hue='method_name',
    style='calibrator',
    height=2.5
    )
g.set_titles("")  # Set titles to empty string