In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - "04_24_24_UniverSeg_VaryThreshold"
    
calibration:
    metric_cfg_file: "/storage/vbutoi/projects/ESE/ese/experiment/configs/inference/Calibration_Metrics.yaml"

options:
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    remove_shared_columns: False
    equal_rows_per_cfg_assert: False 

In [None]:
from ese.experiment.analysis.analyze_inf import load_cal_inference_stats

image_info_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
image_info_df.keys()

In [None]:
def task(dataset_name, label_idx):
    return f"{dataset_name}_{label_idx}"

image_info_df.augment(task)

In [None]:
image_info_df['task'].unique()

In [None]:
# We want to compare how Dice relates to ECE, this means we need to pivot our dataframe
df_pivot = image_info_df.pivot(index=['data_id', 'sup_idx', 'task', 'dataset_name', 'threshold'], columns='image_metric', values='metric_score').reset_index()

In [None]:
# set the figsize
g = sns.relplot(
    data=df_pivot,
    x='threshold',
    y='Dice',
    kind='line',
    hue='dataset_name',
    col="task",
    col_wrap=3,
    height=6,
    facet_kws={
        "sharex":False,
        "sharey":False
    }
)
# set the y axis label to be dice score
g.fig.subplots_adjust(wspace=0.25, hspace=0.2)

g.fig.suptitle("Dice vs Image ECE per held-out task (#samp per subj=5, #support=8)", fontsize=25)
g.fig.subplots_adjust(top=0.9)

# Set ticks for every 0.1 between 0 and 1 for every subplot
for ax in g.axes.flat:
    ax.set_xticks([0.1 * i for i in range(11)])
    ax.set_xticklabels([f"{0.1 * i:.1f}" for i in range(11)])
