In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

# Standard imports
import os 
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
# Ionpy imports
from ionpy.analysis import ResultsLoader
# Local imports
from ese.analysis.analyze_inf import load_cal_inference_stats

sns.set_style("darkgrid")
sns.set_context("talk")
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root:
        - '/storage/vbutoi/scratch/ESE/inference/10_26_24_OCTA_Benchmark'
        - '/storage/vbutoi/scratch/ESE/inference/10_26_24_ISLES_Benchmark'
        - '/storage/vbutoi/scratch/ESE/inference/10_26_24_WMH_Benchmark'
        - '/storage/vbutoi/scratch/ESE/inference/10_26_24_Roads_FULLRES_Benchmark'
    inference_group: 
        - 'Base'
        - 'Optimal_RAVE_Threshold'
        - 'Optimal_RAVE_Temperature'

options:
    verify_graceful_exit: True
    equal_rows_per_cfg_assert: False 

# Plotting Calls

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False
)

In [4]:
def method_group(log_root):
    suffix = log_root.split('/')[-1]
    if "Base" in suffix:
        return "Base"
    elif "Threshold" in suffix:
        return "Threshold Tuned Hard"
    elif "Temperature" in suffix:
        return "Temperature Tuned Soft"
    else:
        raise ValueError("Unknown method")

def pred_volume(method_group, hard_volume, soft_volume):
    if method_group == "Threshold Tuned Hard": 
        return hard_volume
    elif method_group == "Temperature Tuned Soft":
        return soft_volume
    else:
        return None

def dataset(experiment_inf_dataset_name):
    return experiment_inf_dataset_name

inference_df.augment(dataset)
inference_df.augment(method_group)
inference_df.augment(pred_volume)

In [5]:
# Get the rows corresponding to the base methods, and drop them frm the dataframe
base_rows = inference_df[inference_df['method_group'] == 'Base']
tuned_df = inference_df[inference_df['method_group'] != 'Base'].copy()

# Now we duplicate the base_rows.
hard_thresh_df = base_rows.copy()
hard_thresh_df['pred_volume'] = hard_thresh_df['hard_volume']
hard_thresh_df['method_group'] = 'Base Hard'

soft_thresh_df = base_rows.copy()
soft_thresh_df['pred_volume'] = soft_thresh_df['soft_volume']
soft_thresh_df['method_group'] = 'Base Soft'

# Concatenate the new rows to the dataframe
methods_df = pd.concat([tuned_df, hard_thresh_df, soft_thresh_df])

In [6]:
# Assert that there are no NaNs in pred_volume or gt_volume
assert not methods_df['pred_volume'].isna().any()
assert not methods_df['gt_volume'].isna().any()

In [None]:
def VE(pred_volume, gt_volume):
    return np.abs(pred_volume - gt_volume)

def RAVE(pred_volume, gt_volume):
    return np.abs(pred_volume - gt_volume) / gt_volume

def log_VE(VE):
    log_soft_err = np.log(VE)
    # if the error is negative infinity, we will return -2.
    if log_soft_err == -np.inf:
        return -3
    else:
        return log_soft_err

def loss_func(loss_func_class):
    if loss_func_class == "ese.losses.PixelCELoss":
        return "CrossEntropy"
    elif loss_func_class == "ese.losses.SoftDiceLoss":
        return "SoftDice"
    else:
        raise ValueError("Unknown loss function")

methods_df.augment(loss_func)
methods_df.augment(VE)
methods_df.augment(RAVE)
methods_df.augment(log_VE)

In [None]:
# Bookkeeping to make sure we aren't double counting,
for split in methods_df['split'].unique():
    print(split, len(methods_df[methods_df['split'] == split]['data_id'].unique()))

In [9]:
# We need to prune the df to the cols we care about
cols_to_keep = [
    "pred_volume",
    "gt_volume",
    "VE",
    "log_VE",
    "RAVE",
    "loss_func",
    "dataset",
    "split",
    "data_id",
    "method_group"
]
# Prune the dataframe
analyis_df = methods_df[cols_to_keep].drop_duplicates().reset_index(drop=True)

In [None]:
analyis_df

# Look at Method Performance

In [11]:
val_analyis_df = analyis_df[analyis_df['split'] == 'val']

In [None]:
g = sns.catplot(
    val_analyis_df, 
    x='loss_func', 
    y='RAVE', 
    hue='method_group', 
    kind='box',
    col='dataset',
    aspect=1.5,
    height=6,
    showfliers=False,
    sharey=False
)
# We want to make a title for the plot, with some spacing 
g.fig.subplots_adjust(top=0.8)
g.fig.suptitle(f"Relative Absolute Volumetric Error (RAVE) by Method, Loss Function, and Dataset", fontsize=23)