In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

from ese.experiment.analysis.analyze_inf import load_cal_inference_stats
# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - '05_27_24_SW_Pairwise'

options:
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    remove_shared_columns: False
    equal_rows_per_cfg_assert: False 

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
# import pickle

# # Load all of the preds so we can compare them
# loaded_pred_dict = {}
# for pred_hash_id in inference_df['pred_hash'].unique():
#     # Get the rows corresponding to the hash
#     pred_rows = inference_df[inference_df['pred_hash'] == pred_hash_id]
#     # Get the log set corresponding to the hash
#     log_root = pred_rows['root'].unique()[0]
#     log_set = pred_rows['log_set'].unique()[0]
#     # Load the prediction pickle
#     with open(f'{log_root}/{log_set}/preds/{pred_hash_id}.pkl', 'rb') as f:
#         loaded_pred_dict[pred_hash_id] = pickle.load(f)

In [None]:
import numpy as np
import itertools
unique_subj_ids = inference_df['data_id'].unique()

pw_error_list = []
# Iterate through all the pair-wise comparisons.
for (subj_id_1, subj_id_2) in list(itertools.combinations(unique_subj_ids, 2)):
    if subj_id_1 != subj_id_2:
        # Get the dfs corresponding to these two ids:
        data_id_1_df = inference_df[inference_df['data_id'] == subj_id_1]
        data_id_2_df = inference_df[inference_df['data_id'] == subj_id_2]
        # Get the unique support indices.
        unique_sup_ids_1 = data_id_1_df['sup_idx'].unique()
        unique_sup_ids_2 = data_id_2_df['sup_idx'].unique()
        for (sup_id_1, sup_id_2) in list(itertools.product(unique_sup_ids_1, unique_sup_ids_2)):
            # Get the dfs corresponding to these sup ids
            sup_id_1_df = data_id_1_df[data_id_1_df['sup_idx'] == sup_id_1].reset_index(drop=True)
            sup_id_2_df = data_id_2_df[data_id_2_df['sup_idx'] == sup_id_2].reset_index(drop=True)

            # Get our desired quantities.
            # GT
            gt_volume_1 = sup_id_1_df['gt_volume'].values[0]
            gt_volume_2 = sup_id_2_df['gt_volume'].values[0]
            # Soft
            soft_volume_1 = sup_id_1_df['soft_volume'].values[0]
            soft_volume_2 = sup_id_2_df['soft_volume'].values[0]
            # Hard
            hard_volume_1 = sup_id_1_df['hard_volume'].values[0]
            hard_volume_2 = sup_id_2_df['hard_volume'].values[0]

            # Now we can get differences in volume by looking at the differences.
            # GT
            gt_volume_diff = gt_volume_2 - gt_volume_1
            gt_volume_quot = gt_volume_2 / gt_volume_1
            # Soft
            soft_volume_diff = soft_volume_2 - soft_volume_1
            soft_volume_quot = soft_volume_2 / soft_volume_1
            # Hard
            hard_volume_diff = hard_volume_2 - hard_volume_1
            hard_volume_quot = hard_volume_2 / hard_volume_1

            # Error between predicted difference and actual difference.
            # Soft
            soft_diff_error = soft_volume_diff - gt_volume_diff
            hard_diff_error = hard_volume_diff - gt_volume_diff
            # Hard
            soft_quot_error = soft_volume_quot - gt_volume_quot
            hard_quot_error = hard_volume_quot - gt_volume_quot

            # Get the metric information from each of the dataframes
            # Dice
            dice_1 = sup_id_1_df[sup_id_1_df['image_metric'] == 'Dice']['metric_score'].iloc[0]
            dice_2 = sup_id_2_df[sup_id_2_df['image_metric'] == 'Dice']['metric_score'].iloc[0]
            # ECE
            ece_1 = sup_id_1_df[sup_id_1_df['image_metric'] == 'Image_ECE']['metric_score'].iloc[0]
            ece_2 = sup_id_2_df[sup_id_2_df['image_metric'] == 'Image_ECE']['metric_score'].iloc[0]

            # Place all of these into a record
            pair_record = {
                'subj_id_1': subj_id_1,
                'subj_id_2': subj_id_2,
                'subj_combo': f'{subj_id_1},{subj_id_2}',
                'sup_id_1': sup_id_1,
                'sup_id_2': sup_id_2,
                'dice_1': dice_1,
                'dice_2': dice_2,
                'ece_1': ece_1,
                'ece_2': ece_2,
                'mean_dice': np.mean([dice_1, dice_2]),
                'mean_ece': np.mean([ece_1, ece_2]),
                'gt_volume_diff': gt_volume_diff,
                'gt_volume_quot': gt_volume_quot,
                'soft_volume_diff': soft_volume_diff,
                'soft_volume_quot': soft_volume_quot,
                'hard_volume_diff': hard_volume_diff,
                'hard_volume_quot': hard_volume_quot,
            }
            # Add the record to the list
            pw_error_list.append(pair_record)
# Convert the final dataframe into a pandas dataframe
pairwise_df = pd.DataFrame(pw_error_list)

# Test 1: On average, how do the real compare with respect to difference?

In [None]:
# Melt the dataframe
pw_diff_df = pd.melt(
    pairwise_df, 
    id_vars=[
        'subj_id_1', 
        'subj_id_2', 
        'sup_id_1', 
        'sup_id_2', 
        'subj_combo'
    ], 
    value_vars=[
        'gt_volume_diff', 
        'soft_volume_diff', 
        'hard_volume_diff'
    ], 
    var_name='Pred_Type', 
    value_name='Volume Difference'
)


In [None]:
# Create a larger figure
plt.figure(figsize=(12, 8))

# Create the boxplot with modified whiskers and without showing outliers
ax = sns.boxplot(
    x='Pred_Type',
    y='Volume Difference',
    data=pw_diff_df,
    palette="Set2",       # Use a color palette
    linewidth=2,          # Set the linewidth of the edge
    showfliers=False,     # Do not show outliers
    whis=0.5              # Shorten the whiskers to half the IQR
)

# Enhance the plot
ax.set_title('Volume Difference by Prediction Type', fontsize=20)
ax.set_xlabel('Prediction Type', fontsize=15)
ax.set_ylabel('Volume Difference', fontsize=15)
ax.tick_params(labelsize=12)

# Remove the top and right spines for a cleaner look
sns.despine()

# Show the plot
plt.show()

# Test 1.5: Per-subject, how do the real compare with respect to difference?

In [None]:
# Create the boxplot
g = sns.catplot(
    x='subj_combo',
    y='Volume Difference',
    kind='box',
    data=pw_diff_df,
    hue='Pred_Type',
    palette="Set2",       # Use a color palette
    linewidth=2,          # Set the linewidth of the edge
    aspect=1.5,            # Adjust the aspect ratio
    height=8
)

# Test 2: On average, how do the real compare with respect to quotient?

In [None]:
# Melt the dataframe
pw_quot_df = pd.melt(
    pairwise_df, 
    id_vars=[
        'subj_id_1', 
        'subj_id_2', 
        'sup_id_1', 
        'sup_id_2', 
        'subj_combo'
    ], 
    value_vars=[
        'gt_volume_quot', 
        'soft_volume_quot', 
        'hard_volume_quot'
    ], 
    var_name='Pred_Type', 
    value_name='Volume Quotient'
)


In [None]:
# Create a larger figure
plt.figure(figsize=(12, 8))

# Create the boxplot
ax = sns.boxplot(
    x='Pred_Type',
    y='Volume Quotient',
    data=pw_quot_df,
    palette="Set2",       # Use a color palette
    linewidth=2           # Set the linewidth of the edge
)

# Enhance the plot
ax.set_title('Volume Quotient by Prediction Type', fontsize=20)
ax.set_xlabel('Prediction Type', fontsize=15)
ax.set_ylabel('Volume Quotient', fontsize=15)
ax.tick_params(labelsize=12)

# Remove the top and right spines for a cleaner look
sns.despine()

# Show the plot
plt.show()


# Test 2.5: Per-subject, how do the real compare with respect to quotient?

In [None]:
# Create the boxplot
g = sns.catplot(
    x='subj_combo',
    y='Volume Quotient',
    kind='box',
    data=pw_quot_df,
    hue='Pred_Type',
    palette="Set2",       # Use a color palette
    linewidth=2,          # Set the linewidth of the edge
    aspect=1.5,            # Adjust the aspect ratio
    height=8
)