In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

from ese.experiment.analysis.analyze_inf import load_cal_inference_stats
# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - '07_03_24_OCTA_6M_Lab255_InterpolationSettings'

options:
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    remove_shared_columns: False
    equal_rows_per_cfg_assert: True 

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
if 'dataset.version' in inference_df.columns:
    inference_df['dataset_version'] = inference_df['dataset.version'].map(lambda x: float(x))
else:
    inference_df['dataset_version'] = 1.0

def loss_func(loss_func_class):
    if loss_func_class == "None":
        return 'Combo'
    else:
        return loss_func_class.split('.')[-1]
    
def resolution(dataset_version):
    if dataset_version == 0.1:
        return "128"
    elif dataset_version == 1.0:
        return "400" 
    else:
        raise ValueError(f"Unknown dataset version {dataset_version}")

inference_df.augment(loss_func)
inference_df.augment(resolution)

# Remove rows corresponding to pretrained seed 42, it crashed.
inference_df = inference_df[inference_df['experiment_pretrained_seed'] != 42].reset_index(drop=True)
# We want to remove all rows where experiment_resolution_output_mode = nearest and align_corners = True
inference_df = inference_df[
                    ~((inference_df['experiment_resolution_output_mode'] == 'nearest') & 
                    (inference_df['experiment_resolution_output_align_corners'] == True))
                ].reset_index(drop=True)

# Library Funcs

In [None]:
def prepare_error_df(raw_df, groupby_keys, value_vars, var_name, value_name):
    # Make a clone of the proportion df.
    input_df = raw_df.copy()
    # Melt the dataframe to have a single column for the error.
    error_df = pd.melt(
        input_df,
        id_vars=groupby_keys,
        value_vars=value_vars,
        var_name=var_name,
        value_name=value_name,
    )
    # Make some columns that are useful for plotting.
    error_df[f'absolute {value_name}'] = error_df[value_name].abs()
    # Return the melted dataframe.
    return error_df


def calibrator(model_pretrained_exp_root):
    if "SVLS" in model_pretrained_exp_root:
        return "SVLS"
    else:
        return "Uncal"
    

def upsample_cfg(experiment_resolution_output_mode, experiment_resolution_output_align_corners):
    return f"{experiment_resolution_output_mode} align={experiment_resolution_output_align_corners}"


def process_method_names(input_df, value_name):
    # Make a clone of the input_df
    df = input_df.copy()
    # Drop all the rows where calibrator != Uncalibrated AND the proportion_type is hard_proportion_error.
    df = df[~((df['calibrator'] != 'Uncal') & (df['proportion_type'] == f'hard {value_name}'))]
    # Then we augment the proportion_type with the calibrator name.
    def proportion_type(calibrator, loss_func, proportion_type):
        # If the loss function is PixelCELoss, we drop the loss function name.
        if loss_func == "PixelCELoss":
            proc_loss_func = "CE"
        else:
            proc_loss_func = "Dice"

        # If the proportion_type is hard_proportion_error, we drop the calibrator name.
        if proportion_type in ["new gt error", "new gt relative error"]:
            return "New GT"
        elif calibrator == "Uncal":
            return "Uncal " + proportion_type.split(" ")[0] + f" ({proc_loss_func})"
        else:
            return calibrator + " soft" + f" ({proc_loss_func})"
    # Finally, sort by data_id
    df['proportion type'] = df.apply(lambda x: proportion_type(x['calibrator'], x['loss_func'], x['proportion_type']), axis=1)
    df = df.sort_values(by="data_id")
    # Drop the duplicate rows and reset the index.
    df = df.drop_duplicates().reset_index(drop=True)
    # Return the augmented dataframe.
    return df

In [None]:
# Give a new calibrator and interpolate cfg columns.
inference_df.augment(calibrator)
inference_df.augment(upsample_cfg)

# Get the relevant columns for looking at the Dice score and Image ECE
metric_cols = [
    "calibrator",
    "data_id",
    "experiment_pretrained_seed",
    "image_metric",
    "loss_func",
    "metric_score",
    "model_pretrained_exp_root",
    "split",
    "upsample_cfg",
]

# Take these columns of the inference_df, drop other columns and delete duplicate rows.
metric_df = inference_df[metric_cols].drop_duplicates().reset_index(drop=True)

In [None]:
def train_method(calibrator, loss_func):
    return calibrator + f" ({loss_func})"

metric_df.augment(train_method)

In [None]:
metric_df['train_method'].unique()

In [None]:
metric_df['image_metric'].unique()

In [None]:
def shortened_train_method(train_method):
    if train_method == 'Uncal (PixelCELoss)':
        return 'UC-CE'
    elif train_method == 'Uncal (SoftDiceLoss)':
        return 'UC-SD'
    elif train_method == 'SVLS (PixelCELoss)':
        return 'SV-CE'
    elif train_method == 'SVLS (SoftDiceLoss)':
        return 'SV-SD'
    else:
        raise ValueError(f"Unknown train method {train_method}")
    
metric_df.augment(shortened_train_method)

In [None]:
metric_df['shortened_train_method'] = metric_df['shortened_train_method'].astype('category')
metric_df['shortened_train_method'] = metric_df['shortened_train_method'].cat.reorder_categories([
    'UC-CE',
    'UC-SD',
    'SV-CE',
    'SV-SD'
])
metric_df['upsample_cfg'] = metric_df['upsample_cfg'].astype('category')
metric_df['upsample_cfg'] = metric_df['upsample_cfg'].cat.reorder_categories([
    'bilinear align=True',
    'bilinear align=False',
    'nearest align=False',
])

custom_palette = {
    'bilinear align=True': 'cornflowerblue',
    'bilinear align=False': 'darkblue',
    'nearest align=False': "orangered",
}

In [None]:
# Get the df for Dice loss
dice_df = metric_df[metric_df['image_metric'] == 'Dice Loss']

# Create the catplot
g = sns.catplot(
    data=dice_df,      # Ensure you use the 'data' parameter correctly.
    x="shortened_train_method",
    y="metric_score",
    hue="upsample_cfg",
    col="experiment_pretrained_seed",
    kind="box",
    col_wrap=4,
    sharex=False,
    height=6,
    palette=custom_palette,
)

# Set the y axis label to 'Dice Loss'
g.set_ylabels("Dice Loss")

# For each subplot, add a line at y = 0 to show the error.
for ax in g.axes.flat:
    ax.axhline(0, ls='--', color='red')

# Adjust the layout
plt.subplots_adjust(top=0.85)
g.fig.suptitle('Dice Loss per Method for Different Upsample Methods(Vessel)', fontsize=30)

# Add some vertical spcace 
plt.subplots_adjust(hspace=0.25)

# Show the plot
plt.show()

In [None]:
# Get the df for Dice loss
ece_df = metric_df[metric_df['image_metric'] == 'Image_ECE']

# Create the catplot
g = sns.catplot(
    data=ece_df,      # Ensure you use the 'data' parameter correctly.
    x="shortened_train_method",
    y="metric_score",
    hue="upsample_cfg",
    col="experiment_pretrained_seed",
    kind="box",
    col_wrap=4,
    sharex=False,
    height=6,
    palette=custom_palette,
)

# Set the y axis label to 'Image ECE'
g.set_ylabels("Image ECE")

# For each subplot, add a line at y = 0 to show the error.
for ax in g.axes.flat:
    ax.axhline(0, ls='--', color='red')

# Adjust the layout
plt.subplots_adjust(top=0.85)
g.fig.suptitle('Image ECE per Method for Different Upsample Methods(Vessel)', fontsize=30)

# Add some vertical spcace 
plt.subplots_adjust(hspace=0.25)

# Show the plot
plt.show()