In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
from ese.analysis.analyze_inf import load_cal_inference_stats
from ese.analysis.analysis_utils.plot_utils import get_prop_color_palette
# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: '/storage/vbutoi/scratch/ESE/inference/10_09_24_OCTA_Benchmark_Exps'
    # root: '/storage/vbutoi/scratch/ESE/inference/10_09_24_ISLES_Benchmark_Exps'
    # root: '/storage/vbutoi/scratch/ESE/inference/10_10_24_Roads_Benchmark_Exps'
    inference_groups: 
        - 'Threshold_Sweep_SoftDice'
        # - 'Temperature_Sweep_SoftDice'
        # - 'Threshold_Sweep_CrossEntropy'
        # - 'Temperature_Sweep_CrossEntropy'

options:
    verify_graceful_exit: True
    equal_rows_per_cfg_assert: False 

### Useful cell for controlling the plotting functions.

In [3]:
########This cells controls what gets plotted in the following cells so we don't have to change each one
# x_key = 'threshold'
# y_key = 'log_hard_volume_error'
# xtick_range = np.arange(0, 1.1, 0.1)
# aspect = 1

x_key = 'temperature'
y_key = 'log_soft_volume_error'
xtick_range = np.arange(0, 3.1, 0.1)
aspect = 2

# Plotting Calls

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False
)

In [5]:
# for ikey in inference_df.keys():
#     print(ikey)

In [6]:
cols_to_keep = [
    'abs_soft_volume_error',
    'log_soft_volume_error',
    'abs_hard_volume_error',
    'log_hard_volume_error',
    'loss_func_class',
    'temperature',
    'gt_volume',
    'threshold',
    'data_id',
    'split'
]
# Filter out the columns we want to keep
exp_df = inference_df[cols_to_keep].drop_duplicates().reset_index(drop=True)

In [7]:
# We need to make sure that the cal split goes first.
exp_df = exp_df.sort_values('split', ascending=True)

In [None]:
plt.figure(figsize=(30, 20))
# We want to plot the mean error vs temperature
g = sns.relplot(
    data=exp_df,
    x=x_key,
    y=y_key,
    hue='loss_func_class',
    col='split',
    kind='line',
    height=10,
    aspect=aspect,
)

g.set(xticks=xtick_range)

In [None]:
group_cols = ['split', 'loss_func_class']
# Get the optimal threshold for each split out. First we have to average across the data_ids
optimal_threshold_df = exp_df.groupby(group_cols + [x_key]).mean().reset_index()
# Then we get the threshold that minimizes the error
optimal_threshold_df = optimal_threshold_df.loc[optimal_threshold_df.groupby(group_cols)[y_key].idxmin()]
# Finally, we only keep the columns we care about.
optimal_threshold_df = optimal_threshold_df[group_cols + [x_key]].reset_index(drop=True)

In [None]:
optimal_threshold_df

In [None]:
# We want to plot the mean error vs temperature
g = sns.relplot(
    data=exp_df,
    x=x_key,
    y=y_key,
    hue='data_id',
    # col='loss_func_class',
    # row='split',
    col='split',
    kind='line',
    height=8,
    aspect=aspect,
    legend=False,
    palette=get_prop_color_palette(
                exp_df, 
                hue_key='data_id', 
                magnitude_key='gt_volume',
                cmap='magma_r'
            )
)

g.set(xticks=xtick_range)