In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

from ese.experiment.analysis.analyze_inf import load_cal_inference_stats
# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - '05_27_24_SW_SoftVols'

options:
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    remove_shared_columns: False
    equal_rows_per_cfg_assert: False 

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
inference_df.keys()

In [None]:
# We want to compare how Dice relates to ECE, this means we need to pivot our dataframe
inference_df_piv = inference_df.pivot(index=['exp_name', 'data_id', 'sup_idx', 'pred_hash'], columns='image_metric', values='metric_score').reset_index()

In [None]:
sns.catplot(
    data=inference_df_piv,
    x='data_id',
    y='Dice',
    kind='boxen',
    errorbar='sd',
    height=5,
    aspect=3,
    legend_out=False,
)

In [None]:
sns.catplot(
    data=inference_df_piv,
    x='data_id',
    y='Image_ECE',
    kind='boxen',
    errorbar='sd',
    height=5,
    aspect=3,
    legend_out=False,
)

In [None]:
inference_df_piv.keys()

In [None]:
# Melt the dataframe
inference_df_melted = pd.melt(inference_df, id_vars=['data_id', 'sup_idx', 'pred_hash'], value_vars=['gt_volume', 'soft_volume', 'hard_volume'], var_name='Volume_Type', value_name='Volume')

In [None]:
sns.catplot(
    data=inference_df_melted,
    x='data_id',
    y='Volume',
    hue='Volume_Type',
    kind='boxen',
    errorbar='sd',
    height=5,
    aspect=3,
    legend_out=True,
)


In [None]:
# Melt the dataframe
inference_df_piv_melted = pd.melt(inference_df_piv, id_vars=['data_id', 'sup_idx', 'pred_hash', 'Image_ECE'], value_vars=['SoftVolumeError', 'HardVolumeError'], var_name='Pred_Type', value_name='Measurement_Error')

In [None]:
sns.relplot(
    data=inference_df_piv_melted,
    x='Image_ECE',
    y='Measurement_Error',
    hue='Pred_Type',
    col='data_id',
    height=5,
    alpha=0.8,
)

In [None]:
# Plot KDE plots of the measurement error for the different kinds of pred_type
plt.figure(figsize=(12, 8))

sns.kdeplot(
    data=inference_df_piv_melted,
    x='Measurement_Error',
    hue='Pred_Type',
    common_norm=False,
    fill=True,
    alpha=0.5,
)

In [None]:
# Get the prediction hashes of the preds per subject with the lowest ECE
min_ece_preds = inference_df_piv.groupby(['data_id', 'sup_idx'])['Image_ECE'].idxmin()

In [None]:
min_ece_preds