In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

from ese.experiment.analysis.analyze_inf import load_cal_inference_stats
# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - '05_15_24_SpineWeb_SuppWAug'
        - '05_15_24_SW_FixedAugExp'
        # - '05_15_24_SW_FixedAugExp_v2'
    
calibration:
    metric_cfg_file: "/storage/vbutoi/projects/ESE/ese/experiment/configs/inference/Calibration_Metrics.yaml"

options:
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    remove_shared_columns: False
    equal_rows_per_cfg_assert: False 

In [None]:
aug_image_info_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
import ast

def support_augmentations(support_augs):
    # Join the tuple by ', '
    aug_list = list(ast.literal_eval(support_augs))
    if len(aug_list) == 0:
        return "None"
    else:
        return ', '.join(aug_list)

def exp_name(exp_root):
    return exp_root.split('/')[-1]

aug_image_info_df.augment(support_augmentations)
aug_image_info_df.augment(exp_name)

In [None]:
# We want to compare how Dice relates to ECE, this means we need to pivot our dataframe
aug_df_pivot = aug_image_info_df.pivot(index=['exp_name', 'data_id', 'sup_idx', 'support_augmentations'], columns='image_metric', values='metric_score').reset_index()

In [None]:
aug_df_pivot['support_augmentations'] = aug_df_pivot['support_augmentations'].astype('category')
aug_df_pivot['support_augmentations'] = aug_df_pivot['support_augmentations'].cat.reorder_categories([
    'None',
    'Affine',
    'Elastic',
    'HorizontalFlip',
    'Affine, Elastic',
    'Affine, HorizontalFlip',
    'Elastic, HorizontalFlip',
    'Affine, Elastic, HorizontalFlip'
])

In [None]:
sns.catplot(
    data=aug_df_pivot,
    x='support_augmentations',
    y='Dice',
    hue='exp_name',
    kind='boxen',
    showfliers=False,
    aspect=2.5,
    height=10,
    palette='Set2',
)
# Rotate the x labels for better readability.
plt.xticks(rotation=45)
# Set the y axis to be between 0 and 1.
plt.ylim(0, 1)
# Set the title of the plot.
plt.title('Dice Score vs. Support Augmentations')

In [None]:
g = sns.catplot(
    data=aug_df_pivot.select(),
    x='support_augmentations',
    y='Dice',
    hue='exp_name',
    kind='boxen',
    row='data_id',
    showfliers=False,
    aspect=3,
    height=6,
    palette='Set2',
)
# Rotate the x labels for better readability.
plt.xticks(rotation=45)
# Set the y axis to be between 0 and 1.
g.set(ylim=(0, 1))
# Set the title for the collection of subplots.
g.fig.suptitle("Dice vs Augmentations", fontsize=25)
g.fig.subplots_adjust(top=0.9)