## Setup

In [None]:
import os
import pickle
import random
import sys

import numpy as np
import timm
import torch
import torchvision

sys.path.append("[PATH TO PROJECT DIRECTORY]")

from utils.data_utils import load_cifar10, load_cifar100
from utils.ece_utils import *
from utils.eval_utils import *
from utils.noise_utils import *
from utils.plotting_utils import *

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device != "cpu":
    print("Device count: ", torch.cuda.device_count())
    print("GPU being used: {}".format(torch.cuda.get_device_name(0)))

In [None]:
seed = 42
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

## Synthetic Experiments

### Data Setup

In [None]:
# The data is -0.5 if label is 0 and 0.5 if label is 1.
def generate_synthetic(n_sample=1000):
    labels = torch.rand(n_sample, 1).round()
    data = labels - 0.5
    return data, labels

In [None]:
n_sample = 1000
data, labels = generate_synthetic(n_sample)

In [None]:
model = torch.nn.Linear(1, 1, bias=False)
model.weight.data[0, 0] = 1e-3

In [None]:
model_probs = torch.nn.functional.sigmoid(model(data))

In [None]:
# Sanity checks.
print("Accuracy: ", (model_probs.round() == labels).sum() / len(labels))
print("Max pred: ", model_probs.max())
print("Min pred: ", model_probs.min())

### ECE Evals

In [None]:
from netcal.metrics import ECE

In [None]:
bin_range = list(range(1, 101))
n_t = 10000

In [None]:
bin_eces, ls_eces = [], []
for bin_size in bin_range:
    # LS-ECE.
    logits = model(data).detach()
    noise = GaussianNoise(sigma=1/bin_size)
    ls_eces.append(100 * logit_smoothed_ece(logits, labels, n_t, noise))
    
    # Binned ECE.
    preds = torch.nn.functional.sigmoid(logits).reshape(-1).numpy()
    ece = ECE(bin_size)
    bin_eces.append(100 * ece.measure(preds, labels.reshape(-1).numpy()))

In [None]:
plot_multi_dataset_metrics(
    fname="synth_experiments.png",
    x_label=r"Bin Size ($1/\sigma$)", 
    y_label="ECE Value (%)", 
    xs=bin_range, 
    metric_means=[bin_eces, ls_eces], 
    metric_stds=None, 
    datasets=["Binned ECE", "LS-ECE"])

## Pretrained Models

### Data Setup

In [None]:
_, cifar10_test = load_cifar10()
cifar10_loader = torch.utils.data.DataLoader(cifar10_test, batch_size=500, shuffle=False)

In [None]:
_, cifar100_test = load_cifar100()
cifar100_loader = torch.utils.data.DataLoader(cifar10_test, batch_size=500, shuffle=False)

### Evaluation

In [None]:
def eval_models(model_names, loader, bin_var_range, device="cpu"):
    model_bin_eces, model_ls_eces = [], []
    for model_name in model_names:
        model = torch.hub.load("chenyaofo/pytorch-cifar-models", model_name, pretrained=True).to(device)

        softmaxes, labels = get_softmax_and_labels(model, loader, device)
        bin_eces = []
        for bin_size in bin_var_range:
            ece = ECE(bin_size)
            bin_eces.append(100 * ece.measure(softmaxes, labels))

        logits, labels = get_binary_logits_and_labels(model, cifar10_loader, device)
        ls_eces = []
        for bin_size in bin_var_range:
            noise = GaussianNoise(sigma=1/bin_size)
            ls_eces.append(100 * logit_smoothed_ece(logits, labels, n_t, noise))

        model_bin_eces.append(bin_eces)
        model_ls_eces.append(ls_eces)
        
    return model_bin_eces, model_ls_eces

In [None]:
cifar10_models = ["cifar10_resnet32", "cifar10_vgg16_bn", "cifar10_mobilenetv2_x1_4"]
cifar100_models = ["cifar10_resnet32", "cifar10_vgg16_bn", "cifar10_mobilenetv2_x1_4"]
datasets = ["ResNet-32 (Bin)", "VGG-16 (Bin)", "MobileNet V2 (Bin)", "ResNet-32 (LS)", "VGG-16 (LS)", "MobileNet V2 (LS)"]

In [None]:
bin_var_range = list(range(0, 101, 10))
bin_var_range[0] = 1

In [None]:
cifar10_bin_eces, cifar10_ls_eces = eval_models(cifar10_models, cifar10_loader, bin_var_range, device=device)

In [None]:
plot_multi_dataset_metrics(
    fname="cifar10_eval.png",
    x_label=r"Bin Size ($1/\sigma$)", 
    y_label="ECE Value (%)", 
    xs=bin_var_range, 
    metric_means=cifar10_bin_eces + cifar10_ls_eces, 
    metric_stds=None, 
    datasets=datasets,
    custom_colors=["C0", "C1", "C2", "C0", "C1", "C2"],
    custom_lines=["solid", "solid", "solid", "dashed", "dashed", "dashed"],
)

In [None]:
cifar100_bin_eces, cifar100_ls_eces = eval_models(cifar100_models, cifar100_loader, bin_var_range, device=device)

In [None]:
plot_multi_dataset_metrics(
    fname="cifar100_eval.png",
    x_label=r"Bin Size ($1/\sigma$)", 
    y_label="ECE Value (%)", 
    xs=bin_var_range, 
    metric_means=cifar100_bin_eces + cifar100_ls_eces, 
    metric_stds=None, 
    datasets=datasets,
    custom_colors=["C0", "C1", "C2", "C0", "C1", "C2"],
    custom_lines=["solid", "solid", "solid", "dashed", "dashed", "dashed"],
)