In [1]:
import logging, tqdm
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataset import TensorDataset, Dataset
from torch.utils.data.dataloader import DataLoader

import copy
import time, yaml
import torch.nn.functional as F
import pandas as pd, numpy as np
from collections import defaultdict, OrderedDict
import matplotlib.pyplot as plt

from reliability import reliability_diagram, reliability_diagrams, compute_calibration

In [2]:
def get_device():
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    return device

def one_hot_embedding(labels, num_classes=10):
    # Convert to One Hot Encoding
    y = torch.eye(num_classes)
    return y[labels]

In [3]:
def relu_evidence(y):
    return F.relu(y)


def exp_evidence(y):
    return torch.exp(torch.clamp(y, -10, 10))


def softplus_evidence(y):
    return F.softplus(y)


def kl_divergence(alpha, num_classes, device=None):
    if not device:
        device = get_device()
    ones = torch.ones([1, num_classes], dtype=torch.float32, device=device)
    sum_alpha = torch.sum(alpha, dim=1, keepdim=True)
    first_term = (
        torch.lgamma(sum_alpha)
        - torch.lgamma(alpha).sum(dim=1, keepdim=True)
        + torch.lgamma(ones).sum(dim=1, keepdim=True)
        - torch.lgamma(ones.sum(dim=1, keepdim=True))
    )
    second_term = (
        (alpha - ones)
        .mul(torch.digamma(alpha) - torch.digamma(sum_alpha))
        .sum(dim=1, keepdim=True)
    )
    kl = first_term + second_term
    return kl


def loglikelihood_loss(y, alpha, device=None):
    if not device:
        device = get_device()
    y = y.to(device)
    alpha = alpha.to(device)
    S = torch.sum(alpha, dim=1, keepdim=True)
    loglikelihood_err = torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True)
    loglikelihood_var = torch.sum(
        alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True
    )
    loglikelihood = loglikelihood_err + loglikelihood_var
    return loglikelihood


def mse_loss(y, alpha, epoch_num, num_classes, annealing_step, device=None):
    if not device:
        device = get_device()
    y = y.to(device)
    alpha = alpha.to(device)
    loglikelihood = loglikelihood_loss(y, alpha, device=device)

    annealing_coef = torch.min(
        torch.tensor(1.0, dtype=torch.float32),
        torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
    )

    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
    return loglikelihood + kl_div


def edl_loss(func, y, alpha, epoch_num, num_classes, annealing_step, device=None):
    y = y.to(device)
    alpha = alpha.to(device)
    S = torch.sum(alpha, dim=1, keepdim=True)

    A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)

    annealing_coef = torch.min(
        torch.tensor(1.0, dtype=torch.float32),
        torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
    )

    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
    return A + kl_div


def edl_mse_loss(output, target, epoch_num, num_classes, annealing_step, device=None):
    if not device:
        device = get_device()
    evidence = relu_evidence(output)
    alpha = evidence + 1
    loss = torch.mean(
        mse_loss(target, alpha, epoch_num, num_classes, annealing_step, device=device)
    )
    return loss


def edl_log_loss(output, target, epoch_num, num_classes, annealing_step, device=None):
    if not device:
        device = get_device()
    evidence = relu_evidence(output)
    alpha = evidence + 1
    loss = torch.mean(
        edl_loss(
            torch.log, target, alpha, epoch_num, num_classes, annealing_step, device
        )
    )
    return loss


def edl_digamma_loss(
    output, target, epoch_num, num_classes, annealing_step, device=None
):
    if not device:
        device = get_device()
    evidence = relu_evidence(output)
    alpha = evidence + 1
    loss = torch.mean(
        edl_loss(
            torch.digamma, target, alpha, epoch_num, num_classes, annealing_step, device
        )
    )
    return loss

In [32]:
def train_model(
    model,
    dataloaders,
    num_classes,
    criterion,
    optimizer,
    scheduler=None,
    num_epochs=25,
    device=None,
    uncertainty=False,
    metric="accuracy"
):

    since = time.time()

    if not device:
        device = get_device()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    training_results = defaultdict(list)
    for epoch in range(num_epochs):
        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                #print("Training...")
                model.train()  # Set model to training mode
            else:
                #print("Validating...")
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            correct = 0
            
            if verbose:
                total = int(np.ceil(len(dataloaders[phase].dataset) / batch_size))
                my_iter = tqdm.tqdm(enumerate(dataloaders[phase]),
                                total = total,
                                leave = True)
            else:
                my_iter = enumerate(dataloaders[phase])

            # Iterate over data.
            results_dict = defaultdict(list)
            for i, (inputs, labels) in my_iter:

                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == "train"):

                    if uncertainty:
                        y = one_hot_embedding(labels, num_classes)
                        y = y.to(device)
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(
                            outputs, y.float(), epoch, num_classes, 10, device
                        )

                        match = torch.reshape(torch.eq(preds, labels).float(), (-1, 1))
                        acc = torch.mean(match)
                        evidence = relu_evidence(outputs)
                        alpha = evidence + 1
                        u = num_classes / torch.sum(alpha, dim=1, keepdim=True)

                        total_evidence = torch.sum(evidence, 1, keepdim=True)
                        mean_evidence = torch.mean(total_evidence)
                        mean_evidence_succ = torch.sum(
                            torch.sum(evidence, 1, keepdim=True) * match
                        ) / torch.sum(match + 1e-20)
                        mean_evidence_fail = torch.sum(
                            torch.sum(evidence, 1, keepdim=True) * (1 - match)
                        ) / (torch.sum(torch.abs(1 - match)) + 1e-20)

                    else:
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # statistics
                results_dict["loss"].append(loss.item())
                results_dict["acc"].append(torch.mean((preds == labels.data).float()).item())
                
                if verbose:
                    print_str = f"Epoch: {epoch} "
                    print_str += f'{phase}_loss: {np.mean(results_dict["loss"]):.4f} '
                    print_str += f'{phase}_acc: {np.mean(results_dict["acc"]):.4f}'
                    my_iter.set_description(print_str)
                    my_iter.refresh()

            epoch_loss = np.mean(results_dict["loss"])
            epoch_acc = np.mean(results_dict["acc"])
            
            if phase == "train":
                training_results["train_loss"].append(epoch_loss)
                training_results["train_acc"].append(epoch_acc)
            else:
                training_results["valid_loss"].append(epoch_loss)
                training_results["valid_acc"].append(epoch_acc)
            
            training_results["epoch"].append(epoch)
            
            if scheduler is not None:
                if phase == "val":
                    scheduler.step(1-epoch_acc)

            # deep copy the model
            if phase == "val" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
            # Stop training if we have not improved after X epochs
            if phase == "val":
                best_epoch = [i for i,j in enumerate(
                    training_results["valid_acc"]) if j == max(training_results["valid_acc"])][0]
                offset = epoch - best_epoch
                if offset >= stopping_patience:
                    break

    time_elapsed = time.time() - since
    print(
        "Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60
        )
    )
    print("Best val Acc: {:4f}".format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)

    return model, optimizer, training_results

In [5]:
config = "config/p-type.yml"
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [6]:
df = pd.read_parquet(conf['data_path'])

In [7]:
features = conf['tempvars'] + conf['tempdewvars'] + conf['ugrdvars'] + conf['vgrdvars']
outputs = conf['outputvars']
num_classes = len(outputs)
n_splits = conf['trainer']['n_splits']
train_size1 = conf['trainer']['train_size1'] # sets test size
train_size2 = conf['trainer']['train_size2'] # sets valid size
num_hidden_layers = conf['trainer']['num_hidden_layers']
hidden_sizes = conf['trainer']['hidden_sizes']
dropout_rate = conf['trainer']['dropout_rate']
batch_size = conf['trainer']['batch_size']
learning_rate = conf['trainer']['learning_rate']
metrics = conf['trainer']['metrics']
run_eagerly = conf['trainer']['run_eagerly']
shuffle = conf['trainer']['shuffle']
epochs = conf['trainer']['epochs']

lr_patience = 3
stopping_patience = 10

verbose = True

loss = "digamma"
use_uncertainty = False if loss == "ce" else True

In [8]:
#split and preprocess the data
df['day'] = df['datetime'].apply(lambda x: str(x).split(' ')[0])

splitter = GroupShuffleSplit(n_splits=n_splits, train_size=train_size1)
train_idx, test_idx = list(splitter.split(df, groups=df['day']))[0]
train_data, test_data = df.iloc[train_idx], df.iloc[test_idx]

splitter = GroupShuffleSplit(n_splits=n_splits, train_size=train_size2)
train_idx, valid_idx = list(splitter.split(train_data, groups=train_data['day']))[0]
train_data, valid_data = train_data.iloc[train_idx], train_data.iloc[valid_idx]

scaler_x = StandardScaler()
x_train = scaler_x.fit_transform(train_data[features])
x_valid = scaler_x.transform(valid_data[features])
x_test = scaler_x.transform(test_data[features])
y_train = np.argmax(train_data[outputs].to_numpy(), 1)
y_valid = np.argmax(valid_data[outputs].to_numpy(), 1)
y_test = np.argmax(test_data[outputs].to_numpy(), 1)

In [9]:
train_split = TensorDataset(
    torch.from_numpy(x_train).float(),
    torch.from_numpy(y_train).long()
)
train_loader = DataLoader(train_split, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          num_workers=0)

valid_split = TensorDataset(
    torch.from_numpy(x_valid).float(),
    torch.from_numpy(y_valid).long()
)
valid_loader = DataLoader(valid_split, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          num_workers=0)

dataloaders = {
    "train": train_loader,
    "val": valid_loader,
}

In [None]:
class PrecipNet(nn.Module):
    def __init__(self, dropout=False):
        super().__init__()
        self.use_dropout = dropout
        self.conv1 = nn.Conv1d(1, 20, kernel_size=5)
        self.conv2 = nn.Conv1d(20, 50, kernel_size=5)
        self.fc1 = nn.Linear(20000, 500)
        self.fc2 = nn.Linear(500, 4)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 1))
        x = F.relu(F.max_pool2d(self.conv2(x), 1))
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        if self.use_dropout:
            x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x
    
def load_mlp_model(input_size, middle_size, output_size, dropout):
    model = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(input_size, middle_size)),
            #nn.BatchNorm1d(middle_size),
            nn.Dropout(dropout),
            nn.LeakyReLU(),
            #nn.Tanh(),
            nn.utils.spectral_norm(nn.Linear(middle_size, output_size))
    ) 
    return model

In [20]:
if use_uncertainty:
    if loss == "digamma":
        criterion = edl_digamma_loss
    elif loss == "log":
        criterion = edl_log_loss
    elif loss == "mse":
        criterion = edl_mse_loss
    else:
        logging.error("--uncertainty requires --mse, --log or --digamma.")
else:
    criterion = nn.CrossEntropyLoss()

In [33]:
model = load_mlp_model(len(features), 100, len(outputs), dropout_rate) #PrecipNet()

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005)

#exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    patience = lr_patience, 
    verbose = verbose,
    min_lr = 1.0e-13
)

In [34]:
device = get_device()
model = model.to(device)

In [35]:
model, optimizer, metrics = train_model(
    model,
    dataloaders,
    num_classes,
    criterion,
    optimizer,
    scheduler=lr_scheduler,
    num_epochs=100,
    device=device,
    uncertainty=use_uncertainty,
)

Epoch: 0 train_loss: 0.8490 train_acc: 0.7853: 100%|██████████| 172/172 [00:08<00:00, 20.77it/s]
Epoch: 0 val_loss: 0.7608 val_acc: 0.8108: 100%|██████████| 23/23 [00:00<00:00, 28.95it/s]
Epoch: 1 train_loss: 0.7918 train_acc: 0.8053: 100%|██████████| 172/172 [00:08<00:00, 21.08it/s]
Epoch: 1 val_loss: 0.7819 val_acc: 0.8168: 100%|██████████| 23/23 [00:00<00:00, 28.89it/s]
Epoch: 2 train_loss: 0.8018 train_acc: 0.8127: 100%|██████████| 172/172 [00:08<00:00, 20.97it/s]
Epoch: 2 val_loss: 0.8112 val_acc: 0.8153: 100%|██████████| 23/23 [00:00<00:00, 28.88it/s]
Epoch: 3 train_loss: 0.8216 train_acc: 0.8158: 100%|██████████| 172/172 [00:08<00:00, 20.76it/s]
Epoch: 3 val_loss: 0.7983 val_acc: 0.8283: 100%|██████████| 23/23 [00:00<00:00, 31.45it/s]
Epoch: 4 train_loss: 0.8400 train_acc: 0.8183: 100%|██████████| 172/172 [00:08<00:00, 20.96it/s]
Epoch: 4 val_loss: 0.8354 val_acc: 0.8258: 100%|██████████| 23/23 [00:00<00:00, 28.76it/s]
Epoch: 5 train_loss: 0.8595 train_acc: 0.8199: 100%|████████

Epoch 00015: reducing learning rate of group 0 to 1.0000e-04.


Epoch: 15 train_loss: 0.9240 train_acc: 0.8268: 100%|██████████| 172/172 [00:08<00:00, 20.92it/s]
Epoch: 15 val_loss: 0.9153 val_acc: 0.8329: 100%|██████████| 23/23 [00:00<00:00, 28.18it/s]
Epoch: 16 train_loss: 0.9161 train_acc: 0.8290: 100%|██████████| 172/172 [00:08<00:00, 20.82it/s]
Epoch: 16 val_loss: 0.9183 val_acc: 0.8320: 100%|██████████| 23/23 [00:00<00:00, 31.46it/s]
Epoch: 17 train_loss: 0.9157 train_acc: 0.8294: 100%|██████████| 172/172 [00:08<00:00, 20.82it/s]
Epoch: 17 val_loss: 0.9164 val_acc: 0.8329: 100%|██████████| 23/23 [00:00<00:00, 28.93it/s]
Epoch: 18 train_loss: 0.9160 train_acc: 0.8294: 100%|██████████| 172/172 [00:08<00:00, 20.97it/s]
Epoch: 18 val_loss: 0.9136 val_acc: 0.8345: 100%|██████████| 23/23 [00:00<00:00, 29.07it/s]
Epoch: 19 train_loss: 0.9155 train_acc: 0.8294: 100%|██████████| 172/172 [00:08<00:00, 20.87it/s]
Epoch: 19 val_loss: 0.9125 val_acc: 0.8341: 100%|██████████| 23/23 [00:00<00:00, 28.51it/s]
Epoch: 20 train_loss: 0.9159 train_acc: 0.8294: 10

Epoch 00031: reducing learning rate of group 0 to 1.0000e-05.


Epoch: 31 train_loss: 0.9126 train_acc: 0.8310: 100%|██████████| 172/172 [00:08<00:00, 20.36it/s]
Epoch: 31 val_loss: 0.9101 val_acc: 0.8344: 100%|██████████| 23/23 [00:00<00:00, 28.75it/s]
Epoch: 32 train_loss: 0.9113 train_acc: 0.8316: 100%|██████████| 172/172 [00:08<00:00, 20.59it/s]
Epoch: 32 val_loss: 0.9113 val_acc: 0.8352: 100%|██████████| 23/23 [00:00<00:00, 28.01it/s]
Epoch: 33 train_loss: 0.9109 train_acc: 0.8314: 100%|██████████| 172/172 [00:08<00:00, 20.56it/s]
Epoch: 33 val_loss: 0.9101 val_acc: 0.8352: 100%|██████████| 23/23 [00:00<00:00, 27.62it/s]
Epoch: 34 train_loss: 0.9111 train_acc: 0.8315: 100%|██████████| 172/172 [00:08<00:00, 20.57it/s]
Epoch: 34 val_loss: 0.9106 val_acc: 0.8346: 100%|██████████| 23/23 [00:00<00:00, 28.52it/s]


Epoch 00035: reducing learning rate of group 0 to 1.0000e-06.


Epoch: 35 train_loss: 0.9110 train_acc: 0.8315: 100%|██████████| 172/172 [00:08<00:00, 20.55it/s]
Epoch: 35 val_loss: 0.9112 val_acc: 0.8348: 100%|██████████| 23/23 [00:00<00:00, 29.06it/s]
Epoch: 36 train_loss: 0.9107 train_acc: 0.8318: 100%|██████████| 172/172 [00:08<00:00, 20.60it/s]
Epoch: 36 val_loss: 0.9108 val_acc: 0.8352: 100%|██████████| 23/23 [00:00<00:00, 28.06it/s]
Epoch: 37 train_loss: 0.9105 train_acc: 0.8318: 100%|██████████| 172/172 [00:08<00:00, 20.42it/s]
Epoch: 37 val_loss: 0.9105 val_acc: 0.8353: 100%|██████████| 23/23 [00:00<00:00, 28.54it/s]
Epoch: 38 train_loss: 0.9104 train_acc: 0.8318: 100%|██████████| 172/172 [00:08<00:00, 20.57it/s]
Epoch: 38 val_loss: 0.9113 val_acc: 0.8349: 100%|██████████| 23/23 [00:00<00:00, 28.55it/s]


Epoch 00039: reducing learning rate of group 0 to 1.0000e-07.


Epoch: 39 train_loss: 0.9108 train_acc: 0.8317: 100%|██████████| 172/172 [00:08<00:00, 20.46it/s]
Epoch: 39 val_loss: 0.9110 val_acc: 0.8349: 100%|██████████| 23/23 [00:00<00:00, 27.77it/s]
Epoch: 40 train_loss: 0.9104 train_acc: 0.8317: 100%|██████████| 172/172 [00:08<00:00, 20.36it/s]
Epoch: 40 val_loss: 0.9108 val_acc: 0.8349: 100%|██████████| 23/23 [00:00<00:00, 31.24it/s]
Epoch: 41 train_loss: 0.9104 train_acc: 0.8318: 100%|██████████| 172/172 [00:08<00:00, 20.47it/s]
Epoch: 41 val_loss: 0.9107 val_acc: 0.8349: 100%|██████████| 23/23 [00:00<00:00, 27.83it/s]
Epoch: 42 train_loss: 0.9105 train_acc: 0.8318: 100%|██████████| 172/172 [00:08<00:00, 20.53it/s]
Epoch: 42 val_loss: 0.9110 val_acc: 0.8350: 100%|██████████| 23/23 [00:00<00:00, 28.98it/s]


Epoch 00043: reducing learning rate of group 0 to 1.0000e-08.


Epoch: 43 train_loss: 0.9109 train_acc: 0.8317: 100%|██████████| 172/172 [00:08<00:00, 20.58it/s]
Epoch: 43 val_loss: 0.9111 val_acc: 0.8349: 100%|██████████| 23/23 [00:00<00:00, 27.82it/s]
Epoch: 44 train_loss: 0.9110 train_acc: 0.8318: 100%|██████████| 172/172 [00:08<00:00, 20.47it/s]
Epoch: 44 val_loss: 0.9111 val_acc: 0.8350: 100%|██████████| 23/23 [00:00<00:00, 28.60it/s]
Epoch: 45 train_loss: 0.9107 train_acc: 0.8318: 100%|██████████| 172/172 [00:08<00:00, 20.64it/s]
Epoch: 45 val_loss: 0.9111 val_acc: 0.8350: 100%|██████████| 23/23 [00:00<00:00, 27.86it/s]
Epoch: 46 train_loss: 0.9107 train_acc: 0.8319: 100%|██████████| 172/172 [00:08<00:00, 20.52it/s]
Epoch: 46 val_loss: 0.9111 val_acc: 0.8350: 100%|██████████| 23/23 [00:00<00:00, 27.96it/s]
Epoch: 47 train_loss: 0.9108 train_acc: 0.8317: 100%|██████████| 172/172 [00:08<00:00, 20.70it/s]
Epoch: 47 val_loss: 0.9111 val_acc: 0.8350: 100%|██████████| 23/23 [00:00<00:00, 27.95it/s]
Epoch: 48 train_loss: 0.9108 train_acc: 0.8318: 10

Training complete in 15m 15s
Best val Acc: 0.836696


### Evaluate

In [None]:
test_split = TensorDataset(
    torch.from_numpy(x_test).float(),
    torch.from_numpy(y_test).long()
)
test_loader = DataLoader(test_split, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          num_workers=0)

In [None]:
model.eval()
with torch.no_grad():

    if verbose:
        total = int(np.ceil(len(test_loader.dataset) / batch_size))
        my_iter = tqdm.tqdm(enumerate(test_loader),
                        total = total,
                        leave = True)
    else:
        my_iter = enumerate(test_loader)

    # Iterate over data.
    results_dict = defaultdict(list)
    for i, (inputs, labels) in my_iter:

        inputs = inputs.to(device)
        labels = labels.to(device)

        if use_uncertainty:
            output = model(inputs)
            evidence = relu_evidence(output)
            alpha = evidence + 1
            uncertainty = num_classes / torch.sum(alpha, dim=1, keepdim=True)
            _, preds = torch.max(output, 1)
            prob = alpha / torch.sum(alpha, dim=1, keepdim=True)
            results_dict["pred_uncertainty"].append(uncertainty)

        else:
            output = model(img_variable)
            _, preds = torch.max(output, 1)
            prob = F.softmax(output, dim=1)

        results_dict["pred_labels"].append(preds.unsqueeze(-1))
        results_dict["true_labels"].append(labels.unsqueeze(-1))
        results_dict["pred_probs"].append(prob)

        # statistics
        results_dict["acc"].append(torch.mean((preds == labels.data).float()).item())

        if verbose:
            print_str = f'test_acc: {np.mean(results_dict["acc"]):.4f}'
            my_iter.set_description(print_str)
            my_iter.refresh()

    results_dict["pred_uncertainty"] = torch.cat(results_dict["pred_uncertainty"], 0)
    results_dict["pred_probs"] = torch.cat(results_dict["pred_probs"], 0)
    results_dict["pred_labels"] = torch.cat(results_dict["pred_labels"], 0)
    results_dict["true_labels"] = torch.cat(results_dict["true_labels"], 0)

In [None]:
for idx in range(len(outputs)):
    test_data[f"{outputs[idx]}_conf"] = results_dict["pred_probs"][:, idx].cpu().numpy()
test_data["uncertainty"] = results_dict["pred_uncertainty"][:, 0].cpu().numpy()
test_data["pred_labels"] = results_dict["pred_labels"][:, 0].cpu().numpy()
test_data["true_labels"] = results_dict["true_labels"][:, 0].cpu().numpy()
test_data["pred_conf"] = np.max(results_dict["pred_probs"].cpu().numpy(), 1)

In [None]:
cond = (test_data["true_labels"] == 0)

title = "p-type"
fig = reliability_diagram(
    test_data[cond]["true_labels"].to_numpy(), 
    test_data[cond]["pred_labels"].to_numpy(), 
    test_data[cond]["pred_conf"].to_numpy(), 
    num_bins=10, draw_ece=True,
    draw_bin_importance="alpha", draw_averages=True,
    title=title, figsize=(5, 5), dpi=100, 
    return_fig=True)

In [None]:
cond0 = (test_data["true_labels"] == 0)
cond1 = (test_data["true_labels"] == 1)
cond2 = (test_data["true_labels"] == 2)
cond3 = (test_data["true_labels"] == 3)

results = OrderedDict()
results[outputs[0]] = {
    "true_labels": test_data[cond0]["true_labels"].values, 
    "pred_labels": test_data[cond0]["pred_labels"].values, 
    "confidences": test_data[cond0]["pred_conf"].values
}
results[outputs[1]] = {
    "true_labels": test_data[cond1]["true_labels"].values, 
    "pred_labels": test_data[cond1]["pred_labels"].values, 
    "confidences": test_data[cond1]["pred_conf"].values
}
results[outputs[2]] = {
    "true_labels": test_data[cond2]["true_labels"].values, 
    "pred_labels": test_data[cond2]["pred_labels"].values, 
    "confidences": test_data[cond2]["pred_conf"].values
}
results[outputs[3]] = {
    "true_labels": test_data[cond3]["true_labels"].values, 
    "pred_labels": test_data[cond3]["pred_labels"].values, 
    "confidences": test_data[cond3]["pred_conf"].values
}

In [None]:
fig = reliability_diagrams(results, num_bins=10, draw_bin_importance="alpha",
                           num_cols=2, dpi=100, return_fig=True)

In [None]:
def compute_cov(df, col = "pred_conf", quan = "uncertainty", ascending = False):
    df = df.copy()
    df = df.sort_values(col, ascending = ascending)
    df["dummy"] = 1
    df[f"cu_{quan}"] = df[quan].cumsum() / df["dummy"].cumsum()
    df[f"cu_{col}"] = df[col].cumsum() / df["dummy"].cumsum()
    df[f"{col}_cov"] = df["dummy"].cumsum() / len(df)
    return df

In [None]:
test_data_sorted = compute_cov(test_data, col = "pred_conf", quan = "acc")

In [None]:
fig = plt.figure(figsize=(7, 5))
ax = fig.add_subplot(1, 1, 1)

ax.plot(
    test_data_sorted["pred_conf_cov"],
    test_data_sorted["cu_acc"]
)
ax.set_ylabel("Cumulative accuracy")
ax.set_xlabel("Coverage (sorted by predicted confidence)")
#test_data_sorted[outputs + [f"{x}_conf" for x in outputs] + ["acc", "uncertainty", "pred_conf"]]

In [None]:
fig = plt.figure(figsize=(7, 5))
ax = fig.add_subplot(1, 1, 1)

test_data_sorted = compute_cov(test_data, col = "uncertainty", quan = "acc", ascending = True)

l1, = ax.plot(
    test_data_sorted["uncertainty"],
    test_data_sorted["cu_acc"],
    label = "sigma"
)

ax2 = ax.twiny()
l2, = ax2.plot(
    test_data_sorted["uncertainty_cov"],
    test_data_sorted["cu_acc"], 
    color='orange', ls = "--", 
    label = "fraction")

ax.set_ylabel("Cumulative accuracy")
ax2.set_xlabel("Test data fraction")
ax.set_xlabel("Predicted uncertainty")

plt.legend([l1, l2], ["sigma", "fraction"])

In [None]:
fig = plt.figure(figsize=(7, 5))
ax = fig.add_subplot(1, 1, 1)

test_data_sorted = compute_cov(test_data, col = "pred_conf", quan = "uncertainty")

ax.plot(
    test_data_sorted["pred_conf_cov"],
    test_data_sorted["cu_uncertainty"]
)
ax.set_ylabel("Cumulative uncertainty")
ax.set_xlabel("Coverage (sorted by predicted confidence)")

In [None]:
fig = plt.figure(figsize=(7, 5))
ax = fig.add_subplot(1, 1, 1)

test_data_sorted = compute_cov(test_data, col = "pred_conf", quan = "uncertainty")

ax.plot(
    test_data_sorted["pred_conf"],
    test_data_sorted["uncertainty"]
)
ax.set_ylabel("Uncertainty")
ax.set_xlabel("Prediction confidence")