In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - "04_24_24_UniverSeg_Inference"
    
calibration:
    metric_cfg_file: "/storage/vbutoi/projects/ESE/ese/experiment/configs/inference/Calibration_Metrics.yaml"

options:
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    remove_shared_columns: False
    equal_rows_per_cfg_assert: False 

In [None]:
from ese.experiment.analysis.analyze_inf import load_cal_inference_stats

image_info_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
image_info_df.keys()

In [None]:
def task(dataset_name, label_idx):
    return f"{dataset_name}_{label_idx}"

image_info_df.augment(task)

In [None]:
image_info_df['task'].unique()

In [None]:
image_info_df['task'] = image_info_df['task'].astype('category')
image_info_df['task'] = image_info_df['task'].cat.reorder_categories([
    'PanDental_Jaw_0',
    'PanDental_Mandible_0',
    'SpineWeb_0',
    'STARE_0',
    'WBC_1',
    'WBC_0',
    'ACDC_2',
    'ACDC_1',
    'ACDC_0'
])


In [None]:
# Set the figsize
g = sns.catplot(
    data=image_info_df[image_info_df['image_metric'] == 'Dice'],
    x='task',
    y='metric_score',
    kind='boxen',
    hue='dataset_name',
    aspect=2,
    height=8,
)
# Set the y axis label to be Dice Score
g.set_xticklabels(rotation=45)
g.set(ylabel="Dice Score")

g.fig.suptitle("Dice Score per Held-out Task (#Samp per Subj=5, #Support=8)", fontsize=25)
g.fig.subplots_adjust(top=0.9)
# Give the title a bit of space from the plot


In [None]:
# set the figsize
g = sns.catplot(
    data=image_info_df[image_info_df['image_metric'] == 'Image_ECE'],
    x='task',
    y='metric_score',
    kind='boxen',
    hue='dataset_name',
    aspect=2,
    height=8,
)
# set the y axis label to be dice score
g.set_xticklabels(rotation=45)
g.set(ylabel="Image ECE")

g.fig.suptitle("Image ECE per held-out task (#samp per subj=5, #support=8)", fontsize=25)
g.fig.subplots_adjust(top=0.9)
# give the title a bit of space from the plot


In [None]:
# We want to compare how Dice relates to ECE, this means we need to pivot our dataframe
df_pivot = image_info_df.pivot(index=['data_id', 'sup_idx', 'task', 'dataset_name'], columns='image_metric', values='metric_score').reset_index()

In [None]:
# set the figsize
g = sns.catplot(
    data=df_pivot,
    x='data_id',
    y='Dice',
    hue='dataset_name',
    col="task",
    col_wrap=3,
    height=6,
    kind='box',
    sharex=False,
    sharey=False
)
# set the y axis label to be dice score
g.fig.subplots_adjust(wspace=0.25, hspace=0.2)

g.fig.suptitle("Dice-per-Subject Across Samples, Per Task (#samp per subj=5, #support=8)", fontsize=25)
g.fig.subplots_adjust(top=0.9)
# Drop the x tick labels
for ax in g.axes.flatten():
    ax.set_xticklabels([])

In [None]:
# set the figsize
g = sns.catplot(
    data=df_pivot,
    x='data_id',
    y='Image_ECE',
    hue='dataset_name',
    col="task",
    col_wrap=3,
    height=6,
    kind='box',
    sharex=False,
    sharey=False
)
# set the y axis label to be dice score
g.fig.subplots_adjust(wspace=0.25, hspace=0.2)

g.fig.suptitle("Image_ECE-per-Subject Across Samples, Per Task (#samp per subj=5, #support=8)", fontsize=25)
g.fig.subplots_adjust(top=0.9)
# Drop the x tick labels
for ax in g.axes.flatten():
    ax.set_xticklabels([])

In [None]:
from scipy.stats import pearsonr

# Function to calculate r^2 value
def calculate_r_squared(x, y):
    correlation_matrix = pearsonr(x, y)
    r_squared = correlation_matrix[0] ** 2
    return r_squared

# set the figsize
g = sns.relplot(
    data=df_pivot,
    x='Dice',
    y='Image_ECE',
    kind='scatter',
    hue='dataset_name',
    col="task",
    col_wrap=3,
    height=6,
    facet_kws={
        "sharex":False,
        "sharey":False
    }
)
# set the y axis label to be dice score
g.fig.subplots_adjust(wspace=0.25, hspace=0.2)

g.fig.suptitle("Dice vs Image ECE per held-out task (#samp per subj=5, #support=8)", fontsize=25)
g.fig.subplots_adjust(top=0.9)

# Calculate and annotate r^2 value on each subplot
for ax in g.axes.flatten():
    x_data = ax.collections[0].get_offsets()[:, 0]
    y_data = ax.collections[0].get_offsets()[:, 1]
    r_squared = calculate_r_squared(x_data, y_data)
    ax.annotate(f"$r^2$ = {r_squared:.2f}", xy=(0.7, 0.85), xycoords='axes fraction', color='red', fontsize=16)

plt.show()