In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import os 
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from pathlib import Path
import matplotlib.pyplot as plt
# Ionpy imports
from ionpy.analysis import ResultsLoader
# Local imports
from ese.analysis.analyze_inf import load_cal_inference_stats
from ese.analysis.analysis_utils.plot_utils import get_prop_color_palette
from ese.analysis.analysis_utils.parse_sweep import get_global_optimal_parameter, get_per_subject_optimal_values
sns.set_style("darkgrid")
sns.set_context("talk")
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
root = Path("/storage/vbutoi/scratch/ESE")
pd.set_option('display.max_rows', 50)
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: '/storage/vbutoi/scratch/ESE/inference/11_05_24_UVS_InContext_CrossEval'
    inference_group: 
        - 'Sweep_Threshold'
        # - 'Sweep_Temperature'

options:
    verify_graceful_exit: True
    equal_rows_per_cfg_assert: False 

### Useful cell for controlling the plotting functions.

In [None]:
# ######This cells controls what gets plotted in the following cells so we don't have to change each one
x_key = 'threshold'
y_key = 'hard_RAVE'
xtick_range = np.arange(0, 1.1, 0.1)
cmap = 'viridis_r'
aspect = 1
x_lims = (0, 1)
y_lims = (0.0, 2)

# x_key = 'temperature'
# y_key = 'soft_RAVE'
# xtick_range = np.arange(0, 3.1, 0.1)
# cmap = 'magma_r'
# aspect = 2
# x_lims = (0, 3.0)
# y_lims = (-0.5, 2)

# Plotting Calls

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=True
)

In [None]:
# Only keep the rows where the image_metric is 'Dice'
inference_df = inference_df[inference_df['image_metric'] == 'Dice']
# Rename the column metric score for this new df to Dice
inference_df = inference_df.rename(columns={'metric_score': 'Dice'})

In [None]:
inference_df['split'] = 'train'

In [None]:
# for ikey in inference_df.keys():
#     print(ikey)

In [None]:
for task in inference_df['inference_data_task'].unique():
    print(task)

In [None]:
cols_to_keep = [
    'soft_abs_area_estimation_error',
    'hard_abs_area_estimation_error',
    'soft_RAVE',
    'hard_RAVE',
    'log_soft_RAVE',
    'log_hard_RAVE',
    'Dice',
    'inference_data_task',
    'loss_func_class',
    'threshold',
    'temperature',
    'hard_volume',
    'soft_volume',
    'gt_volume',
    'data_id',
    'split'
]
# Filter out the columns we want to keep
exp_df = inference_df[cols_to_keep].drop_duplicates().reset_index(drop=True)

In [None]:
# We need to make sure that the cal split goes first.
exp_df = exp_df.sort_values('split', ascending=True)

In [None]:
plt.figure(figsize=(30, 20))
# We want to plot the mean error vs temperature
g = sns.relplot(
    data=exp_df,
    x=x_key,
    y=y_key,
    col='inference_data_task',
    col_wrap=5,
    kind='line',
    height=10,
    aspect=aspect,
)
# If the x_key is temperature, place a dashed red vertical line at 1.01
if x_key == 'temperature':
    for ax in g.axes.flat:
        ax.axvline(x=1.01, color='r', linestyle='--')
else:
    for ax in g.axes.flat:
        ax.axvline(x=0.5, color='r', linestyle='--')

g.set(xticks=xtick_range, xlim=x_lims, ylim=y_lims)
# Make a global title using suptitle with some spacing
plt.suptitle(f'{y_key} vs {x_key}', fontsize=30)
# Add spacing between the title and the plot
plt.subplots_adjust(top=0.9)

In [None]:
plt.figure(figsize=(30, 20))
# We want to plot the mean error vs temperature
g = sns.relplot(
    data=exp_df,
    x=x_key,
    y='Dice',
    col='inference_data_task',
    col_wrap=5,
    kind='line',
    height=10,
    aspect=aspect,
    # legend=(x_key == 'temperature')
)
# If the x_key is temperature, place a dashed red vertical line at 1.01
if x_key == 'temperature':
    for ax in g.axes.flat:
        ax.axvline(x=1.01, color='r', linestyle='--')
else:
    for ax in g.axes.flat:
        ax.axvline(x=0.5, color='r', linestyle='--')

g.set(xticks=xtick_range, xlim=x_lims, ylim=[0, 1])
# Make a global title using suptitle with some spacing
plt.suptitle(f'Dice vs {x_key}', fontsize=30)
# Add spacing between the title and the plot
plt.subplots_adjust(top=0.9)

In [None]:
for task_name in exp_df['inference_data_task'].unique():
    task_df = exp_df[exp_df['inference_data_task'] == task_name]
    # We want to plot the mean error vs temperature
    g = sns.relplot(
        data=task_df,
        x=x_key,
        y=y_key,
        hue='data_id',
        kind='line',
        height=8,
        aspect=aspect,
        legend=False,
        palette=get_prop_color_palette(
                    task_df, 
                    hue_key='data_id', 
                    magnitude_key='gt_volume',
                    cmap=cmap
                )
    )
    g.set(xticks=xtick_range, ylim=y_lims)
    # Set the title as the task name
    plt.title(task_name)
    plt.show()

In [None]:
for task_name in exp_df['inference_data_task'].unique():
    task_df = exp_df[exp_df['inference_data_task'] == task_name]
    # We want to plot the mean error vs temperature
    g = sns.relplot(
        data=task_df,
        x=x_key,
        y='Dice',
        hue='data_id',
        kind='line',
        height=8,
        aspect=aspect,
        legend=False,
        palette=get_prop_color_palette(
                    task_df, 
                    hue_key='data_id', 
                    magnitude_key='gt_volume',i
                    # cmap=cmapiii
                )
    )

    g.set(xticks=xtick_range, ylim=y_lims)
    plt.show()