In [None]:
import os 
import sys
import numpy as np
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
# Ionpy imports
from ionpy.analysis import ResultsLoader
# Local imports
from ese.analysis.analyze_inf import load_cal_inference_stats
from ese.analysis.analysis_utils.plot_utils import get_prop_color_palette
from ese.analysis.analysis_utils.parse_sweep import get_global_optimal_parameter, get_per_subject_optimal_values

sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')
sns.set_style("darkgrid")
sns.set_context("talk")
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    # root: '/storage/vbutoi/scratch/ESE/inference/10_09_24_OCTA_Benchmark_Exps'
    # root: '/storage/vbutoi/scratch/ESE/inference/10_09_24_ISLES_Benchmark_Exps'
    root: '/storage/vbutoi/scratch/ESE/inference/10_10_24_Roads_Benchmark_Exps'
    inference_groups: 
        # - 'Threshold_Sweep_SoftDice'
        # - 'Threshold_Sweep_CrossEntropy'
        - 'Temperature_Sweep_SoftDice'
        - 'Temperature_Sweep_CrossEntropy'

options:
    verify_graceful_exit: True
    equal_rows_per_cfg_assert: False 

### Useful cell for controlling the plotting functions.

In [3]:
#######This cells controls what gets plotted in the following cells so we don't have to change each one
# x_key = 'threshold'
# y_key = 'log_hard_volume_error'
# xtick_range = np.arange(0, 1.1, 0.1)
# cmap = 'viridis_r'
# aspect = 1
# y_lims = (-2.5, 12)

x_key = 'temperature'
y_key = 'log_soft_volume_error'
xtick_range = np.arange(0, 3.1, 0.1)
cmap = 'magma_r'
aspect = 2
y_lims = (-2.5, 12)

# Plotting Calls

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=True
)

In [5]:
# for ikey in inference_df.keys():
#     print(ikey)

In [6]:
cols_to_keep = [
    'abs_soft_volume_error',
    'log_soft_volume_error',
    'abs_hard_volume_error',
    'log_hard_volume_error',
    'loss_func_class',
    'temperature',
    'gt_volume',
    'threshold',
    'data_id',
    'split'
]
# Filter out the columns we want to keep
exp_df = inference_df[cols_to_keep].drop_duplicates().reset_index(drop=True)

In [7]:
# We need to make sure that the cal split goes first.
exp_df = exp_df.sort_values('split', ascending=True)

In [None]:
plt.figure(figsize=(30, 20))
# We want to plot the mean error vs temperature
g = sns.relplot(
    data=exp_df,
    x=x_key,
    y=y_key,
    hue='loss_func_class',
    row='split',
    kind='line',
    height=10,
    aspect=aspect,
    legend=(x_key == 'temperature')
)

g.set(xticks=xtick_range, ylim=y_lims)

In [None]:
get_global_optimal_parameter(
    exp_df, 
    sweep_key=x_key, 
    y_key=y_key,
    group_keys=['split', 'loss_func_class']
)

In [None]:
# We want to plot the mean error vs temperature
g = sns.relplot(
    data=exp_df,
    x=x_key,
    y=y_key,
    hue='data_id',
    col='loss_func_class',
    row='split',
    # col='split',
    kind='line',
    height=8,
    aspect=aspect,
    legend=False,
    palette=get_prop_color_palette(
                exp_df, 
                hue_key='data_id', 
                magnitude_key='gt_volume',
                cmap=cmap
            )
)

g.set(xticks=xtick_range, ylim=y_lims)

In [None]:
get_per_subject_optimal_values(
    exp_df, 
    sweep_key=x_key, 
    y_key=y_key,
    group_keys=['split', 'loss_func_class']
)