In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

# Standard imports
import os 
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
# Ionpy imports
from ionpy.analysis import ResultsLoader
# Local imports
from ese.analysis.analyze_inf import load_cal_inference_stats

sns.set_style("darkgrid")
sns.set_context("talk")
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

  warn(



In [None]:
%%yaml results_cfg 

log:
    root:
        - '/storage/vbutoi/scratch/ESE/inference/11_05_24_UVS_InContext_CrossEval'
    inference_group: 
        - 'Base'
        # - 'Optimal_Dice_Threshold'
        - 'Optimal_RAVE_Threshold'

options:
    verify_graceful_exit: True
    equal_rows_per_cfg_assert: False 

<IPython.core.display.Javascript object>

# Plotting Calls

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False
)

Loading log configs: 100%|██████████| 30/30 [00:00<00:00, 77.83it/s]
Loading image stats: 100%|██████████| 30/30 [00:08<00:00,  3.46it/s]


/storage/vbutoi/scratch/ESE/inference/11_05_24_UVS_InContext_CrossEval/Base                    20241105_110332-FGSY-77badd7cd1734748825ac49135b4887f     180
                                                                                               20241105_110335-M77Y-ae115739b9a20dea65cb92d15593735d     440
                                                                                               20241105_110340-BJ3D-671fba858a9084204bb3c56753544af8     400
                                                                                               20241105_110344-CWQB-d638adc2f532357a939ea4ad2b048c72      40
                                                                                               20241105_110429-X95K-3971a71503dc3f0b78e7de4003e9ea54      60
                                                                                               20241105_110433-DNA8-a9eee47fa99be0ed74ef8dd24f97f945     120
                                                          

In [None]:
# for ikey in inference_df.keys():
#     print(ikey)

In [None]:
inference_df['inference_data_task'].unique()

array(['PanDental/v1/XRay/0', 'SCD/VIS_human/MRI/2', 'WBC/CV/EM/0',
       'SCD/LAF_Pre/MRI/2', 'SpineWeb/Dataset7/MR/0',
       'PanDental/v2/XRay/0', 'SCD/LAS/MRI/2', 'SCD/VIS_pig/MRI/2',
       'SCD/LAF_Post/MRI/2', 'WBC/JTSC/EM/0',
       'STARE/retrieved_2021_12_06/Retinal/0', 'ACDC/Challenge2017/MRI/2'],
      dtype=object)

In [None]:
def method_group(log_root):
    suffix = log_root.split('/')[-1]
    if "Base" in suffix:
        return "Base"
    elif "Threshold" in suffix:
        return "Threshold Tuned Hard"
    elif "Temperature" in suffix:
        return "Temperature Tuned Soft"
    else:
        raise ValueError("Unknown method")

def pred_volume(method_group, hard_volume, soft_volume):
    if method_group == "Threshold Tuned Hard": 
        return hard_volume
    elif method_group == "Temperature Tuned Soft":
        return soft_volume
    else:
        return None

def dataset(inference_data_class):
    return inference_data_class.split('.')[-1]

inference_df.augment(dataset)
inference_df.augment(pred_volume)

In [None]:
# Get the rows corresponding to the base methods, and drop them frm the dataframe
base_rows = inference_df[inference_df['method_group'] == 'Base']
tuned_df = inference_df[inference_df['method_group'] != 'Base'].copy()

# Now we duplicate the base_rows.
hard_thresh_df = base_rows.copy()
hard_thresh_df['pred_volume'] = hard_thresh_df['hard_volume']
hard_thresh_df['method_group'] = 'Base Hard'

soft_thresh_df = base_rows.copy()
soft_thresh_df['pred_volume'] = soft_thresh_df['soft_volume']
soft_thresh_df['method_group'] = 'Base Soft'

# Concatenate the new rows to the dataframe
methods_df = pd.concat([tuned_df, hard_thresh_df, soft_thresh_df])

In [None]:
# Assert that there are no NaNs in pred_volume or gt_volume
assert not methods_df['pred_volume'].isna().any()
assert not methods_df['gt_volume'].isna().any()

In [None]:
def VE(pred_volume, gt_volume):

def RVE(pred_volume, gt_volume):
    return (pred_volume - gt_volume) / gt_volume

def RAVE(pred_volume, gt_volume):
    return np.abs(pred_volume - gt_volume) / gt_volume

def log_VE(VE):
    log_soft_err = np.log(VE + 1)
    # if the error is negative infinity, we will return -2.
    if log_soft_err == -np.inf:
        return -3
    else:
        return log_soft_err

def loss_func(loss_func_class):
    if loss_func_class == "ese.losses.PixelCELoss":
        return "CrossEntropy"
    elif loss_func_class == "ese.losses.SoftDiceLoss":
        return "SoftDice"
    else:
        raise ValueError("Unknown loss function")

methods_df.augment(loss_func)
methods_df.augment(VE)
methods_df.augment(RVE)
methods_df.augment(RAVE)
methods_df.augment(log_VE)

In [None]:
# Only keep the rows where the image_metric is 'Dice'
methods_df = methods_df[methods_df['image_metric'] == 'Dice']
# Rename the column metric score for this new df to Dice
methods_df = methods_df.rename(columns={'metric_score': 'Dice'})

In [None]:
# We need to prune the df to the cols we care about
cols_to_keep = [
    "pred_volume",
    "gt_volume",
    "Dice",
    "VE",
    "RVE",
    "log_VE",
    "RAVE",
    "loss_func",
    "dataset",
    "split",
    "data_id",
    "method_group"
]
# Prune the dataframe
analyis_df = methods_df[cols_to_keep].drop_duplicates().reset_index(drop=True)

In [None]:
analyis_df

# Look at Method Performance

In [None]:
val_analyis_df = analyis_df[analyis_df['split'] == 'val'].copy()

In [None]:
# These are the unique values of the estimator column.
val_analyis_df['method_group'] = val_analyis_df['method_group'].astype('category')
val_analyis_df['method_group'] = val_analyis_df['method_group'].cat.reorder_categories([
    'Base Hard',
    'Base Soft',
    'Threshold Tuned Hard',
    'Temperature Tuned Soft'
])

In [None]:
g = sns.catplot(
    val_analyis_df, 
    x='loss_func', 
    y='RAVE', 
    hue='method_group', 
    kind='box',
    col='dataset',
    aspect=1.5,
    height=6,
    showfliers=False,
    sharey=False
)
# We want to make a title for the plot, with some spacing 
g.fig.subplots_adjust(top=0.8)
g.fig.suptitle(f"Relative Absolute Volumetric Error (RAVE) by Method, Loss Function, and Dataset", fontsize=23)

In [None]:
g = sns.relplot(
    val_analyis_df,
    x='data_id', 
    y='RAVE', 
    hue='method_group', 
    kind='line',
    col='dataset',
    row='loss_func',
    aspect=1.5,
    height=6,
    facet_kws={'sharex': False, 'sharey': False}
)
# Disable x tick labels
g.set(xticklabels=[])
g.set_axis_labels("Data-Id Ordered by Dice Score (Increasing)", "RAVE")
# We want to make a title for the plot, with some spacing 
g.fig.subplots_adjust(top=0.8)
g.fig.suptitle(f"Unsorted Relative Absolute Volumetric Error (RAVE) by Dataset (col) and Loss Function (row)", fontsize=23)
plt.show()

In [None]:
# Sort by Dice Score
vad_sortby_dice = val_analyis_df.sort_values(by='Dice', ascending=True)

g = sns.relplot(
    vad_sortby_dice,
    x='data_id', 
    y='RAVE', 
    hue='method_group', 
    kind='line',
    col='dataset',
    row='loss_func',
    aspect=1.5,
    height=6,
    facet_kws={'sharex': False, 'sharey': False}
)
# Disable x tick labels
g.set(xticklabels=[])
g.set_axis_labels("Data-Id Ordered by Dice Score (Increasing)", "RAVE")
# We want to make a title for the plot, with some spacing 
g.fig.subplots_adjust(top=0.8)
g.fig.suptitle(f"Sort by Dice Relative Absolute Volumetric Error (RAVE) by Dataset (col) and Loss Function (row)", fontsize=23)
plt.show()

In [None]:
# Sort by Dice Score
vad_sortby_gtvol = val_analyis_df.sort_values(by='gt_volume', ascending=True)

g = sns.relplot(
    vad_sortby_gtvol,
    x='data_id', 
    y='RAVE', 
    hue='method_group', 
    kind='line',
    col='dataset',
    row='loss_func',
    aspect=1.5,
    height=6,
    facet_kws={'sharex': False, 'sharey': False}
)
# Disable x tick labels
g.set(xticklabels=[])
g.set_axis_labels("Data-Id Ordered by Dice Score (Increasing)", "RAVE")
# We want to make a title for the plot, with some spacing 
g.fig.subplots_adjust(top=0.8)
g.fig.suptitle(f"Sort by GT Size Relative Absolute Volumetric Error (RAVE) by Dataset (col) and Loss Function (row)", fontsize=23)
plt.show()

In [None]:
# Sort by Dice Score
vad_sortby_gtvol = val_analyis_df.sort_values(by='gt_volume', ascending=True)

g = sns.relplot(
    vad_sortby_gtvol,
    x='data_id', 
    y='RVE', 
    hue='method_group', 
    kind='line',
    col='dataset',
    row='loss_func',
    aspect=1.5,
    height=6,
    facet_kws={'sharex': False, 'sharey': False}
)
# Disable x tick labels
g.set(xticklabels=[])
g.set_axis_labels("Data-Id Ordered by GT Amount (Increasing)", "RAVE")
# We want to make a title for the plot, with some spacing 
g.fig.subplots_adjust(top=0.8)
g.fig.suptitle(f"Sort by GT Size Relative Volumetric Error (RVE) by Dataset (col) and Loss Function (row)", fontsize=23)
plt.show()

# Investigating what's so bad

In [None]:
vad_sortby_gtvol[vad_sortby_gtvol['loss_func'] == 'SoftDice'].head(12)

In [None]:
vad_sortby_gtvol[vad_sortby_gtvol['loss_func'] == 'CrossEntropy'].head(12)

In [None]:
# Get the mean gt_volume by dataset
gt_volume_means = val_analyis_df.groupby('dataset')['gt_volume'].mean()

In [None]:
gt_volume_means

In [None]:
# set per row the 'mean gt_volume' for that row's dataset
val_analyis_df['mean_gt_volume'] = val_analyis_df['dataset'].map(gt_volume_means)

In [None]:
# We want to visualize the distribution of ground truth volumes by dataset, noramlized by the mean
def norm_gt_volume(gt_volume, mean_gt_volume):
    return gt_volume / mean_gt_volume

val_analyis_df.augment(norm_gt_volume)

# Visualize using KDE plos in a facet grid
g = sns.FacetGrid(
    val_analyis_df, 
    col='dataset', 
    aspect=1.5, 
    height=6, 
    sharey=False
)
g.map(sns.kdeplot, 'norm_gt_volume', fill=True)