In [1]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 
# For using yaml configs.
%load_ext yamlmagic

In [2]:
exp_base = "01_10_24_wLabAmounts"
# exp_base = "01_14_24_EnsembleAnalysis"

In [3]:
def sort_by_calibrator(image_info_df):
    # Sort the image_info_df by method name, so everything appears nicely
    image_info_df = image_info_df.sort_values(by=['method_name', 'calibrator'])
    # Make sure that the model_class 'Uncalibrated' is first
    image_info_df['calibrator'] = image_info_df['calibrator'].astype('category')
    image_info_df['calibrator'].cat.reorder_categories(['Uncalibrated', 'Temperature_Scaling', 'Vector_Scaling', 'Dirichlet_Scaling', 'LTS'])
    # Make the category normal again
    image_info_df['calibrator'] = image_info_df['calibrator'].astype('object')
    return image_info_df

In [4]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    load_pixel_meters: True 
    remove_shared_columns: True
    drop_nan_metric_rows: True
    min_fg_pixels: 100
    inference_paths:
        - "01_14_24_EnsembleAnalysis/WMH_Individual_Uncalibrated"
        - "01_14_24_EnsembleAnalysis/WMH_Individual_TempScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Individual_VectorScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Individual_DirichletScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Individual_LTS"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_Uncalibrated"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_TempScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_VectorScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_DirichletScaling"
        - "01_14_24_EnsembleAnalysis/WMH_Ensemble_LTS"
    
calibration:
    conf_interval:
        - 0.5
        - 1.
    num_bins: 10
    square_diff: False 
    neighborhood_width: 3

cal_metrics:
    - ECE:
        _fn: ese.experiment.metrics.ece.ece_loss
    - CW_ECE:
        _fn: ese.experiment.metrics.ece.cw_ece_loss
    - Edge_ECE:
        _fn: ese.experiment.metrics.ece.edge_ece_loss
    - ELM:
        _fn: ese.experiment.metrics.elm.elm_loss
    - Foreground_ECE:
        _fn: ese.experiment.metrics.ece.ece_loss
        ignore_index: 0
    - Foreground_CW_ECE:
        _fn: ese.experiment.metrics.ece.cw_ece_loss
        ignore_index: 0
    - Foreground_Edge_ECE:
        _fn: ese.experiment.metrics.ece.edge_ece_loss
        ignore_index: 0       
    - Foreground_ELM:
        _fn: ese.experiment.metrics.elm.elm_loss
        ignore_index: 0

<IPython.core.display.Javascript object>

In [5]:
from ese.experiment.analysis.inference import load_cal_inference_stats

image_info_df = load_cal_inference_stats(
    results_cfg=results_cfg
)

## We are going to do the same standardization for our df. 

In [6]:
# Add extra variable names.
####################################################################

image_info_df["ensemble"] = image_info_df["model.ensemble"]
image_info_df["pre_softmax"] = image_info_df["model.ensemble_pre_softmax"]
image_info_df["combine_fn"] = image_info_df["model.ensemble_combine_fn"]
image_info_df["pretrained_seed"] = image_info_df["experiment.pretrained_seed"]
image_info_df["model_class"] = image_info_df["model._class"]
image_info_df["pretrained_model_class"] = image_info_df["model._pretrained_class"]

def method_name(model_class, pretrained_model_class, pretrained_seed, ensemble, pre_softmax, combine_fn):
    if ensemble:
        softmax_modifier = "pre" if pre_softmax else "post"
        method_name_string = f"Ensemble ({combine_fn}, {softmax_modifier})" 
    else:
        if pretrained_model_class == "None":
            method_name_string = f"{model_class.split('.')[-1]} (seed={pretrained_seed})"
        else:
            method_name_string = f"{pretrained_model_class.split('.')[-1]} (seed={pretrained_seed})"

    return method_name_string

def calibrator(model_class):
    if "UNet" in model_class:
        return "Uncalibrated"
    else:
        return model_class.split('.')[-1]

def configuration(method_name, calibrator):
    return f"{method_name}_{calibrator}"

def model_type(ensemble):
    return 'group' if ensemble else 'individual'

image_info_df.augment(method_name)
image_info_df.augment(calibrator)
image_info_df.augment(configuration)
image_info_df.augment(model_type)

# Sorting for plotting.
####################################################################
image_info_df = sort_by_calibrator(image_info_df)

In [7]:
image_info_df.head()

Unnamed: 0,data_id,slice_idx,image_metric,metric_score,num_lab_0_pixels,num_lab_1_pixels,num_bins,neighborhood_width,square_diff,log_set,...,ensemble,pre_softmax,combine_fn,pretrained_seed,model_class,pretrained_model_class,method_name,calibrator,configuration,model_type
92815,103,18,Image_ECE,0.001084,65395,141,10,3,False,20240114_172155-V6G8-57101eef185e4d2f57011bee8...,...,True,False,max,42,ese.experiment.models.calibrators.Dirichlet_Sc...,ese.experiment.models.unet.UNet,"Ensemble (max, post)",Dirichlet_Scaling,"Ensemble (max, post)_Dirichlet_Scaling",group
92816,103,18,Image_Edge-ECE,0.038892,65395,141,10,3,False,20240114_172155-V6G8-57101eef185e4d2f57011bee8...,...,True,False,max,42,ese.experiment.models.calibrators.Dirichlet_Sc...,ese.experiment.models.unet.UNet,"Ensemble (max, post)",Dirichlet_Scaling,"Ensemble (max, post)_Dirichlet_Scaling",group
92817,103,18,Image_CW-ECE,0.267799,65395,141,10,3,False,20240114_172155-V6G8-57101eef185e4d2f57011bee8...,...,True,False,max,42,ese.experiment.models.calibrators.Dirichlet_Sc...,ese.experiment.models.unet.UNet,"Ensemble (max, post)",Dirichlet_Scaling,"Ensemble (max, post)_Dirichlet_Scaling",group
92818,103,18,Image_ELM,0.001391,65395,141,10,3,False,20240114_172155-V6G8-57101eef185e4d2f57011bee8...,...,True,False,max,42,ese.experiment.models.calibrators.Dirichlet_Sc...,ese.experiment.models.unet.UNet,"Ensemble (max, post)",Dirichlet_Scaling,"Ensemble (max, post)_Dirichlet_Scaling",group
92819,103,18,Image_Foreground-ECE,0.534777,65395,141,10,3,False,20240114_172155-V6G8-57101eef185e4d2f57011bee8...,...,True,False,max,42,ese.experiment.models.calibrators.Dirichlet_Sc...,ese.experiment.models.unet.UNet,"Ensemble (max, post)",Dirichlet_Scaling,"Ensemble (max, post)_Dirichlet_Scaling",group


In [8]:
image_info_df.keys()

Index(['data_id', 'slice_idx', 'image_metric', 'metric_score',
       'num_lab_0_pixels', 'num_lab_1_pixels', 'num_bins',
       'neighborhood_width', 'square_diff', 'log_set',
       'experiment.pretrained_seed', 'log.root', 'model._class',
       'model.checkpoint', 'model.convs_per_block', 'model.ensemble',
       'model.ensemble_combine_fn', 'model.ensemble_pre_softmax',
       'model.in_channels', 'model.out_channels', 'model.pretrained_exp_root',
       'model._pretrained_class', 'model.image_channels', 'model.num_classes',
       'ECE', 'CW_ECE', 'Edge_ECE', 'ELM', 'Foreground_ECE',
       'Foreground_CW_ECE', 'Foreground_Edge_ECE', 'Foreground_ELM',
       'ensemble', 'pre_softmax', 'combine_fn', 'pretrained_seed',
       'model_class', 'pretrained_model_class', 'method_name', 'calibrator',
       'configuration', 'model_type'],
      dtype='object')

## Let's try to see if there is any hope with having better ECE/ELM makes better ensembles. Note that this isn't a conclusive result just because the number of samples per images that are used to calculate ECE/ELM are not sufficient to get actual statistical quantities.

### First thing we have to do is calculate per slice per model configuration, the delta in performance that each configuration has between that configuration's slice performance and the average un-calibrated UNet performance on that slice.

In [9]:
unet_info_df = image_info_df[image_info_df['ensemble'] == False].reset_index(drop=True)
# Group everything we need. 
group_keys = ['data_id', 'slice_idx', 'image_metric', 'calibrator', 'model_class', 'pretrained_model_class'] 
average_unet_group = unet_info_df.groupby(group_keys).agg({'metric_score': 'mean', 'num_lab_0_pixels': 'mean', 'num_lab_1_pixels': 'mean'}).reset_index()
average_unet_group['pretrained_seed'] = 'Average'
average_unet_group['model_type'] = 'group'

In [10]:
len(unet_info_df)

54595

In [11]:
len(average_unet_group)

13650

In [12]:
def method_name(pretrained_model_class, model_class):
    if pretrained_model_class == "None":
        return f"{model_class.split('.')[-1]} (seed=Average)"
    else:
        return f"{pretrained_model_class.split('.')[-1]} (seed=Average)"

def configuration(method_name, calibrator):
    return f"{method_name}_{calibrator}"

average_unet_group.augment(method_name)
average_unet_group.augment(configuration)

In [13]:
# Add this unet group back to image info df
image_info_df = pd.concat([image_info_df, average_unet_group], axis=0, ignore_index=True)
image_info_df = sort_by_calibrator(image_info_df)

In [14]:
image_info_df['configuration'].unique()

array(['Ensemble (max, post)_Dirichlet_Scaling',
       'Ensemble (max, post)_LTS',
       'Ensemble (max, post)_Temperature_Scaling',
       'Ensemble (max, post)_Uncalibrated',
       'Ensemble (max, post)_Vector_Scaling',
       'Ensemble (max, pre)_Dirichlet_Scaling', 'Ensemble (max, pre)_LTS',
       'Ensemble (max, pre)_Temperature_Scaling',
       'Ensemble (max, pre)_Uncalibrated',
       'Ensemble (max, pre)_Vector_Scaling',
       'Ensemble (mean, post)_Dirichlet_Scaling',
       'Ensemble (mean, post)_LTS',
       'Ensemble (mean, post)_Temperature_Scaling',
       'Ensemble (mean, post)_Uncalibrated',
       'Ensemble (mean, post)_Vector_Scaling',
       'Ensemble (mean, pre)_Dirichlet_Scaling',
       'Ensemble (mean, pre)_LTS',
       'Ensemble (mean, pre)_Temperature_Scaling',
       'Ensemble (mean, pre)_Uncalibrated',
       'Ensemble (mean, pre)_Vector_Scaling',
       'UNet (seed=40)_Dirichlet_Scaling', 'UNet (seed=40)_LTS',
       'UNet (seed=40)_Temperature_Scaling

### Now we have to add to each row a column that is the difference betweeen the row's metric_score and the metric_score corresponding to the same image metric as mean uncalibrated UNet performance. 

In [15]:
# Step 1: Filter the dataframe
average_unet_row = image_info_df[(image_info_df['pretrained_seed'] == 'Average') & (image_info_df['calibrator'] == 'Uncalibrated')]
# assert that for the same data_id, slice_idx, and image_metric, there is only one row
merge_unet_cols = average_unet_row.groupby(['data_id', 'slice_idx', 'image_metric']).size()
assert merge_unet_cols.max() == 1,\
    f"There should be only one row for each data_id, slice_idx, and image_metric combination, got {merge_unet_cols}."

In [16]:
# Step 2: Merge based on 'image_metric', 'subject_id', and 'slice_idx'
merge_columns = ['image_metric', 'data_id', 'slice_idx']
merged_df = pd.merge(image_info_df, average_unet_row[merge_columns + ['metric_score']], on=merge_columns, how='left', suffixes=('', '_average_unet'))

In [17]:
# Step 3: Calculate the difference
merged_df['delta_from_base'] = merged_df['metric_score'] - merged_df['metric_score_average_unet']
# Drop unnecessary columns from the merged dataframe
image_info_df = merged_df.drop(['metric_score_average_unet'], axis=1)

In [18]:
# Assert here that the delta from base from the unet group is 0
base_rows = image_info_df[(image_info_df['pretrained_seed'] == 'Average') & (image_info_df['calibrator'] == 'Uncalibrated')]
assert (base_rows['delta_from_base'] == 0).all(),\
    f"Delta from base should be 0 for the unet group, got {base_rows['delta_from_base']}."

## Now we can look at trends! We want to make some scatterplots to look at relationships between calibration scores and their relative improvement over the baseline.   

In [19]:
def metric_type(image_metric):
    if 'ECE' in image_metric or 'ELM' in image_metric:
        return 'calibration'
    else:
        return 'quality'
    
image_info_df.augment(metric_type)

In [20]:
len(image_info_df)

122843

In [21]:
# Filter rows with 'quality' and 'calibration' metric types
quality_metric_df = image_info_df[image_info_df['metric_type'] == 'quality']
calibration_metric_df = image_info_df[image_info_df['metric_type'] == 'calibration']
# Drop the metric_type columns
quality_metric_df = quality_metric_df.drop(['metric_type'], axis=1)
calibration_metric_df = calibration_metric_df.drop(['metric_type'], axis=1)

df_size = len(image_info_df)
num_qual_metrics = len(quality_metric_df['image_metric'].unique())
num_cal_metrics = len(calibration_metric_df['image_metric'].unique())
print('Num quality metrics:', num_qual_metrics) 
print('Num calibration metrics:', num_cal_metrics)
print('Original df size:', df_size)

Num quality metrics: 2
Num calibration metrics: 8
Original df size: 122843


In [22]:
merge_cols = ['configuration', 'data_id', 'slice_idx']

# Merge the two dataframes on other common columns
combined_metric_df = pd.merge(quality_metric_df, calibration_metric_df[merge_cols + ['image_metric', 'metric_score']], on=merge_cols, suffixes=('', '_calibration'))
# Rename the newly added colums
combined_metric_df = combined_metric_df.rename(columns={
    'image_metric': 'quality_metric', 
    'metric_score': 'quality_score',
    'image_metric_calibration': 'calibration_metric', 
    'metric_score_calibration': 'calibration_score'
    })
assert len(combined_metric_df) == (df_size * (num_cal_metrics, num_qual_metrics)),\
    f"Combined df should have {df_size * (num_cal_metrics * num_qual_metrics)} rows, got {len(combined_metric_df)}."

AssertionError: Combined df should have 1965488 rows, got 196504.

In [None]:
combined_metric_df.keys()

In [None]:
len(combined_metric_df)

In [None]:
combined_metric_df

In [None]:
sns.scatterplot(data=image_info_df, 
                x='calibration_score', 
                y='delta_from_base', 
                col='calibration_metric',
                row='quality_metric',
                hue='calibrator', 
                style='model_type', 
                alpha=0.5
                )