# **On Calibration of Modern Neural Networks**

In [88]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import numpy as np
import pandas as pd

from sklearn.metrics import accuracy_score
from sklearn.isotonic import IsotonicRegression

from netcal.binning import HistogramBinning, BBQ
from netcal.scaling import TemperatureScaling
from netcal.metrics import ECE

In [89]:
def create_mnist_data(batch_size=64):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.ToTensor(),  
        transforms.Lambda(lambda x: x.view(-1))  
    ])

    full_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # Divide treino completo em treino (60%) e validação (40%)
    train_size = int(0.6 * len(full_train))  # 60% de 60000
    val_size = len(full_train) - train_size  # 40% restante
    train_data, val_data = random_split(full_train, [train_size, val_size], generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size)
    test_loader = DataLoader(test_data, batch_size=batch_size)

    return train_loader, val_loader, test_loader, device

class SimpleNN(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=128, output_dim=10):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

# Criação dos dataloaders e modelo
train_loader, val_loader, test_loader, device = create_mnist_data()
model = SimpleNN().to(device)


### **Vector Scaling // Matrix Scaling**

In [90]:
class VectorScaling(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(n_classes))
        self.bias = nn.Parameter(torch.zeros(n_classes))
    def forward(self, logits):
        return logits * self.scale + self.bias


class MatrixScaling(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.weight = nn.Parameter(torch.eye(n_classes))
        self.bias = nn.Parameter(torch.zeros(n_classes))
        
    def forward(self, logits):
        return torch.matmul(logits, self.weight) + self.bias


class _ECELoss(nn.Module):
    def __init__(self, n_bins=15):
        super().__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            in_bin = (confidences > bin_lower) * (confidences <= bin_upper)
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        return ece

In [91]:
def train_simple_model(train_loader, input_dim, n_classes, device, epochs=20, lr=0.01):
    model = SimpleNN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            logits = model(x_batch)
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")
    return model


model = train_simple_model(train_loader, 784, 10, device)

Epoch 5/20, Loss: 0.1030
Epoch 10/20, Loss: 0.0905
Epoch 15/20, Loss: 0.0622
Epoch 20/20, Loss: 0.0649


### **BBQ // Hist. Bins // Isotonic Regresiion // Temperature Scaling**

In [None]:
def evaluate(probs, y_true, name):
    acc = accuracy_score(y_true, np.argmax(probs, axis=1))
    ece = ECE(15).measure(probs, y_true)
    #print(f"{name} - Accuracy: {acc:.4f}, ECE: {ece:.4f}")
    return [name, acc, ece]


def train_nn(model, train_loader, val_loader, device, epochs=20, lr=0.01):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            logits = model(x_batch)
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        # if (epoch + 1) % 5 == 0:
            # print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")
    
    # Função para extrair probabilidades no conjunto de validação e teste
    def get_probs(model, loader):
        model.eval()
        all_probs = []
        all_labels = []
        with torch.no_grad():
            for x, y in loader:
                x = x.to(device)
                logits = model(x)
                probs = F.softmax(logits, dim=1).cpu().numpy()
                all_probs.append(probs)
                all_labels.append(y.numpy())
        return np.vstack(all_probs), np.hstack(all_labels)
    
    probs_val, y_val = get_probs(model, val_loader)
    probs_test, y_test = get_probs(model, test_loader)
    
    return model, probs_val, y_val, probs_test, y_test


def calibrate_and_evaluate(probs_val, y_val, probs_test, y_test):
    results = []
    
    results.append(evaluate(probs_test, y_test, "Uncalibrated NN"))
    
    hb = HistogramBinning()
    hb.fit(probs_val, y_val)
    results.append(evaluate(hb.transform(probs_test), y_test, "Histogram Binning"))
    
    iso_probs = []
    for k in range(probs_val.shape[1]):
        ir = IsotonicRegression(out_of_bounds='clip')
        y_bin = (y_val == k).astype(int)
        ir.fit(probs_val[:, k], y_bin)
        iso_probs.append(ir.predict(probs_test[:, k]))
    iso_probs = np.vstack(iso_probs).T
    results.append(evaluate(iso_probs, y_test, "Isotonic Regression"))
    
    bbq = BBQ()
    bbq.fit(probs_val, y_val)
    results.append(evaluate(bbq.transform(probs_test), y_test, "BBQ"))
    
    ts = TemperatureScaling()
    ts.fit(probs_val, y_val)
    results.append(evaluate(ts.transform(probs_test), y_test, "Temperature Scaling"))
    
    return pd.DataFrame(results, columns=["Model", "Accuracy", "ECE"])

model, probs_val, y_val, probs_test, y_test = train_nn(model, train_loader, val_loader, device, epochs=20, lr=0.01)

results = calibrate_and_evaluate(probs_val, y_val, probs_test, y_test)

In [93]:
class ModelWithVectorScaling(nn.Module):
    def __init__(self, model, n_classes):
        super(ModelWithVectorScaling, self).__init__()
        self.model = model
        self.calibrator = VectorScaling(n_classes)

    def forward(self, x):
        logits = self.model(x)
        calibrated_logits = self.calibrator(logits)
        return calibrated_logits

    def set_temperature(self, valid_loader, lr=0.01, max_iter=100):
        device = next(self.parameters()).device
        self.calibrator.train()
        nll_criterion = nn.CrossEntropyLoss()
        ece_criterion = _ECELoss()

        logits_list = []
        labels_list = []
        with torch.no_grad():
            for input, label in valid_loader:
                input = input.to(device)
                logits = self.model(input)
                logits_list.append(logits)
                labels_list.append(label.to(device))
            logits = torch.cat(logits_list)
            labels = torch.cat(labels_list)

        before_nll = nll_criterion(logits, labels).item()
        before_ece = ece_criterion(logits, labels).item()
        #print(f"Before vector scaling - NLL: {before_nll:.3f}, ECE: {before_ece:.3f}")

        optimizer = optim.LBFGS(self.calibrator.parameters(), lr=lr, max_iter=max_iter)

        def closure():
            optimizer.zero_grad()
            loss = nll_criterion(self.calibrator(logits), labels)
            loss.backward()
            return loss

        optimizer.step(closure)

        after_nll = nll_criterion(self.calibrator(logits), labels).item()
        after_ece = ece_criterion(self.calibrator(logits), labels).item()
        #print(f"After vector scaling - NLL: {after_nll:.3f}, ECE: {after_ece:.3f}")
        return self
    
    
class ModelWithMatrixScaling(nn.Module):
    def __init__(self, model, n_classes):
        super(ModelWithMatrixScaling, self).__init__()
        self.model = model
        self.calibrator = MatrixScaling(n_classes)

    def forward(self, x):
        logits = self.model(x)
        calibrated_logits = self.calibrator(logits)
        return calibrated_logits

    def set_temperature(self, valid_loader, lr=0.01, max_iter=100):
        device = next(self.parameters()).device
        self.calibrator.train()
        nll_criterion = nn.CrossEntropyLoss()
        ece_criterion = _ECELoss()

        logits_list = []
        labels_list = []
        with torch.no_grad():
            for input, label in valid_loader:
                input = input.to(device)
                logits = self.model(input)
                logits_list.append(logits)
                labels_list.append(label.to(device))
            logits = torch.cat(logits_list)
            labels = torch.cat(labels_list)

        before_nll = nll_criterion(logits, labels).item()
        before_ece = ece_criterion(logits, labels).item()
        #print(f"Before matrix scaling - NLL: {before_nll:.3f}, ECE: {before_ece:.3f}")

        optimizer = optim.LBFGS(self.calibrator.parameters(), lr=lr, max_iter=max_iter)

        def closure():
            optimizer.zero_grad()
            loss = nll_criterion(self.calibrator(logits), labels)
            loss.backward()
            return loss

        optimizer.step(closure)

        after_nll = nll_criterion(self.calibrator(logits), labels).item()
        after_ece = ece_criterion(self.calibrator(logits), labels).item()
        #print(f"After matrix scaling - NLL: {after_nll:.3f}, ECE: {after_ece:.3f}")
        return self

### **Results**

In [94]:
def evaluate_torch_model(model, dataloader, device):
    model.eval()
    all_logits = []
    all_labels = []
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            all_logits.append(logits)
            all_labels.append(y)
    logits = torch.cat(all_logits)
    labels = torch.cat(all_labels)

    softmaxes = F.softmax(logits, dim=1).cpu().numpy()
    preds = np.argmax(softmaxes, axis=1)
    labels_np = labels.cpu().numpy()
    acc = (preds == labels_np).mean()

    ece = _ECELoss()(logits, labels).item()
    return acc, ece

n_classes = 10

model_ms = ModelWithMatrixScaling(model, n_classes).to(device)
model_ms.set_temperature(val_loader)
acc_ms, ece_ms = evaluate_torch_model(model_ms, test_loader, device)

model_vs = ModelWithVectorScaling(model, n_classes).to(device)
model_vs.set_temperature(val_loader)
acc_vs, ece_vs = evaluate_torch_model(model_vs, test_loader, device)

results.loc[len(results)] = ["Vector Scaling", acc_vs, ece_vs]
results.loc[len(results)] = ["Matrix Scaling", acc_ms, ece_ms]

results

Unnamed: 0,Model,Accuracy,ECE
0,Uncalibrated NN,0.9628,0.025151
1,Histogram Binning,0.963,0.005922
2,Isotonic Regression,0.9638,0.004706
3,BBQ,0.9633,0.003576
4,Temperature Scaling,0.9628,0.02515
5,Vector Scaling,0.9613,0.012878
6,Matrix Scaling,0.9613,0.013089
