In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn, optim
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ExponentialLR
from pathlib import Path
import copy
import traceback
import os
import contextlib
from sklearn.isotonic import IsotonicRegression
from typing import Union
import pickle
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

### Download the UCI datasets

In [None]:
!git init
!git remote add origin https://github.com/YoungseogChung/calibrated-quantile-uq.git
!git config core.sparseCheckout true
!echo "data/UCI_Datasets" >> .git/info/sparse-checkout
!git pull origin master
!mkdir -p datasets/UCI_Datasets
!mv data/UCI_Datasets/* datasets/UCI_Datasets/
!rm -rf data .git

[33mhint: Using 'master' as the name for the initial branch. This default branch name[m
[33mhint: is subject to change. To configure the initial branch name to use in all[m
[33mhint: [m
[33mhint: 	git config --global init.defaultBranch <name>[m
[33mhint: [m
[33mhint: Names commonly chosen instead of 'master' are 'main', 'trunk' and[m
[33mhint: 'development'. The just-created branch can be renamed via this command:[m
[33mhint: [m
[33mhint: 	git branch -m <name>[m
Initialized empty Git repository in /content/.git/
remote: Enumerating objects: 228, done.[K
remote: Counting objects: 100% (228/228), done.[K
remote: Compressing objects: 100% (161/161), done.[K
remote: Total 228 (delta 119), reused 172 (delta 63), pack-reused 0 (from 0)[K
Receiving objects: 100% (228/228), 3.68 MiB | 8.78 MiB/s, done.
Resolving deltas: 100% (119/119), done.
From https://github.com/YoungseogChung/calibrated-quantile-uq
 * branch            master     -> FETCH_HEAD
 * [new branch]      mas

### 1D interpolation function

This function is used to interpolate additional quantiles between the quantile regression models.

In [None]:
class Interp1d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y, xnew, out=None):
        """
        Linear 1D interpolation on the GPU for Pytorch.
        This function returns interpolated values of a set of 1-D functions at
        the desired query points `xnew`.
        This function is working similarly to Matlab™ or scipy functions with
        the `linear` interpolation mode on, except that it parallelises over
        any number of desired interpolation problems.
        The code will run on GPU if all the tensors provided are on a cuda
        device.

        Parameters
        ----------
        x : (N, ) or (D, N) Pytorch Tensor
            A 1-D or 2-D tensor of real values.
        y : (N,) or (D, N) Pytorch Tensor
            A 1-D or 2-D tensor of real values. The length of `y` along its
            last dimension must be the same as that of `x`
        xnew : (P,) or (D, P) Pytorch Tensor
            A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if
            _both_ `x` and `y` are 1-D. Otherwise, its length along the first
            dimension must be the same as that of whichever `x` and `y` is 2-D.
        out : Pytorch Tensor, same shape as `xnew`
            Tensor for the output. If None: allocated automatically.

        """
        # making the vectors at least 2D
        is_flat = {}
        require_grad = {}
        v = {}
        device = []
        eps = torch.finfo(y.dtype).eps
        for name, vec in {'x': x, 'y': y, 'xnew': xnew}.items():
            assert len(vec.shape) <= 2, 'interp1d: all inputs must be '\
                                        'at most 2-D.'
            if len(vec.shape) == 1:
                v[name] = vec[None, :]
            else:
                v[name] = vec
            is_flat[name] = v[name].shape[0] == 1
            require_grad[name] = vec.requires_grad
            device = list(set(device + [str(vec.device)]))
        assert len(device) == 1, 'All parameters must be on the same device.'
        device = device[0]

        # Checking for the dimensions
        assert (v['x'].shape[1] == v['y'].shape[1]
                and (
                     v['x'].shape[0] == v['y'].shape[0]
                     or v['x'].shape[0] == 1
                     or v['y'].shape[0] == 1
                    )
                ), ("x and y must have the same number of columns, and either "
                    "the same number of row or one of them having only one "
                    "row.")

        reshaped_xnew = False
        if ((v['x'].shape[0] == 1) and (v['y'].shape[0] == 1)
           and (v['xnew'].shape[0] > 1)):
            # if there is only one row for both x and y, there is no need to
            # loop over the rows of xnew because they will all have to face the
            # same interpolation problem. We should just stack them together to
            # call interp1d and put them back in place afterwards.
            original_xnew_shape = v['xnew'].shape
            v['xnew'] = v['xnew'].contiguous().view(1, -1)
            reshaped_xnew = True

        # identify the dimensions of output and check if the one provided is ok
        D = max(v['x'].shape[0], v['xnew'].shape[0])
        shape_ynew = (D, v['xnew'].shape[-1])
        if out is not None:
            if out.numel() != shape_ynew[0]*shape_ynew[1]:
                # The output provided is of incorrect shape.
                # Going for a new one
                out = None
            else:
                ynew = out.reshape(shape_ynew)
        if out is None:
            ynew = torch.zeros(*shape_ynew, device=device)

        # moving everything to the desired device in case it was not there
        # already (not handling the case things do not fit entirely, user will
        # do it if required.)
        for name in v:
            v[name] = v[name].to(device)

        # calling searchsorted on the x values.
        ind = ynew.long()

        # expanding xnew to match the number of rows of x in case only one xnew is
        # provided
        if v['xnew'].shape[0] == 1:
            v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1)

        # the squeeze is because torch.searchsorted does accept either a nd with
        # matching shapes for x and xnew or a 1d vector for x. Here we would
        # have (1,len) for x sometimes
        torch.searchsorted(v['x'].contiguous().squeeze(),
                           v['xnew'].contiguous(), out=ind)

        # the `-1` is because searchsorted looks for the index where the values
        # must be inserted to preserve order. And we want the index of the
        # preceeding value.
        ind -= 1
        # we clamp the index, because the number of intervals is x.shape-1,
        # and the left neighbour should hence be at most number of intervals
        # -1, i.e. number of columns in x -2
        ind = torch.clamp(ind, 0, v['x'].shape[1] - 1 - 1)

        # helper function to select stuff according to the found indices.
        def sel(name):
            if is_flat[name]:
                return v[name].contiguous().view(-1)[ind]
            return torch.gather(v[name], 1, ind)

        # activating gradient storing for everything now
        enable_grad = False
        saved_inputs = []
        for name in ['x', 'y', 'xnew']:
            if require_grad[name]:
                enable_grad = True
                saved_inputs += [v[name]]
            else:
                saved_inputs += [None, ]
        # assuming x are sorted in the dimension 1, computing the slopes for
        # the segments
        is_flat['slopes'] = is_flat['x']
        # now we have found the indices of the neighbors, we start building the
        # output. Hence, we start also activating gradient tracking
        with torch.enable_grad() if enable_grad else contextlib.suppress():
            v['slopes'] = (
                    (v['y'][:, 1:]-v['y'][:, :-1])
                    /
                    (eps + (v['x'][:, 1:]-v['x'][:, :-1]))
                )

            # now build the linear interpolation
            ynew = sel('y') + sel('slopes')*(
                                    v['xnew'] - sel('x'))

            if reshaped_xnew:
                ynew = ynew.view(original_xnew_shape)

        ctx.save_for_backward(ynew, *saved_inputs)
        return ynew

    @staticmethod
    def backward(ctx, grad_out):
        inputs = ctx.saved_tensors[1:]
        gradients = torch.autograd.grad(
                        ctx.saved_tensors[0],
                        [i for i in inputs if i is not None],
                        grad_out, retain_graph=True)
        result = [None, ] * 5
        pos = 0
        for index in range(len(inputs)):
            if inputs[index] is not None:
                result[index] = gradients[pos]
                pos += 1
        return (*result,)

interp1d = Interp1d.apply

### CaliPSo model ensemble definition

Implementation of the CaliPSo model as described in the paper

In [None]:
## LOWER QUANTILE MODEL
class quantile_model(nn.Module):
    def __init__(self, X, Y, vanilla_model, quantile):
        super().__init__()
        self.X = X
        self.Y = Y
        self.quantile = quantile
        self.vanilla_model = vanilla_model

    def forward(self, x, x_cal = None, y_cal = None):

        vanilla_model = self.vanilla_model(x)

        if x_cal is None and y_cal is None:
            y_cal = self.Y
            vanilla_model_Xcal = self.vanilla_model(self.X)
        else:
            vanilla_model_Xcal = self.vanilla_model(x_cal)

        #c_lb enforces calibration during training on calibration dataset
        c_lb = torch.quantile(y_cal - vanilla_model_Xcal, self.quantile, interpolation='linear')

        zero_quantile = c_lb + vanilla_model

        return zero_quantile

## Beyond Pinball Loss Base Model
class bpl_nn(nn.Module):
    def __init__(self, nfeatures):
        super(bpl_nn, self).__init__()
        # input layer
        self.input_layer = nn.Sequential(
            nn.Linear(nfeatures, 64),
            nn.ReLU())
        # hidden layers
        self.hidden_layers = nn.ModuleList()
        for i in range(1):
            self.hidden_layers.append(nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU()))
        # output layer
        self.output_layer = nn.Sequential(
            nn.Linear(64, 1))

    def forward(self, input):
        output = self.input_layer(input)
        for layer in self.hidden_layers:
            output_hidden = layer(output) + output
            output = output_hidden
        output = self.output_layer(output)

        return output

## SPECIFY CALIBRATED QUANTILE MODEL ENSEMBLE
class quantile_model_ensemble(nn.Module):
    def __init__(self, X, Y, vanilla_model, half_q_levels: Union[int, torch.Tensor], output_device, vanilla_weights_path = None, avg_weights=False):
        super().__init__()
        self.output_device = output_device
        self.upper_quantile_models = nn.ModuleList()
        self.lower_quantile_models = nn.ModuleList()
        self.X = X
        self.Y = Y
        self.iso_reg = None
        self.printname = 'quantile_model'
        self.printcolor = 'C0'

        ## check if number of quantile levels is divisible by two
        if isinstance(half_q_levels, int):
            assert int((half_q_levels+1)/2) == (half_q_levels+1)/2
            quantile_model_levels = range(int((half_q_levels+1)/2))
        else:
            quantile_model_levels = half_q_levels

        self.half_q_levels = quantile_model_levels

        length_data = X.shape[0]

        for i in range(len(quantile_model_levels)):
            vanilla_quantile_model_lower = vanilla_model(nfeatures = X.shape[1])
            vanilla_quantile_model_upper = vanilla_model(nfeatures = X.shape[1])
            if isinstance(vanilla_weights_path, list):
                path = vanilla_weights_path[i]
            else:
                path = vanilla_weights_path
            if path is not None:
                # Initialize weights from normal regression on training set
                lower_weights = torch.load(path, weights_only=True, map_location=output_device)
                upper_weights = copy.deepcopy(lower_weights)
                if avg_weights:
                    for key in lower_weights:
                        upper_weights[key] = (lower_weights[key].to(output_device) + vanilla_quantile_model_upper.state_dict()[key].to(output_device)) / 2
                        lower_weights[key] = (lower_weights[key].to(output_device) + vanilla_quantile_model_lower.state_dict()[key].to(output_device)) / 2
                vanilla_quantile_model_lower.load_state_dict(lower_weights)
                vanilla_quantile_model_upper.load_state_dict(upper_weights)
            q_model_upper = quantile_model(X=X, Y=Y, vanilla_model=vanilla_quantile_model_upper, quantile = 1)
            q_model_lower = quantile_model(X=X, Y=Y, vanilla_model=vanilla_quantile_model_lower, quantile = 0)
            self.upper_quantile_models.append(q_model_upper)
            self.lower_quantile_models.append(q_model_lower)

        self.to(output_device)


    def forward(self, x, quantile = []):

        half_q_levels = self.half_q_levels
        output_device = self.output_device
        n_quantile_models = len(half_q_levels)


        quantiles_lower = []
        quantiles_upper = []

        x_cal = self.X
        y_cal = self.Y

        interval_tot = 1
        ind_kept = torch.ones_like(y_cal, dtype=torch.bool)

        for n_mod in range(n_quantile_models):
            upper_quantile_model = self.upper_quantile_models[n_mod]
            lower_quantile_model = self.lower_quantile_models[n_mod]

            lower_quantile = lower_quantile_model(x, x_cal = x_cal, y_cal = y_cal)
            upper_quantile = upper_quantile_model(x, x_cal = x_cal, y_cal = y_cal)

            # USED EXCLUSIVELY TO COMPUTE CALIBRATING ELEMENTS
            lower_quantile_cal = lower_quantile_model(x_cal, x_cal=x_cal, y_cal=y_cal)
            upper_quantile_cal = upper_quantile_model(x_cal, x_cal=x_cal, y_cal=y_cal)


            # IN TRAINING, WE ONLY OPTIMIZE SHARPNESS FOR ENTRIES CORRESPONDING TO x_cal AND y_cal
            if self.training:
                # ALL ENTRIES THAT DO NOT CORRESPOND TO x_cal AND y_cal ARE REPLACED WITH DUMMY VALUES THAT DO NOT
                # AFFECT OPTIMIZATION
                # index_x = np.all([np.in1d(x.T[i].cpu(), x_cal.T[i].cpu()) for i in range(x.shape[1])], 0)
                index_x = np.all([np.isin(x.T[i].cpu(), x_cal.T[i].cpu()) for i in range(x.shape[1])], 0) #in1d
                lower_quantile[index_x.__invert__()] = 0
                upper_quantile[index_x.__invert__()] = 0
            else:
                if n_mod==0:
                    one_quantile = upper_quantile
                    one_quantile_cal = upper_quantile_cal
                else:
                    # Apply maximum and minimum operations to inner models as described in the paper
                    lower_quantile = torch.maximum(lower_quantile, previous_lower_quantile + delta_lower * (
                                previous_upper_quantile - previous_lower_quantile)).to(self.output_device)
                    upper_quantile = torch.minimum(upper_quantile, previous_lower_quantile + delta_upper * (
                                previous_upper_quantile - previous_lower_quantile)).to(self.output_device)
                    lower_quantile = torch.minimum(lower_quantile, one_quantile).to(self.output_device)

                    lower_quantile_cal = torch.maximum(lower_quantile_cal, previous_lower_quantile_cal[ind_kept] + delta_lower * (
                                previous_upper_quantile_cal[ind_kept] - previous_lower_quantile_cal[ind_kept])).to(self.output_device)
                    upper_quantile_cal = torch.minimum(upper_quantile_cal, previous_lower_quantile_cal[ind_kept] + delta_upper * (
                                previous_upper_quantile_cal[ind_kept] - previous_lower_quantile_cal[ind_kept])).to(self.output_device)
                    lower_quantile_cal = torch.minimum(lower_quantile_cal, one_quantile_cal[ind_kept]).to(self.output_device)

                upper_quantile = torch.maximum(lower_quantile, upper_quantile).to(self.output_device)
                previous_lower_quantile = lower_quantile
                previous_upper_quantile = upper_quantile
                previous_lower_quantile_cal = lower_quantile_cal
                previous_upper_quantile_cal = upper_quantile_cal
                one_quantile_cal = one_quantile_cal[ind_kept.squeeze()]

            quantiles_lower.append(lower_quantile)
            quantiles_upper.append(upper_quantile)

            if n_mod<n_quantile_models-1:
                delta_upper_lower = upper_quantile_cal - lower_quantile_cal
                delta_upper_lower = torch.maximum(delta_upper_lower, torch.tensor(1e-24))
                delta_quantile_level = (half_q_levels[n_mod+1] - half_q_levels[n_mod]).to(self.output_device)
                delta_upper = torch.quantile((y_cal - lower_quantile_cal)/delta_upper_lower, 1 - delta_quantile_level/interval_tot)
                delta_lower = torch.quantile((y_cal - lower_quantile_cal) / delta_upper_lower, delta_quantile_level/interval_tot)
                ind_kept = ((y_cal < lower_quantile_cal + delta_upper*(delta_upper_lower)) & (
                        y_cal > lower_quantile_cal + delta_lower*(delta_upper_lower))).reshape(y_cal.shape[0], )
                if not ind_kept.any() and n_mod < n_quantile_models-1:
                    STOP_FOR_DEBUGGING = 1
                x_cal = x_cal[ind_kept]
                y_cal = y_cal[ind_kept]


                interval_tot = 1 - 2*half_q_levels[n_mod+1]

        quantiles_upper.reverse()
        quantiles = torch.hstack([torch.hstack(quantiles_lower).to(output_device), torch.hstack(quantiles_upper).to(output_device)]).to(output_device)
        if quantile == []:
            return quantiles
        else:
            if not torch.is_tensor(quantile):
                quantile = torch.tensor(quantile)
            quantile = quantile.to(self.output_device)
            total_q_levels = torch.hstack([half_q_levels, 1- half_q_levels.flip(0) ]).to(self.output_device)
            return interp1d(total_q_levels.repeat(x.shape[0],1), quantiles, quantile)

    def get_quantiles(self, X, q_levels):
        if self.iso_reg is None:
            return self(X, q_levels)
        else:
            length_data = self.X.shape[0]
            quantile_linspace = torch.linspace(1/(length_data+1), length_data/(length_data+1), length_data)
            quantile_linspace = torch.linspace(0, 1, length_data)
            uncalibrated_quantiles = self(X, quantile_linspace.to(self.output_device))
            recalibrated_quantile_vals = self.iso_reg.predict(quantile_linspace.numpy())
            recalibrated_quantile_vals = torch.tensor(recalibrated_quantile_vals).to(self.output_device)
            if not torch.is_tensor(q_levels):
                q_levels = torch.tensor(q_levels)
            return interp1d(recalibrated_quantile_vals.repeat(X.shape[0], 1), uncalibrated_quantiles,
                                            q_levels.to(self.output_device)).to(self.output_device)


    def recalibrate(self, val=None):

        self.iso_reg = None
        self.eval()
        if val is None:
            cdf_vals = self.get_cdf_value(self.X, self.Y)
        else:
            cdf_vals = self.get_cdf_value(torch.cat((self.X, val[0]), dim=0), torch.cat((self.Y, val[1]), dim=0))
        Phat_vals = []
        for cdf in cdf_vals:
            cdf_l = cdf_vals <= cdf
            cdf_ind = torch.zeros(cdf_vals.shape)
            cdf_ind[cdf_l] = 1
            Phat_vals.append(torch.mean(cdf_ind))

        Phat_vals = torch.stack(Phat_vals).reshape(cdf_vals.shape[0], 1)

        ## CONVERT Phat_Vals and cdf_vals TO NUMPY FOR ISOTONIC REGRESSION
        cdf_vals_np = np.float64(cdf_vals.detach().cpu().numpy())
        Phat_vals_np = np.float64(Phat_vals.cpu().numpy().reshape(Phat_vals.shape[0], ))

        # Recalibrate using Isotonic Regression
        iso_reg = IsotonicRegression(out_of_bounds='clip').fit(cdf_vals_np, Phat_vals_np)
        self.iso_reg = iso_reg





    def get_cdf_value(self, X, Y, num_samples = 10000):
        return self.get_cdf1_cdf2(X=X, Y1=Y, Y2=None, num_samples=num_samples)


    def get_cdf1_cdf2(self, X, Y1, Y2 = None, num_samples = 10000):
        # COMPUTE CDF FOR SAME VALUE OF X AND TWO DIFFERENT VALUES OF Y. REQUIRED TO COMPUTE
        # FINITE DIFFERENCE USED TO COMPUTE LIKELIHOOD. SIMULTANEOUS COMPUTATION FOR DIFFERENT YS
        # NECESSARY DUE TO SAMPLING-BASED COMPUTATION


        output_device = self.output_device
        length_data = X.shape[0]
        q_levels = torch.linspace(0, 1, num_samples).to(output_device)

        quantiles = self(X, quantile=q_levels)

        F1 = interp1d(quantiles, q_levels, Y1).clamp(min=0, max=1)
        if Y2 is None:
            if self.iso_reg is None:
                return F1
            else:
                F1_recal = self.iso_reg.predict(F1.detach().cpu().numpy())
                return torch.tensor(F1_recal).to(output_device)
        else:
            F2 = interp1d(quantiles, q_levels, Y2).clamp(min=0, max=1)
            if self.iso_reg is None:
                return F1, F2
            else:
                F1_recal = self.iso_reg.predict(F1.detach().cpu().numpy())
                F2_recal = self.iso_reg.predict(F2.detach().cpu().numpy())
                return F1, F2


    def get_likelihood(self, X, Y, deltay = 1e-2, num_samples = 10000):
        output_device = self.output_device
        dY = deltay*torch.ones(Y.shape).to(output_device)
        Y_dplus = Y+dY
        Y_dminus = Y-dY
        F_minus, F_plus = self.get_cdf1_cdf2(X=X, Y1=Y_dminus, Y2=Y_dplus, num_samples = 10000)
        likelihood = (F_plus - F_minus)/(2*dY.reshape(F_minus.shape))
        return likelihood

### Compute metrics for model evaluation

Metrics include the Expected Calibration Error (ECE) and 95% Interval Sharpness

In [None]:
def check_score(X_test, Y_test, model, p_diagnostics =torch.linspace(0, 1, 25)):

        output_device = model.output_device
        CS = torch.zeros(p_diagnostics.shape).to(output_device)
        quantile_tau = model.get_quantiles(X_test, p_diagnostics)
        diff = Y_test - quantile_tau
        for i in range(p_diagnostics.shape[0]):
            CS[i] = torch.mean(torch.maximum(p_diagnostics[i] * (diff.T[i]), (1 - p_diagnostics[i]) * (-diff.T[i]))).to(output_device)
        return CS.mean()

def expected_calibration_error(X_test, Y_test, model, p_diagnostics = torch.linspace(0, 1, 100)):

    output_device = model.output_device
    p_diagnostics = p_diagnostics.to(model.output_device)
    cdf_vals = model.get_cdf_value(X=X_test, Y=Y_test)
    phat = []
    for pj in p_diagnostics:
        cdf_l = cdf_vals <= pj
        cdf_ind = torch.zeros(cdf_vals.shape).to(output_device)
        cdf_ind[cdf_l] = 1
        phat.append(torch.mean(cdf_ind))
    phat = torch.stack(phat).reshape(p_diagnostics.shape[0],)
    dp = p_diagnostics[1:]- p_diagnostics[:-1]
    err_p = 0.5*(phat - p_diagnostics)[1:] + 0.5*(phat - p_diagnostics)[:-1]
    return torch.sum(err_p**2*dp) # *dp


def negative_log_likelihood(X_test, Y_test, model, dy = 1e-2):

    output_device = model.output_device
    dFdY = model.get_likelihood(X=X_test, Y=Y_test)

    return -torch.mean(torch.log(dFdY).to(output_device)).to(output_device)



def average_variance(X_test, model, p_diagnostics = torch.linspace(0, 1, 100)):

    output_device = model.output_device
    quantiles = model.get_quantiles(X=X_test, q_levels=p_diagnostics)
    dp = (p_diagnostics[1:] - p_diagnostics[:-1]).to(output_device)
    x = 0.5*(quantiles.T[1:] + quantiles.T[:-1])
    Ex = torch.sum(x.T*dp,1)
    varx = torch.sum(((x - Ex)**2).T*dp,1).to(output_device)

    return torch.mean(varx)

def average_95_interval(X_test, model):

    output_device = model.output_device
    quantiles = model.get_quantiles(X=X_test, q_levels=torch.tensor([0.025, 0.975]).to(output_device))
    return torch.mean(quantiles.T[-1] - quantiles.T[0], 0).to(output_device)

def run_diagnostics(X_test, Y_test, model, rnd=4, display=True):

    # nll = round(negative_log_likelihood(X_test, Y_test, model).item(), rnd)
    # print('Negative log likelihood: ', nll)
    ece = round(expected_calibration_error(X_test, Y_test, model).item(), rnd)

    avg_interval = round(average_95_interval(X_test, model).item(), rnd)
    avg_var = round(average_variance(X_test, model).item(), rnd)
    chk_score = round(check_score(X_test, Y_test, model).item(), rnd)
    if display:
        print('Expected calibration error: ', ece)
        print('Average length of 95% interval: ', avg_interval)
        print('Average variance: ', avg_var)
        print('Check score: ', chk_score)

    return np.array([ece, avg_interval, avg_var, chk_score])

def plot_calibration_curve(
    exp_proportions,
    obs_proportions,
    title=None,
    curve_label=None,
    make_plots=False,
):
    miscalibration_area = torch.mean(
        torch.abs(exp_proportions - obs_proportions)
    ).item()

    return miscalibration_area

def get_q_idx(exp_props, q):
    target_idx = None
    for idx, x in enumerate(exp_props):
        if idx + 1 == exp_props.shape[0]:
            if round(q, 2) == round(float(exp_props[-1]), 2):
                target_idx = exp_props.shape[0] - 1
            break
        if x <= q < exp_props[idx + 1]:
            target_idx = idx
            break
    if target_idx is None:
        import pdb

        pdb.set_trace()
        raise ValueError("q must be within exp_props")
    return target_idx

def test_group_cali(
    y,
    q_pred_mat,
    exp_props,
    y_range,
    ratio,
    num_group_draws=20,
    make_plots=False,
):

    num_pts, num_q = q_pred_mat.shape
    group_size = max([int(round(num_pts * ratio)), 2])
    q_025_idx = get_q_idx(exp_props, 0.025)
    q_975_idx = get_q_idx(exp_props, 0.975)

    score_per_trial = []
    for _ in range(20):
        ##########
        group_cali_scores = []
        for g_idx in range(num_group_draws):
            rand_idx = np.random.choice(num_pts, group_size, replace=True)
            g_y = y[rand_idx]
            g_q_preds = q_pred_mat[rand_idx, :]
            g_obs_props = torch.mean(
                (g_q_preds >= g_y).float(), dim=0
            ).flatten()
            assert exp_props.shape == g_obs_props.shape
            g_cali_score = plot_calibration_curve(
                exp_props, g_obs_props, make_plots=False
            )

            group_cali_scores.append(g_cali_score)

        # mean_cali_score = np.mean(group_cali_scores)
        mean_cali_score = np.max(group_cali_scores)
        ##########

        score_per_trial.append(mean_cali_score)

    return np.mean(score_per_trial)
    return mean_cali_score

def test_uq(
    model,
    x,
    y,
    exp_props,
    y_range,
    recal_model=None,
    recal_type=None,
    make_plots=False,
    test_group_cal=False,
    display=True,
):

    num_pts = x.shape[0]
    y = y.reshape(num_pts, -1)

    quantile_preds = model.get_quantiles(
        x,
        exp_props,
    )  # of shape (num_pts, num_q)
    obs_props = torch.mean((quantile_preds >= y).float(), dim=0).flatten()

    assert exp_props.shape == obs_props.shape

    pred = model(x)
    num_quantiles = pred.shape[1]
    pinball_loss = []
    model_obs = []
    model_y = []
    mask = (y >= pred)
    delta = (y - pred)
    for i in range(num_quantiles):
        quantile = i/(num_quantiles-1)
        q_mask = mask[:,i]
        q_delta = delta[:,i]
        pinball_loss.append(torch.mean(q_mask*q_delta*quantile + (~q_mask)*(-q_delta)*(1-quantile)).item())
        model_obs.append(torch.mean((y <= pred[:,i].reshape(y.shape)).float()).item())
        model_y.append(torch.mean(pred[:,i]).item())

    idx_01 = get_q_idx(exp_props, 0.01)
    idx_99 = get_q_idx(exp_props, 0.99)
    individual_cali = exp_props[idx_01 : idx_99 + 1] - obs_props[idx_01 : idx_99 + 1]
    cali_score = torch.abs(individual_cali).mean().item()

    order = torch.argsort(y.flatten())
    q_025 = quantile_preds[:, get_q_idx(exp_props, 0.025)][order]
    q_975 = quantile_preds[:, get_q_idx(exp_props, 0.975)][order]
    sharp_score = torch.mean(q_975 - q_025).item() / y_range

    g_cali_scores = []
    if test_group_cal:
        ratio_arr = np.linspace(0.01, 1.0, 10)
        if display:
            print(
                "Spanning group size from {} to {} in {} increments".format(
                    np.min(ratio_arr), np.max(ratio_arr), len(ratio_arr)
                )
            )
        for r in (tqdm.tqdm(ratio_arr) if display else ratio_arr):
            gc = test_group_cali(
                y=y,
                q_pred_mat=quantile_preds[:, idx_01 : idx_99 + 1],
                exp_props=exp_props[idx_01 : idx_99 + 1],
                y_range=y_range,
                ratio=r,
            )
            g_cali_scores.append(gc)
        g_cali_scores = np.array(g_cali_scores)

    return (
        cali_score,
        sharp_score,
        obs_props,
        quantile_preds,
        g_cali_scores,
        individual_cali,
        pinball_loss,
        model_obs,
        model_y
    )

### Load datasets

In [None]:
def load_data_bpl(dataset, seed, dataset_path='datasets/UCI_Datasets', extra_val = None):
    data = np.loadtxt("{}/{}.txt".format(dataset_path, dataset))
    x_al = data[:, :-1]
    y_al = data[:, -1].reshape(-1, 1)

    x_tr, x_te, y_tr, y_te = train_test_split(
        x_al, y_al, test_size=0.1, random_state=seed
    )
    x_tr, x_va, y_tr, y_va = train_test_split(
        x_tr, y_tr, test_size=0.2, random_state=seed
    )
    if extra_val is not None:
        x_tr, x_va2, y_tr, y_va2 = train_test_split(
            x_tr, y_tr, test_size=extra_val, random_state=seed
    )

    s_tr_x = StandardScaler().fit(x_tr)
    s_tr_y = StandardScaler().fit(y_tr)

    x_tr = torch.Tensor(s_tr_x.transform(x_tr))
    x_va = torch.Tensor(s_tr_x.transform(x_va))
    x_te = torch.Tensor(s_tr_x.transform(x_te))

    y_tr = torch.Tensor(s_tr_y.transform(y_tr))
    y_va = torch.Tensor(s_tr_y.transform(y_va))
    y_te = torch.Tensor(s_tr_y.transform(y_te))
    y_al = torch.Tensor(s_tr_y.transform(y_al))

    if extra_val is not None:
        x_va2 = torch.Tensor(s_tr_x.transform(x_va2))
        y_va2 = torch.Tensor(s_tr_y.transform(y_va2))
        return x_tr, x_va, x_va2, x_te, y_tr, y_va, y_va2, y_te, y_al

    return x_tr, x_va, x_te, y_tr, y_va, y_te, y_al

### Train a base L1 regression model for initializing weights

As described in the paper, to help regularize the model, the networks are initalized using the weights obtained via L1 regression on the full training dataset.

In [None]:
def gen_model(dataset, seed, path, output_device, extra_val=None, dataset_path='datasets/UCI_Datasets'):
    val = True
    nepochs = 1000
    display = True

    np.random.seed(seed)
    torch.manual_seed(seed)

    if extra_val is None:
        X, X_val, X_test, Y, Y_val, Y_test, Y_al = load_data_bpl(dataset, seed, dataset_path=dataset_path)
    else:
        X, X_val, X_true_val, X_test, Y, Y_val, Y_true_val, Y_test, Y_al = load_data_bpl(dataset, seed, extra_val=extra_val, dataset_path=dataset_path)

    X = X.to(output_device)
    Y = Y.to(output_device)
    X_val = X_val.to(output_device)
    Y_val = Y_val.to(output_device)
    X_test = X_test.to(output_device)
    Y_test = Y_test.to(output_device)
    if not val:
        X=torch.cat((X, X_val), dim=0)
        Y=torch.cat((Y, Y_val), dim=0)

    class data_set(Dataset):
        def __init__(self, X, Y):
            self.X = X
            self.Y = Y

        def __len__(self):
            return len(self.X)

        def __getitem__(self, index):
            return self.X[index], self.Y[index]

    data = data_set(X=X, Y=Y)
    dataloader = DataLoader(data, batch_size=64, shuffle=True)#25

    our_model = bpl_nn(X.shape[1]).to(output_device)
    optimizer = optim.Adam(our_model.parameters(), lr=1e-3)
    loss_fun = torch.nn.L1Loss()

    best_loss = torch.inf
    best_weights = None
    early_stop_count = 0

    training_metrics = {'tr': [], 'va': []}

    for epoch in (tqdm(range(nepochs)) if display else range(nepochs)):
        our_model.train()
        batch_loss = []
        for Xbatch, Ybatch in dataloader:
            Xbatch, Ybatch = Xbatch.to(output_device), Ybatch.to(output_device)
            pred = our_model(Xbatch)
            loss = loss_fun(pred,Ybatch)
            batch_loss.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        training_metrics['tr'].append(np.mean(batch_loss))

        if val:
            our_model.eval()
            with torch.no_grad():
                Xbatch, Ybatch = X_val.to(output_device), Y_val.to(output_device)
                pred = our_model(Xbatch)
                loss = loss_fun(pred,Ybatch)
                if loss < best_loss:
                    early_stop_count = 0
                    best_loss = loss
                    best_weights = copy.deepcopy(our_model.state_dict())
                else:
                    early_stop_count += 1
            if early_stop_count > 200:
                break
            training_metrics['va'].append(loss.item())

    if val:
        our_model.load_state_dict(best_weights)

    torch.save(our_model.state_dict(), path)

    return copy.deepcopy(our_model.state_dict())


### Training and evaluation of CaliPSo model

In [None]:
config_path = '.'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# Regularization helps prevent overfitting to outliers, particularly for outer quantiles
class param_scheduler:
    def __init__(self, update_rate: int = 50, decay_rate: float = 0.5, param0: float = 1.0):
        self.update_rate = update_rate
        self.decay_rate = decay_rate
        self.param = param0
        self.counter = 0
    def step(self):
        self.counter += 1
        if self.counter > self.update_rate:
            self.counter = 0
            self.param *= self.decay_rate
        return self.param

# Includes several hyperparameters, running the notebook as is will match the configuration reported in the paper.
# dataset can be one of: 'yacht','boston','concrete','energy','kin8nm','power','wine','naval','protein'
# val_recal: Should you recalibrate on a held-out validation set
# display: Whether to output model evaluation metrics
# dims: If >0, applies PCA to the input data with n_components=dims
# vanilla_weights_path: specifies path to L1 weights
# balanced_recal: when true, upsamples the heldout validation set to have the same number of samples as the training set when recalibrating the model
# va_split: defines a validation size, as a fraction of the train+val size
# cali_favoured: when using early stopping based on the best weighted sum of ECE and sharpness, cali_favoured is the weight of the ECE (sharpness has a fixed weight of 1)
# fix_cali: if true, stops based on the ECE achieving a satisfactory level (relative to the Beyond Pinball Loss paper's reported MAQR results), else use the best weighted sum as early stopping criterion
def run_experiment(rep, output_device, nepochs=1000, dataset='boston',val_recal=True,seed=0,val=False,display=True,dims=0,vanilla_weights_path=None,balanced_recal=False,va_split=None,cali_favoured=1,fix_cali=False):
    argdataset = dataset
    np.random.seed(rep)
    torch.manual_seed(rep)

    if fix_cali:
        X, X_val, X_true_val, X_test, Y, Y_val, Y_true_val, Y_test, Y_al = load_data_bpl(dataset, rep, extra_val=0.14)
        print([i.shape for i in (X, X_val, X_true_val, X_test, Y, Y_val, Y_true_val, Y_test, Y_al)])
    else:
        X, X_val, X_test, Y, Y_val, Y_test, Y_al = load_data_bpl(dataset, rep)

    if dims > 0:
        pca = PCA(n_components=dims)
        X = torch.Tensor(pca.fit_transform(X))
        X_val = torch.Tensor(pca.transform(X_val))
        X_test = torch.Tensor(pca.transform(X_test))

    X = X.to(output_device)
    Y = Y.to(output_device)
    X_val = X_val.to(output_device)
    Y_val = Y_val.to(output_device)
    X_test = X_test.to(output_device)
    Y_test = Y_test.to(output_device)
    if fix_cali:
        X_true_val = X_true_val.to(output_device)
        Y_true_val = Y_true_val.to(output_device)

    if not val:
        X=torch.cat((X, X_val), dim=0)
        Y=torch.cat((Y, Y_val), dim=0)

    if va_split is not None:
        X = torch.cat((X, X_val), dim=0)
        Y = torch.cat((Y, Y_val), dim=0)
        split_i = int(len(X) * (1-va_split))
        X_val = X[split_i:]
        Y_val = Y[split_i:]
        X = X[:split_i]
        Y = Y[:split_i]
        print(len(X), len(X_val))

    te_exp_props = torch.linspace(0.01, 0.99, 99).to(output_device)
    y_range = (Y_al.max() - Y_al.min()).item()

    class data_set(Dataset):
        def __init__(self, X, Y):
            self.X = X
            self.Y = Y

        def __len__(self):
            return len(self.X)

        def __getitem__(self, index):
            return self.X[index], self.Y[index]

    dataset = data_set(X=X, Y=Y)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

    half_q_levels = torch.tensor([0, 0.025, 0.05, 0.1, 0.2])
    lambda_reg_vec = torch.tensor([0.05, 0.005, 0.005, 0.005, 0.005])
    scheduler = param_scheduler()
    our_model = quantile_model_ensemble(X=X, Y=Y, vanilla_model=bpl_nn, half_q_levels=half_q_levels, output_device=output_device, vanilla_weights_path=[vanilla_weights_path, vanilla_weights_path, None, None, None])
    optimizer = optim.Adam(our_model.parameters(), lr=1e-3)
    loss_fun = torch.nn.MSELoss()


    our_model.train()

    best_loss = torch.inf
    best_weights = None
    early_stop_count = 0
    best_metric = torch.inf
    best_metric_epoch = None
    best_sharp = torch.inf
    best_sharp_epoch = None
    maqr_ece = {
        'yacht':   (6.8+2.1*2)/100,
        'boston':  (6.2+1.8*2)/100,
        'kin8nm':  (1.8+0.4*2)/100,
        'energy':  (3.5+1.0*2)/100,
        'concrete':(5.3+0.4*2)/100,
        'wine':    (2.7+0.2*2)/100,
        'power':   (1.6+0.3*2)/100,
        'naval':   (2.3+0.2*2)/100,
        'protein': (2.6+0.3*2)/100,
        }

    training_metrics = {metric: {'tr': [], 'va': [], 'te': []} for metric in ['loss', 'cali', 'sharp', 'g_cali', 'ind_cali', 'pinball_loss', 'model_obs', 'model_y']}

    for epoch in (tqdm(range(nepochs)) if display else range(nepochs)):
        our_model.train()
        batch_loss = []
        lambda_reg_vec *= scheduler.step()
        for Xbatch, Ybatch in dataloader:
            Xbatch, Ybatch = Xbatch.to(output_device), Ybatch.to(output_device)
            quantile_preds = our_model(Xbatch)
            loss = loss_fun(quantile_preds,Ybatch.repeat(1, quantile_preds.shape[1]))
            our_model.eval()
            quantile_preds_val = our_model(X_val)
            diff_val = quantile_preds_val - Y_val
            loss_ece = 0
            total_q_levels = torch.hstack([half_q_levels, 1-half_q_levels.flip(0)]).to(output_device)
            for q in torch.arange(total_q_levels.shape[0]):
                loss_ece += torch.quantile(diff_val[:, q], total_q_levels.flip(0)[q])**2
            our_model.train()
            loss += loss_ece
            for i in range(len(half_q_levels)):
                lambda_reg = lambda_reg_vec[i]
                if lambda_reg != 0:
                    for params in our_model.lower_quantile_models[i].parameters():
                        loss += lambda_reg*params.norm()
                    for params in our_model.upper_quantile_models[i].parameters():
                        loss += lambda_reg*params.norm()
            batch_loss.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        our_model.eval()
        training_metrics['loss']['tr'].append(np.mean(batch_loss))
        tr_cali_score,tr_sharp_score,_,_,tr_g_cali_scores, tr_indiv_cali_score, pinball_loss,model_obs,model_y = test_uq(our_model,X,Y,
            te_exp_props,y_range,recal_model=None,recal_type=None,test_group_cal=False,)
        training_metrics['cali']['tr'].append(tr_cali_score)
        training_metrics['sharp']['tr'].append(tr_sharp_score)
        training_metrics['g_cali']['tr'].append(tr_g_cali_scores)
        training_metrics['ind_cali']['tr'].append(tr_indiv_cali_score)
        training_metrics['pinball_loss']['tr'].append(pinball_loss)
        training_metrics['model_obs']['tr'].append(model_obs)
        training_metrics['model_y']['tr'].append(model_y)

        if val:
            with torch.no_grad():
                Xbatch, Ybatch = X_val.to(output_device), Y_val.to(output_device)
                quantile_preds = our_model(Xbatch)
                loss = loss_fun(quantile_preds,Ybatch.repeat(1, quantile_preds.shape[1]))
                if loss < best_loss:
                    best_loss = loss
                    best_loss_epoch = epoch
                else:
                    pass

            training_metrics['loss']['va'].append(loss.item())
            if fix_cali:
                va_cali_score,va_sharp_score,_,_,va_g_cali_scores,va_indiv_cali_score,pinball_loss,model_obs,model_y = test_uq(our_model,X_true_val,Y_true_val,
                te_exp_props,y_range,recal_model=None,recal_type=None,test_group_cal=False,)
            else:
                va_cali_score,va_sharp_score,_,_,va_g_cali_scores,va_indiv_cali_score,pinball_loss,model_obs,model_y = test_uq(our_model,X_val,Y_val,
                te_exp_props,y_range,recal_model=None,recal_type=None,test_group_cal=False,)
            training_metrics['cali']['va'].append(va_cali_score)
            training_metrics['sharp']['va'].append(va_sharp_score)
            training_metrics['g_cali']['va'].append(va_g_cali_scores)
            training_metrics['ind_cali']['va'].append(va_indiv_cali_score)
            training_metrics['pinball_loss']['va'].append(pinball_loss)
            training_metrics['model_obs']['va'].append(model_obs)
            training_metrics['model_y']['va'].append(model_y)

            metric = (va_cali_score*cali_favoured+va_sharp_score)
            if metric < best_metric:
                if not fix_cali:
                    early_stop_count = 0
                    best_weights = copy.deepcopy(our_model.state_dict())
                best_metric = metric
                best_metric_epoch = epoch
            elif not fix_cali:
                early_stop_count += 1

            if early_stop_count > 200:
                break

            # Compute metrics for test data
            Xbatch, Ybatch = X_test.to(output_device), Y_test.to(output_device)
            quantile_preds = our_model(X_test.to(output_device))
            loss = loss_fun(quantile_preds,Y_test.to(output_device).repeat(1, quantile_preds.shape[1]))

            training_metrics['loss']['te'].append(loss.item())
            te_cali_score,te_sharp_score,_,_,te_g_cali_scores,te_indiv_cali_score,pinball_loss,model_obs,model_y  = test_uq(our_model,X_test,Y_test,
                te_exp_props,y_range,recal_model=None,recal_type=None,test_group_cal=False,)
            training_metrics['cali']['te'].append(te_cali_score)
            training_metrics['sharp']['te'].append(te_sharp_score)
            training_metrics['g_cali']['te'].append(te_g_cali_scores)
            training_metrics['ind_cali']['te'].append(te_indiv_cali_score)
            training_metrics['pinball_loss']['te'].append(pinball_loss)
            training_metrics['model_obs']['te'].append(model_obs)
            training_metrics['model_y']['te'].append(model_y)

            if fix_cali:
                if va_cali_score <= maqr_ece[argdataset] and va_sharp_score < best_sharp:
                    early_stop_count = 0
                    best_sharp = va_sharp_score
                    best_sharp_epoch = epoch
                    best_weights = copy.deepcopy(our_model.state_dict())
                else:
                    early_stop_count += 1

    ## RECALIBRATE QUANTILES NOT CORRESPONDING TO ENSEMBLE MEMBERS
    if val:
        print("Stopping epochs:", best_loss_epoch, best_metric_epoch, best_sharp_epoch)
        if best_weights is not None:
            our_model.load_state_dict(best_weights)

    if balanced_recal:
        if fix_cali:
            X_val = torch.cat((X_val, X_true_val))
            Y_val = torch.cat((Y_val, Y_true_val))
        our_model.recalibrate((torch.cat((X_val, X_val, X_val, X_val)),torch.cat((Y_val, Y_val, Y_val, Y_val))) if (val_recal and val) else None)
    else:
        our_model.recalibrate((X_val,Y_val) if (val_recal and val) else None)

    ## OUR MODEL
    stats = []
    stats.append(run_diagnostics(X_test, Y_test, our_model,display=False))
    stats = np.stack(stats, axis=0)
    (
        te_cali_score,
        te_sharp_score,
        te_obs_props,
        te_q_preds,
        te_g_cali_scores,
        te_indiv_cali_scores,
        pinball_loss, model_obs, model_y,
    ) = test_uq(
        our_model,
        X_test,
        Y_test,
        te_exp_props,
        y_range,
        recal_model=None,
        recal_type=None,
        test_group_cal=True,
        display=False
    )
    if display:
        print("ECE       :",te_cali_score)
        print("Sharp     :",te_sharp_score)
        print("Group Cal :",te_g_cali_scores)
        print("Indiv Cal:", te_indiv_cali_scores)
        print("Pinball   :",pinball_loss)
        print("Actual Obs:",model_obs)
        print("Mean Y    :",model_y)
        print(np.mean(training_metrics['cali']['va']), np.mean(training_metrics['cali']['te']))
    if dims > 0:
        savepath = f'{config_path}/results_vis_val/{argdataset}/{str(val_recal)}/{seed}/'
        from pathlib import Path
        Path(savepath).mkdir(parents=True, exist_ok=True)
        with open(savepath + f'vis_{dims}.pkl', 'wb') as f:
            pickle.dump(best_weights, f)
    return (te_cali_score, te_sharp_score, te_g_cali_scores, te_indiv_cali_scores, pinball_loss, model_obs, model_y), stats, training_metrics, our_model.state_dict()#, (ece, sharpness)


def main(datasets, seeds, device):
    run_name = 'results'
    Path(f'{config_path}/{run_name}/').mkdir(parents=True, exist_ok=True)
    for seed in seeds:
        print(f"\n{seed}")
        for dataset in datasets:
            try:
                savepath = f'{config_path}/{run_name}/{dataset}/{seed}/' #parallel_1
                # if os.path.exists(savepath):
                #     print("Skipping", dataset, val_recal, seed)
                #     continue
                print(f"{dataset}")

                Path(f'{config_path}/results/{dataset}').mkdir(parents=True, exist_ok=True)
                weights_path = f'{config_path}/results/{dataset}/{seed}_extra_val_0.14_l1.pt'
                if not os.path.exists(weights_path):
                    gen_model(dataset, seed, weights_path, device, extra_val=0.14)
                outs = run_experiment(seed, device, nepochs=1000, dataset=dataset, val_recal=True, val=True, display=True, vanilla_weights_path=weights_path, balanced_recal=True, fix_cali=True)

                Path(savepath).mkdir(parents=True, exist_ok=True)
                with open(savepath + f'_ours_model.pkl', 'wb') as f:
                    pickle.dump(outs[:-1], f)

                torch.save(outs[-1], savepath + f'state_dict.pt')
            except Exception as e:
                print(traceback.format_exc())

### Execute training and testing for selected datasets and seeds

In [None]:
datasets = ['boston','concrete','energy']#['yacht','boston','concrete','energy','kin8nm','power','wine','naval','protein']
seeds = range(5)
main(datasets, seeds, 'cuda')


0
boston


 50%|████▉     | 499/1000 [00:06<00:06, 73.48it/s]


[torch.Size([313, 13]), torch.Size([91, 13]), torch.Size([51, 13]), torch.Size([51, 13]), torch.Size([313, 1]), torch.Size([91, 1]), torch.Size([51, 1]), torch.Size([51, 1]), torch.Size([506, 1])]


 45%|████▌     | 451/1000 [06:29<07:53,  1.16it/s]


Stopping epochs: 423 95 249
ECE       : 0.12060607224702835
Sharp     : 0.10497283052775898
Group Cal : [0.49137374 0.35387207 0.27689395 0.24813014 0.22658806 0.22013653
 0.2177047  0.19702273 0.2052155  0.18910993]
Indiv Cal: tensor([-0.2253, -0.2153, -0.2053, -0.1953, -0.2049, -0.1949, -0.2241, -0.2141,
        -0.2433, -0.2333, -0.2233, -0.2329, -0.2229, -0.2129, -0.2029, -0.1929,
        -0.1829, -0.1729, -0.1629, -0.1529, -0.1429, -0.1329, -0.1229, -0.1129,
        -0.1029, -0.1322, -0.1614, -0.1514, -0.1414, -0.1902, -0.1802, -0.1702,
        -0.1602, -0.1502, -0.1794, -0.1890, -0.1790, -0.1690, -0.1590, -0.1490,
        -0.1390, -0.1290, -0.1190, -0.1090, -0.0990, -0.0890, -0.0986, -0.0886,
        -0.0786, -0.0686, -0.0586, -0.0682, -0.0778, -0.0678, -0.0775, -0.0675,
        -0.0967, -0.0867, -0.0767, -0.0863, -0.0763, -0.1055, -0.0955, -0.0855,
        -0.0755, -0.0655, -0.0555, -0.0455, -0.0355, -0.0255, -0.0155, -0.0055,
         0.0045,  0.0145,  0.0245,  0.0345,  0.0445,

 30%|██▉       | 297/1000 [00:03<00:07, 89.14it/s]


[torch.Size([313, 13]), torch.Size([91, 13]), torch.Size([51, 13]), torch.Size([51, 13]), torch.Size([313, 1]), torch.Size([91, 1]), torch.Size([51, 1]), torch.Size([51, 1]), torch.Size([506, 1])]


 46%|████▌     | 459/1000 [06:41<07:53,  1.14it/s]


Stopping epochs: 444 188 257
ECE       : 0.07656566798686981
Sharp     : 0.13125882895887941
Group Cal : [0.48758586 0.30934008 0.2312761  0.2176845  0.18943918 0.17917798
 0.17229412 0.15690657 0.14734905 0.14215925]
Indiv Cal: tensor([-0.1665, -0.1565, -0.1661, -0.1757, -0.1657, -0.1557, -0.1457, -0.1749,
        -0.1649, -0.1549, -0.1449, -0.1545, -0.1445, -0.1345, -0.1441, -0.1341,
        -0.1437, -0.1337, -0.1237, -0.1137, -0.1233, -0.1133, -0.1229, -0.1325,
        -0.1225, -0.1125, -0.1025, -0.0925, -0.0825, -0.0725, -0.0625, -0.0525,
        -0.0622, -0.0522, -0.0422, -0.0322, -0.0222, -0.0122, -0.0218, -0.0510,
        -0.0606, -0.0506, -0.0406, -0.0502, -0.0402, -0.0302, -0.0202, -0.0298,
        -0.0590, -0.0490, -0.0586, -0.0682, -0.0582, -0.0482, -0.0382, -0.0478,
        -0.0575, -0.0475, -0.0375, -0.0471, -0.0371, -0.0271, -0.0367, -0.0267,
        -0.0167, -0.0263, -0.0163, -0.0063,  0.0037,  0.0137,  0.0237,  0.0337,
         0.0437,  0.0537,  0.0637,  0.0345,  0.0053

 30%|███       | 303/1000 [00:03<00:07, 88.34it/s]


[torch.Size([313, 13]), torch.Size([91, 13]), torch.Size([51, 13]), torch.Size([51, 13]), torch.Size([313, 1]), torch.Size([91, 1]), torch.Size([51, 1]), torch.Size([51, 1]), torch.Size([506, 1])]


 62%|██████▎   | 625/1000 [08:52<05:19,  1.17it/s]


Stopping epochs: 618 217 423
ECE       : 0.09260052442550659
Sharp     : 0.12664088626755513
Group Cal : [0.47949495 0.35344109 0.25516499 0.24360191 0.22745324 0.17955137
 0.18590137 0.17204041 0.17451404 0.1688075 ]
Indiv Cal: tensor([-0.2253, -0.2153, -0.2053, -0.2149, -0.2049, -0.2145, -0.2437, -0.2337,
        -0.2237, -0.2137, -0.2037, -0.1937, -0.1837, -0.1933, -0.1833, -0.1929,
        -0.1829, -0.1729, -0.1825, -0.1725, -0.1625, -0.1722, -0.1622, -0.1522,
        -0.1422, -0.1322, -0.1222, -0.1318, -0.1414, -0.1314, -0.1214, -0.1114,
        -0.1014, -0.0914, -0.0814, -0.0714, -0.0614, -0.0514, -0.0414, -0.0510,
        -0.0410, -0.0310, -0.0210, -0.0110, -0.0206, -0.0106, -0.0006,  0.0094,
        -0.0198, -0.0098, -0.0194, -0.0486, -0.0582, -0.0482, -0.0382, -0.0282,
        -0.0182, -0.0082, -0.0178, -0.0471, -0.0371, -0.0467, -0.0367, -0.0463,
        -0.0363, -0.0263, -0.0163, -0.0063,  0.0037,  0.0137,  0.0237,  0.0141,
         0.0241,  0.0341,  0.0441,  0.0541,  0.0445

 34%|███▍      | 342/1000 [00:04<00:08, 79.40it/s]


[torch.Size([313, 13]), torch.Size([91, 13]), torch.Size([51, 13]), torch.Size([51, 13]), torch.Size([313, 1]), torch.Size([91, 1]), torch.Size([51, 1]), torch.Size([51, 1]), torch.Size([506, 1])]


 65%|██████▍   | 649/1000 [09:25<05:05,  1.15it/s]


Stopping epochs: 618 161 447
ECE       : 0.10099821537733078
Sharp     : 0.10584715961248406
Group Cal : [0.49459596 0.35019529 0.26053872 0.22044683 0.21495257 0.19809022
 0.18113547 0.18966162 0.17507744 0.16402654]
Indiv Cal: tensor([-0.1861, -0.1761, -0.1661, -0.1561, -0.1461, -0.1361, -0.1457, -0.1357,
        -0.1257, -0.1353, -0.1449, -0.1349, -0.1249, -0.1149, -0.1049, -0.0949,
        -0.0849, -0.0749, -0.0649, -0.0549, -0.0449, -0.0349, -0.0445, -0.0345,
        -0.0245, -0.0145, -0.0045,  0.0055,  0.0155,  0.0255,  0.0355,  0.0455,
         0.0555,  0.0459,  0.0559,  0.0659,  0.0563,  0.0467,  0.0567,  0.0667,
         0.0178,  0.0278,  0.0378,  0.0086,  0.0186,  0.0090,  0.0190,  0.0290,
         0.0194,  0.0294,  0.0394,  0.0494,  0.0594,  0.0694,  0.0598,  0.0698,
         0.0798,  0.0898,  0.0802,  0.0902,  0.1002,  0.1102,  0.1006,  0.1106,
         0.1206,  0.1110,  0.1210,  0.1310,  0.1018,  0.1118,  0.1218,  0.1318,
         0.1418,  0.1518,  0.1225,  0.1325,  0.1425

 62%|██████▏   | 616/1000 [00:07<00:04, 81.00it/s]


[torch.Size([313, 13]), torch.Size([91, 13]), torch.Size([51, 13]), torch.Size([51, 13]), torch.Size([313, 1]), torch.Size([91, 1]), torch.Size([51, 1]), torch.Size([51, 1]), torch.Size([506, 1])]


 31%|███       | 306/1000 [04:19<09:47,  1.18it/s]


Stopping epochs: 306 104 104
ECE       : 0.10713012516498566
Sharp     : 0.1450740498252923
Group Cal : [0.47612121 0.33688217 0.28115152 0.24860073 0.22577954 0.20873528
 0.18084611 0.19314395 0.17890686 0.17426224]
Indiv Cal: tensor([-0.1665, -0.1565, -0.1465, -0.1365, -0.1265, -0.1361, -0.1849, -0.1945,
        -0.1845, -0.1941, -0.2037, -0.2133, -0.2229, -0.2129, -0.2225, -0.2322,
        -0.2222, -0.2122, -0.2022, -0.1922, -0.1822, -0.1918, -0.2014, -0.1914,
        -0.2010, -0.1910, -0.1810, -0.1710, -0.1610, -0.1510, -0.1410, -0.1310,
        -0.1602, -0.1698, -0.1598, -0.1694, -0.1594, -0.1690, -0.1590, -0.1490,
        -0.1390, -0.1486, -0.1386, -0.1286, -0.1186, -0.1086, -0.0986, -0.1082,
        -0.0982, -0.0882, -0.0782, -0.0682, -0.0582, -0.0482, -0.0382, -0.0282,
        -0.0182, -0.0278, -0.0178, -0.0078,  0.0022,  0.0122,  0.0025, -0.0071,
        -0.0167, -0.0067, -0.0163, -0.0063,  0.0037, -0.0059,  0.0041,  0.0141,
         0.0241,  0.0341,  0.0441,  0.0541,  0.0641,