In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

from ionpy.analysis import ResultsLoader
from ese.experiment.experiment import CalibrationExperiment
import seaborn as sns
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
os.environ["CUDA_VISIBLE_DEVICES"] = '3' 

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [None]:
# Results loader object does everything
rs = ResultsLoader()
root = "/storage/vbutoi/scratch/ESE"

In [None]:
path = f"{root}/WMH_aug_runs"

dfc = rs.load_configs(
    path,
    properties=False,
)

df = rs.load_metrics(dfc)

In [None]:
best_exp = rs.get_experiment(
    df=df,
    exp_class=CalibrationExperiment,
    metric="val-dice_score",
    checkpoint="max-val-dice_score",
    device="cuda"
)

In [None]:
best_exp.vis_loss_curves(height=6)

In [None]:
%%yaml dataset_cfg 

_class: ese.experiment.datasets.WMH
annotator: observer_o12
axis: 0
split: train 
num_slices: 1
slicing: midslice 
task: Amsterdam 
version: 0.2

In [None]:
from ionpy.experiment.util import absolute_import
from torch.utils.data import DataLoader
from ionpy.util.torchutils import to_device
import torch
import numpy as np
from tqdm.notebook import tqdm
from ionpy.metrics import dice_score

def get_dataset_perf(exp, dataset_config):
    data_cfg = dataset_config.to_dict()
    dataset_cls = absolute_import(data_cfg.pop("_class"))
    Dataset = dataset_cls(**data_cfg)
    dataloader = DataLoader(Dataset, batch_size=1, shuffle=False, drop_last=False)

    items = []
    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(dataloader)):
            x, y = to_device(batch, exp.device)
            yhat = exp.model(x)
            
            # Extract predictions
            soft_pred = torch.sigmoid(yhat)
            hard_pred = (soft_pred > 0.5).float()
            background_pred = 1 - soft_pred

            # Compute metrics
            dice_met = np.round(dice_score(yhat, y).cpu().numpy(), 3)
            loss_image = (y - soft_pred)

            # Calculate calibration image
            calibration_image = torch.zeros_like(soft_pred)
            pred_match = (y==hard_pred).float()
            calibration_image[hard_pred == 1] = (pred_match - soft_pred).abs()[hard_pred == 1]
            calibration_image[hard_pred == 0] = (pred_match - background_pred).abs()[hard_pred == 0] 

            # We want the confidence per pixel and whether or not it was correct we will treat this 
            # as a 3D tensor which is (2 x H x W) where the first channel is the confidence and the second
            # channel is whether or not it was correct
            confidence_channel = torch.zeros_like(soft_pred)
            confidence_channel[hard_pred == 1] = soft_pred[hard_pred == 1]
            confidence_channel[hard_pred == 0] = background_pred[hard_pred == 0]
            pixel_wise_pred = torch.stack((confidence_channel, pred_match), dim=0).view(2, -1)

            # Maybe we also only want to look at the cases where pixels should be 1
            foreground_pixels = (y==1)
            foreground_conf = confidence_channel[foreground_pixels]
            foreground_match = pred_match[foreground_pixels]
            foreground_pixel_wise_pred = torch.stack((foreground_conf, foreground_match), dim=0).view(2, -1)

            # Wrap it in an item
            items.append({
                "image": x.cpu().numpy().squeeze(),
                "label": y.cpu().numpy().squeeze(),
                "pred | Dice: " + str(dice_met): hard_pred.cpu().numpy().squeeze(),
                "loss": loss_image.cpu().numpy().squeeze(),
                "calibration error": calibration_image.cpu().numpy().squeeze(),
                "pixel_wise_pred": pixel_wise_pred.cpu().numpy().squeeze(),
                "foreground_pixel_wise_pred": foreground_pixel_wise_pred.cpu().numpy(),
            })
    return items

In [None]:
from ionpy.util import Config

# val_perf is a dict where each item is the subj id
# with the y, ypred, yloss, ydice
val_perf = get_dataset_perf(
    exp=best_exp, 
    dataset_config=Config(dataset_cfg)
    )

In [None]:
import matplotlib.pyplot as plt

num_subjs = len(val_perf)
                       
# We are going to visualize 5 things here
# - The image
# - The ground truth
# - The hard prediction (with loss above it)
# - The delta between the ground truth and the soft prediction
# - The Calibration error (Acc - Soft Pred)

f, axarr = plt.subplots(num_subjs, 5, figsize=(25, 5*num_subjs))

color_dict = {
    "loss": "twilight",
    "calibration error": "plasma" 
}
skip_keys = ["pixel_wise_pred", "foreground_pixel_wise_pred"]

for i, subj in enumerate(val_perf):
    for k_idx, key in enumerate(subj.keys()):
        if key not in skip_keys:
            if key in color_dict:
                im = axarr[i, k_idx].imshow(subj[key], cmap=color_dict[key])
            else:
                im = axarr[i, k_idx].imshow(subj[key], cmap="gray")
            axarr[i, k_idx].axis("off")
            axarr[i, k_idx].set_title(key)
            f.colorbar(im, ax=axarr[i, k_idx])
plt.show()

In [None]:
def ECE(accuracies, confidences, num_bins=10, round_to=5):
    """
    Calculates the Expected Calibration Error (ECE) for a model.
    Args:
        accuracies: numpy array of calibration accuracies for each bin
        confidences: numpy array of confidences outputted by the model
        num_bins (int): number of confidence interval bins
    Returns:
        float: Expected Calibration Error
    """
    bin_boundaries = np.linspace(0, 1, num_bins+1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0.0

    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Calculated |confidence - accuracy| in each bin
        in_bin = np.logical_and(confidences >= bin_lower, confidences < bin_upper)
        prop_in_bin = np.mean(in_bin)
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(accuracies[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return np.round(ece, round_to)


def plot_calibration_plot(
    accuracies,
    confidences,
    num_plots=6,
    bin_multiplier=20,
    style='bar'
    ):

    # One issue is that we want to look maybe with different bin sizes, so 
    # lets look at a range.
    f, ax = plt.subplots(1, num_plots, figsize=(num_plots*6, 6))
    for g_idx in range(1, num_plots+1):
        
        # Calulate the bins and spacing
        interval_size = 1 / (bin_multiplier*g_idx)
        bins = np.linspace(0, 1, bin_multiplier*g_idx + 1) # Off by one error

        # For each bin, calculate the mean accuracy within the bin  
        bar_heights = [np.mean(accuracies[(confidences >= b) & (confidences < (b + interval_size))]) for b in bins]
        ideal_heights = bins

        if style == 'bar':
            ax[g_idx-1].bar(bins, ideal_heights, width=interval_size, color='red', alpha=0.2)
            ax[g_idx-1].bar(bins, bar_heights, width=interval_size, color='blue')
        elif style == 'line':
            ax[g_idx-1].plot(bins, bar_heights)
        else:
            raise ValueError("Style must be bar or line.")

        ax[g_idx-1].plot([0, 1], [0, 1], linestyle='dotted', linewidth=2, color='gray')
        ax[g_idx-1].set_title(f"{20*g_idx} Bins, ECE: {ECE(accuracies, confidences, num_bins=20*g_idx)}")
        ax[g_idx-1].set_xlabel("Confidence")
        ax[g_idx-1].set_ylabel("Accuracy")

    plt.show()

# 1. What if we mix all calibration errors from all of the pixels of all of the images? Here is that histogram.

In [None]:
all_pixel_preds = np.concatenate([subj["pixel_wise_pred"] for subj in val_perf], axis=1)
plot_calibration_plot(accuracies=all_pixel_preds[1], confidences=all_pixel_preds[0], num_plots=5, style='bar')

In [None]:
all_pixel_preds = np.concatenate([subj["pixel_wise_pred"] for subj in val_perf], axis=1)
plot_calibration_plot(accuracies=all_pixel_preds[1], confidences=all_pixel_preds[0], num_plots=5, style='line')

# 2. Let's do the same experiment EXCEPT only look at positive pixel regions.

In [None]:
all_pixel_preds = np.concatenate([subj["foreground_pixel_wise_pred"] for subj in val_perf], axis=1)
plot_calibration_plot(accuracies=all_pixel_preds[1], confidences=all_pixel_preds[0], num_plots=5, style='bar')

In [None]:
all_pixel_preds = np.concatenate([subj["foreground_pixel_wise_pred"] for subj in val_perf], axis=1)
plot_calibration_plot(accuracies=all_pixel_preds[1], confidences=all_pixel_preds[0], num_plots=5, style='line')

## Far worse!