In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import seaborn as sns
from ionpy.analysis import ResultsLoader
from ese.experiment.experiment.ese_exp import CalibrationExperiment
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
rs = ResultsLoader()
root = "/storage/vbutoi/scratch/ESE/"

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [None]:
from ese.experiment.analysis.inference import load_cal_inference_stats

inference_path = f"{root}/inference/11_13_23_WMH_SUME_Analysis"

cal_inference_info = load_cal_inference_stats(
    log_dir=inference_path,
)

In [None]:
cal_inference_info.keys()

In [None]:
metadata = cal_inference_info['metadata']

In [None]:
metadata.keys()

In [None]:
# This function will take in a dictionary of pixel meters and a metadata dataframe
# from which to select the log_set corresponding to particular attributes, then 
# we index into the dictionary to get the corresponding pixel meters.
def select_pixel_dict(pixel_meter_logdict, metadata, kwargs):
    # Select the metadata
    metadata = metadata.select(**kwargs)
    # Get the log set
    assert len(metadata) == 1, "More than one log set found."
    log_set = metadata['log_set'].iloc[0]
    # Return the pixel dict
    return pixel_meter_logdict[log_set]

## Pixel-level Analysis

In [None]:
# from ese.experiment.analysis.err_diagrams import viz_accuracy_vs_confidence

# for split in ["train", "val", "cal"]: 
#     split_preds_dict = select_pixel_dict(
#         pixel_meter_logdict=cal_inference_info["pixel_info_dicts"], 
#         metadata=cal_inference_info["metadata"],
#         kwargs={"dataset.split": split}
#     ) 
#     # Plot the accuracy vs confidence for this split.
#     viz_accuracy_vs_confidence(
#         split_preds_dict,
#         title=f"WMH Confidence vs Accuracy per (Bin and Predicted Label, {split} split)",
#         x="pred_label",
#         col="bin_num",
#         kind="bar",
#         add_avg=False,
#         facet_kws={'sharey': False, 'sharex': False}
#         )

In [None]:
from ese.experiment.analysis.err_diagrams import viz_accuracy_vs_confidence

for split in ["cal"]: 
    split_preds_dict = select_pixel_dict(
        pixel_meter_logdict=cal_inference_info["pixel_info_dicts"], 
        metadata=cal_inference_info["metadata"],
        kwargs={"dataset.split": split}
    ) 
    # Plot the accuracy vs confidence for this split.
    viz_accuracy_vs_confidence(
        split_preds_dict,
        title=f"WMH Confidence vs Accuracy per (Bin and Num Neighbors, split: {split})",
        x="num_neighbors",
        col="bin_num",
        kind="bar",
        add_avg=False,
        add_proportion=True,
        facet_kws={'sharey': False, 'sharex': False},
        )

## Image-level Analysis

In [None]:
image_info_df = cal_inference_info['image_info_df']

In [None]:
from ese.experiment.analysis.utils import reorder_splits

unique_image_df = reorder_splits(image_info_df.drop_constant())

In [None]:
# Now using seaborn's FacetGrid to create the KDE plots for the 'accuracy' column for each 'split'.
g = sns.FacetGrid(unique_image_df, hue="qual_metric", row="qual_metric", sharex=True, sharey=False)
g = g.map(sns.kdeplot, "qual_score", fill=True)

# Adjusting the layout
g.fig.tight_layout()

In [None]:
unique_image_df["cal_metric_type"] = unique_image_df["cal_metric"].apply(lambda x: x.split(" ")[-1])

In [None]:
# from ese.experiment.analysis.err_diagrams import viz_cal_metric_corr

# viz_cal_metric_corr(
#     unique_image_df,
#     title="WMH Calibration Metric NEGATED Correlation",
#     negate=True,
#     height=7
# )

In [None]:
from ese.experiment.analysis.err_diagrams import viz_cal_metric_corr

viz_cal_metric_corr(
    unique_image_df,
    title="WMH Calibration Metric NEGATED Correlation",
    row="cal_metric_type",
    negate=True,
    height=7
)