In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

from ese.experiment.analysis.analyze_inf import load_cal_inference_stats
# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - '06_04_24_WMH_DifferenceExps'

options:
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    remove_shared_columns: False
    equal_rows_per_cfg_assert: False 

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
# We want to compare how Dice relates to ECE, this means we need to pivot our dataframe
inference_df_piv = inference_df.pivot(index=['exp_name', 'calibrator', 'split', 'dataset_name', 'data_id', 'pretrained_exp_root', 'pred_hash'], columns='image_metric', values='metric_score').reset_index()

In [None]:
inference_df_piv['calibrator'] = inference_df_piv['calibrator'].astype('category')
inference_df_piv['calibrator'] = inference_df_piv['calibrator'].cat.reorder_categories([
    'Uncalibrated',
    'TempScaling',
    'LTS',
])

inference_df_piv['split'] = inference_df_piv['split'].astype('category')
inference_df_piv['split'] = inference_df_piv['split'].cat.reorder_categories([
    'val',
    'cal'
])

In [None]:
sns.catplot(
    data=inference_df_piv,
    x='calibrator',
    y='Dice',
    kind='bar',
    errorbar='sd',
    col='split',
    height=5,
    aspect=2,
    legend_out=False,
)

In [None]:
sns.catplot(
    data=inference_df_piv,
    x='calibrator',
    y='Image_ECE',
    kind='bar',
    errorbar='sd',
    col='split',
    height=5,
    aspect=2,
    legend_out=False,
)

In [None]:
inference_df_piv.keys()

In [None]:
# Melt the dataframe
inference_df_melted = pd.melt(inference_df, id_vars=['data_id', 'calibrator', 'pretrained_exp_root', 'split', 'pred_hash', 'dataset_name'], value_vars=['gt_volume', 'soft_volume', 'hard_volume'], var_name='Volume_Type', value_name='Volume')

In [None]:
sns.catplot(
    data=inference_df_melted,
    x='dataset_name',
    y='Volume',
    hue='Volume_Type',
    kind='boxen',
    col='split',
    row='calibrator',
    errorbar='sd',
    height=5,
    aspect=3,
    legend_out=True,
)


In [None]:
# Melt the dataframe
inference_df_piv_melted = pd.melt(inference_df_piv, id_vars=['data_id', 'calibrator', 'pretrained_exp_root', 'split', 'pred_hash', 'dataset_name', 'Image_ECE'], value_vars=['SoftVolumeError', 'HardVolumeError'], var_name='Pred_Type', value_name='Measurement_Error')

In [None]:
sns.relplot(
    data=inference_df_piv_melted,
    x='Image_ECE',
    y='Measurement_Error',
    hue='calibrator',
    col='split',
    height=5,
    alpha=0.8,
)

In [None]:
# Create a FacetGrid with dataset_name as columns
g = sns.FacetGrid(
    inference_df_piv_melted, 
    col='split', 
    height=5, 
    aspect=1.5, 
    sharex=False, 
    sharey=False
)

# Map the KDE plot to the grid
g.map(sns.kdeplot, 'Image_ECE', fill=True, alpha=0.5)

# Add a legend
g.add_legend()

# Adjust the layout
plt.subplots_adjust(top=0.9)
g.fig.suptitle('KDE Plots of Image ECE')

# Show the plot
plt.show()

In [None]:
import numpy as np

# Create a FacetGrid with dataset_name as columns
g = sns.FacetGrid(
    inference_df_piv_melted, 
    col='calibrator', 
    row='split',
    hue='Pred_Type', 
    height=8, 
    aspect=1.2, 
    sharex=False, 
    sharey=False
)

def kde_with_mean(data, **kwargs):
    sns.kdeplot(data, fill=True, alpha=0.5, **kwargs)
    mean_val = np.mean(data)
    plt.axvline(mean_val, linestyle='--', color=kwargs.get('color', 'k'))
    plt.axvline(0, linestyle='--', color='black', linewidth=1.5)

# Map the KDE plot to the grid
g.map(kde_with_mean, 'Measurement_Error')

# Add a legend
g.add_legend()

# Adjust the layout
plt.subplots_adjust(top=0.9)
g.fig.suptitle('KDE Plots of Measurement Error by Dataset and Prediction Type')

# Show the plot
plt.show()