In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - "03_25_24_RandomCircles_RandomMiscalEnsembles"
    
calibration:
    metric_cfg_file: "/storage/vbutoi/projects/ESE/ese/experiment/configs/inference/Calibration_Metrics.yaml"

options:
    add_baseline_rows: False 
    load_pixel_meters: False 
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    load_groupavg_metrics: False
    remove_shared_columns: False
    equal_rows_per_cfg_assert: True 

In [None]:
from ese.experiment.analysis.analyze_inf import load_cal_inference_stats

image_info_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
import ast
import numpy as np

# def temp_1(member_temps):
#     return ast.literal_eval(member_temps)[0]
# def temp_2(member_temps):
#     return ast.literal_eval(member_temps)[1]

def temp_mean(member_temps):
    return np.mean(ast.literal_eval(member_temps))

def temp_variance(member_temps):
    return np.var(ast.literal_eval(member_temps))

image_info_df.augment(temp_mean)
image_info_df.augment(temp_variance)

In [None]:
image_info_df['temp_variance'].unique()

In [None]:
# Select only the rows corresponding to group methods
image_info_df = image_info_df[image_info_df['model_type'] == 'group']

image_info_df['method_name'] = image_info_df['method_name'].astype('category')
image_info_df['method_name'] = image_info_df['method_name'].cat.reorder_categories([
    'Ensemble (mean, probs)', 
])

image_info_df['split'] = image_info_df['split'].astype('category')
image_info_df['split'] = image_info_df['split'].cat.reorder_categories([
    'val',
    'cal'
])

In [None]:
image_info_df['image_metric'].unique()

In [None]:
metric_df = image_info_df[image_info_df['image_metric'] == 'Image_ECE']

g = sns.relplot(
    data=metric_df,
    x='temp_mean',
    y='temp_variance',
    hue='metric_score',
    col='split',
    kind='scatter',
    height=8,
    aspect=1.5,
    s=100,
    alpha=0.7,
)

In [None]:
metric_df = image_info_df[image_info_df['image_metric'] == 'Dice']

g = sns.relplot(
    data=metric_df,
    x='temp_mean',
    y='temp_variance',
    hue='metric_score',
    col='split',
    kind='scatter',
    height=8,
    aspect=1.5,
    s=100,
    alpha=0.7,
    hue_norm=(0.70, 0.80)
)

In [None]:
metric_df = image_info_df[image_info_df['image_metric'] == 'Dice']

g = sns.relplot(
    data=metric_df,
    x='temp_variance',
    y='metric_score',
    col='split',
    kind='line',
    height=8,
    aspect=1.5,
)

# g.set(ylim=(0.85, 1.0))

In [None]:
metric_df = image_info_df[image_info_df['image_metric'] == 'HD95']

g = sns.relplot(
    data=metric_df,
    x='temp_mean',
    y='temp_variance',
    hue='metric_score',
    col='split',
    kind='scatter',
    height=8,
    aspect=1.5,
    s=100,
    alpha=0.7,
)

In [None]:
metric_df = image_info_df[image_info_df['image_metric'] == 'Accuracy']

g = sns.relplot(
    data=metric_df,
    x='temp_mean',
    y='temp_variance',
    hue='metric_score',
    col='split',
    kind='scatter',
    height=8,
    aspect=1.5,
    s=100,
    alpha=0.7,
)

In [None]:
metric_df = image_info_df[image_info_df['image_metric'] == 'BoundaryIOU']

g = sns.relplot(
    data=metric_df,
    x='temp_mean',
    y='temp_variance',
    hue='metric_score',
    col='split',
    kind='scatter',
    height=8,
    aspect=1.5,
    s=100,
    alpha=0.7,
)