In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - "03_21_24_RandomCircles_TET_pt2"
    
calibration:
    metric_cfg_file: "/storage/vbutoi/projects/ESE/ese/experiment/configs/inference/Calibration_Metrics.yaml"

options:
    add_baseline_rows: False 
    load_pixel_meters: False 
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    load_groupavg_metrics: False
    remove_shared_columns: False
    equal_rows_per_cfg_assert: True 

In [None]:
from ese.experiment.analysis.analyze_inf import load_cal_inference_stats

image_info_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
image_info_df['ensemble.member_temps'].unique()

In [None]:
import ast

def temp_1(member_temps):
    return ast.literal_eval(member_temps)[0]


def temp_2(member_temps):
    return ast.literal_eval(member_temps)[1]

image_info_df.augment(temp_1)
image_info_df.augment(temp_2)

In [None]:
# Select only the rows corresponding to group methods
image_info_df = image_info_df[image_info_df['model_type'] == 'group']

image_info_df['method_name'] = image_info_df['method_name'].astype('category')
image_info_df['method_name'] = image_info_df['method_name'].cat.reorder_categories([
    'Ensemble (mean, probs)', 
])

image_info_df['split'] = image_info_df['split'].astype('category')
image_info_df['split'] = image_info_df['split'].cat.reorder_categories([
    'val',
    'cal'
])

# Let's looks at how the dice score varies as a function of the two temperatures.

In [None]:
def group_by_config(in_df):
    grouped_df = in_df.groupby([
        'ensemble_hash',
        'method_name',
        'calibrator',
        'split',
        'member_temps',
        'temp_1',
        'temp_2',
        'num_ensemble_members',
    ])
    # Mean over the metric_score columns
    meaned_groups =  grouped_df.agg({'metric_score': 'mean'}).reset_index()
    # Drop the NaN rows
    grouped_cfg = meaned_groups.dropna().reset_index(drop=True)
    return grouped_cfg

val_dice_metric_df = group_by_config(image_info_df.select(
    image_metric='Dice',
    split='val'
))

cal_dice_metric_df = group_by_config(image_info_df.select(
    image_metric='Dice',
    split='cal'
))

In [None]:
val_pivot_df = val_dice_metric_df.pivot_table(
    index='temp_1',
    columns='temp_2',
    values='metric_score',
    aggfunc='mean'
)

cal_pivot_df = cal_dice_metric_df.pivot_table(
    index='temp_1',
    columns='temp_2',
    values='metric_score',
    aggfunc='mean'
)

In [None]:
plt.figure(figsize=(15, 12))  # Adjust the size of the figure as desired

g = sns.heatmap(
    data=val_pivot_df,
    annot=True,
    cmap='coolwarm',
    fmt='.3f'
)

g.set_title('Validation Split Dice Score vs Temp 1 vs Temp 2')

In [None]:
plt.figure(figsize=(15, 12))  # Adjust the size of the figure as desired

g = sns.heatmap(
    data=cal_pivot_df,
    annot=True,
    cmap='coolwarm',
    fmt='.3f'
)

g.set_title('Calibration Split Dice Score vs Temp 1 vs Temp 2')