In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

from ese.experiment.analysis.analyze_inf import load_cal_inference_stats
# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - '07_10_24_OCTA_SmoothGT'

options:
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    remove_shared_columns: False
    equal_rows_per_cfg_assert: False 

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
def loss_func(loss_func_class):
    if loss_func_class == "None":
        return 'Combo'
    else:
        return loss_func_class.split('.')[-1]

inference_df.augment(loss_func)

In [None]:
# For the purpose of this experiment, we only care about a few columns in particular:
exp_columns = [
    "data_id",
    "loss_func",
    "new_gt_proportion", # This is after our resizing and blurring
    "gt_proportion",
    "soft_proportion",
    "hard_proportion",
    "experiment_pretrained_seed", 
    "model_pretrained_exp_root",
    "split",
]
# Take these columns of the inference_df, drop other columns.
exp_df = inference_df[exp_columns].drop_duplicates().reset_index(drop=True)

In [None]:
# Get number of examples we are evaluating on.
len(exp_df['data_id'].unique())

# Library Funcs

In [None]:
def prepare_proportion_df(raw_df, groupby_keys, value_vars, var_name, value_name):
    # Make a clone of the proportion df.
    input_df = raw_df.copy()
    # Melt the dataframe to have a single column for the error.
    return pd.melt(
        input_df,
        id_vars=groupby_keys,
        value_vars=value_vars,
        var_name=var_name,
        value_name=value_name,
    )

def calibrator(model_pretrained_exp_root):
    if "SVLS" in model_pretrained_exp_root:
        return "SVLS"
    else:
        return "Uncalibrated"


def process_method_names(input_df, value_name):
    # Make a clone of the input_df
    df = input_df.copy()
    # Drop all the rows where calibrator != Uncalibrated AND the proportion_type is hard_proportion_error.
    df = df[~((df['calibrator'] != 'Uncalibrated') & (df['proportion_type'] == f'hard {value_name}'))]
    # Then we augment the proportion_type with the calibrator name.
    def proportion_type(calibrator, loss_func, proportion_type):
        if proportion_type in ["new gt error", "new gt relative error"]:
            return "New GT"
        elif calibrator == "Uncalibrated":
            return "Uncalibrated " + proportion_type.split(" ")[0] + f" ({loss_func})"
        else:
            return calibrator + " soft" + f" ({loss_func})"
    # Finally, sort by data_id
    df['proportion type'] = df.apply(lambda x: proportion_type(x['calibrator'], x['loss_func'], x['proportion_type']), axis=1)
    df = df.sort_values(by="data_id")
    # Drop the duplicate rows and reset the index.
    df = df.drop_duplicates().reset_index(drop=True)
    # Return the augmented dataframe.
    return df

# First, let's look at absolute error.

In [None]:
exp_df.augment(calibrator)
# Make some columns that are useful for plotting.
proportion_df = prepare_proportion_df(
    exp_df, 
    groupby_keys=[
        "calibrator",
        "loss_func",
        "data_id", 
        "experiment_pretrained_seed", 
        "gt_proportion",
        "split"
    ],
    value_vars=["new_gt_proportion", "soft_proportion", "hard_proportion"],
    var_name="prop_type",
    value_name="proportion"
)
# Delete duplicate rows, thought there shouldn't be any..
proportion_df = proportion_df.drop_duplicates().reset_index(drop=True)

# Process the method names.
def total_method(calibrator, loss_func, prop_type):
    return f"{calibrator},{loss_func},{prop_type}"
proportion_df.augment(total_method)

In [None]:
proportion_df

In [None]:
# We want to keep two kinds of rows:
# - where it's Uncalibrated, uses PixelCELoss, and hard_prop
# - SVLS, uses PixelCELoss, and soft_prop
proportion_df = proportion_df[proportion_df['total_method'].isin([
    "Uncalibrated,SoftDiceLoss,hard_proportion",
    "SVLS,SoftDiceLoss,soft_proportion",
])].reset_index(drop=True).sort_values(by="gt_proportion")

In [None]:
proportion_df

In [None]:
g = sns.relplot(
    proportion_df, 
    x="data_id",
    y="proportion",
    hue="calibrator",
    kind="line",
    height=10,
    aspect=1.5,
)
# We want to add, per x data_id a dot that represents the gt_proportion.
for i, data_id in enumerate(proportion_df['data_id'].unique()):
    plt.scatter(
        [i],
        [proportion_df[proportion_df['data_id'] == data_id]['gt_proportion'].iloc[0]],
        color='magenta',
        marker='*',
        s=80,
    )
# Disable horizontal grid lines for every subplot
for ax in g.axes.flat:
    ax.xaxis.grid(False)
    ax.set_xticks([]) # Turn off the x axis tickss
# Finally, put a title on the plot saying the proportion type.
plt.title("OCTA Predicted Proportion vs. GT Proportion (pink)")
plt.show()

In [None]:
g = sns.catplot(
    proportion_df, 
    x="data_id",
    y="proportion",
    hue="calibrator",
    kind="box",
    height=10,
    aspect=1.5,
    boxprops={
        "edgecolor":'dimgray',
        "linewidth": 2
    },
    whiskerprops=dict(color='dimgray')  # Make whiskers gray
)  
# We want to add, per x data_id a dot that represents the gt_proportion.
for i, data_id in enumerate(proportion_df['data_id'].unique()):
    plt.scatter(
        [i],
        [proportion_df[proportion_df['data_id'] == data_id]['gt_proportion'].iloc[0]],
        color='magenta',
        marker='*',
        s=80,
    )
# Disable horizontal grid lines for every subplot
for ax in g.axes.flat:
    ax.xaxis.grid(False)
    ax.set_xticks([]) # Turn off the x axis tickss

# Finally, put a title on the plot saying the proportion type.
plt.title("OCTA Predicted Proportion vs. GT Proportion (pink)")

plt.show()