## Bayesian Binning into Quantiles (BBQ) Calibration Method, OVERVIEW

### Bayesian Binning into Quantiles is a calibration method, that does yada yada blah blah.

## 2.1 Load the model

We are going to use the pretrained ResNet18 Model from Torchvision for simplicity.

In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.models import ResNet18_Weights
import torchvision.transforms as T
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

net = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(512, 10, device=device)
net = net.to(device)

Using device: cpu


## 2.2 Load the data

Here we are going to load the CIFAR10 dataset from torchvision to do a quick training session...

In [2]:
transforms = T.Compose([T.ToTensor()])

train = torchvision.datasets.CIFAR10(root="~/datasets", train=True, download=True, transform=transforms)
train, cal = torch.utils.data.random_split(train, [0.8, 0.2])
test = torchvision.datasets.CIFAR10(root="~/datasets", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
cal_loader = DataLoader(cal, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

## 2.3 Train the model

So let's give the model a quick training session with our dataset.

In [3]:
epochs = 5
optimizer = optim.Adam(net.parameters())
criterion = nn.CrossEntropyLoss()
for epoch in tqdm(range(epochs)):
    net.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = criterion(outputs, targets.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}")

  0%|          | 0/5 [00:00<?, ?it/s]

 20%|██        | 1/5 [02:43<10:55, 163.91s/it]

Epoch 1, Running loss: 0.9217946464848367


 40%|████      | 2/5 [05:49<08:49, 176.38s/it]

Epoch 2, Running loss: 0.5628288703359616


 60%|██████    | 3/5 [08:58<06:04, 182.20s/it]

Epoch 3, Running loss: 0.418066429484422


 80%|████████  | 4/5 [12:03<03:03, 183.51s/it]

Epoch 4, Running loss: 0.31996190842169864


100%|██████████| 5/5 [15:04<00:00, 180.98s/it]

Epoch 5, Running loss: 0.2449185317678816





In [None]:
# demo_bbq_calibration_with_diagram.py
import torch
import torch.nn.functional as F
from probly.calibration.bayesian_binning.torch2 import BayesianBinningQuantiles
from probly.evaluation.metrics import brier_score, expected_calibration_error
from probly.calibration.visualization.reliability_diagram import compute_reliability_diagram, plot_reliability_diagram

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assume net, cal_loader, test_loader are defined elsewhere
net.to(device)
net.eval()

# --- Step 1: collect calibration logits and targets
cal_logits_list, cal_targets_list = [], []
with torch.no_grad():
    for inputs, targets in cal_loader:
        inputs = inputs.to(device)
        outputs = net(inputs)  # logits on device
        cal_logits_list.append(outputs.cpu())
        cal_targets_list.append(targets)

cal_logits = torch.cat(cal_logits_list, dim=0)
cal_targets = torch.cat(cal_targets_list, dim=0)

# --- Step 2: convert logits to probabilities
cal_probs = F.softmax(cal_logits, dim=1)  # CPU tensor, shape (N_cal, n_classes)

n_classes = cal_probs.shape[1]
bbq_calibrators = [None] * n_classes

# --- Step 3: fit per-class calibrator
for class_idx in range(n_classes):
    class_probs = cal_probs[:, class_idx]
    binary_labels = (cal_targets == class_idx).long()

    if int(binary_labels.sum()) < 2:
        # Not enough positives for this class
        bbq_calibrators[class_idx] = None
        continue

    calibrator = BayesianBinningQuantiles(max_bins=10)
    calibrator.fit(class_probs, binary_labels)

    if not getattr(calibrator, "is_fitted", False):
        bbq_calibrators[class_idx] = None
    else:
        bbq_calibrators[class_idx] = calibrator

# --- Step 4: collect test logits and targets
test_logits_list, test_targets_list = [], []
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(device)
        outputs = net(inputs)
        test_logits_list.append(outputs.cpu())
        test_targets_list.append(targets)

test_logits = torch.cat(test_logits_list, dim=0)
test_targets = torch.cat(test_targets_list, dim=0)

# --- Step 5: convert test logits to probabilities
test_probs = F.softmax(test_logits, dim=1)  # CPU tensor

# --- Step 6: calibrate per-class
calibrated_probs = torch.zeros_like(test_probs, dtype=torch.float32)

for class_idx in range(n_classes):
    class_probs = test_probs[:, class_idx]
    calibrator = bbq_calibrators[class_idx]
    if calibrator is None or not getattr(calibrator, "is_fitted", False):
        calibrated_probs[:, class_idx] = class_probs
    else:
        try:
            preds = calibrator.predict(class_probs)
            calibrated_probs[:, class_idx] = preds
        except Exception:
            calibrated_probs[:, class_idx] = class_probs

# --- Step 7: renormalize rows
row_sums = calibrated_probs.sum(dim=1, keepdim=True)
zero_rows = (row_sums == 0).squeeze()
if zero_rows.any():
    calibrated_probs[zero_rows] = test_probs[zero_rows]
    row_sums = calibrated_probs.sum(dim=1, keepdim=True)
calibrated_probs /= row_sums

# --- Step 8: evaluate calibration metrics
calibrated_probs_np = calibrated_probs.numpy()
test_targets_np = test_targets.numpy()

accuracy = (calibrated_probs_np.argmax(axis=1) == test_targets_np).mean()
ece = expected_calibration_error(calibrated_probs_np, test_targets_np)
brier = brier_score(calibrated_probs_np, test_targets_np)

print(f"Accuracy after BBQ calibration: {accuracy:.4f}")
print(f"ECE after BBQ calibration: {ece:.4f}")
print(f"Brier score after BBQ calibration: {brier:.4f}")

# --- Step 9: compute and plot reliability diagram
diagram = compute_reliability_diagram(calibrated_probs_np, test_targets_np)
plot_reliability_diagram(diagram, "After Bayesian Binning Quantiles Calibration")


Class 0: probs shape=torch.Size([10000]), labels shape=torch.Size([10000])
 probs dtype=torch.float32, labels dtype=torch.int64
 probs range=[0.0000, 1.0000]
 labels sum=964 out of 10000
Class 1: probs shape=torch.Size([10000]), labels shape=torch.Size([10000])
 probs dtype=torch.float32, labels dtype=torch.int64
 probs range=[0.0000, 1.0000]
 labels sum=1065 out of 10000
Class 2: probs shape=torch.Size([10000]), labels shape=torch.Size([10000])
 probs dtype=torch.float32, labels dtype=torch.int64
 probs range=[0.0000, 1.0000]
 labels sum=965 out of 10000
Class 3: probs shape=torch.Size([10000]), labels shape=torch.Size([10000])
 probs dtype=torch.float32, labels dtype=torch.int64
 probs range=[0.0000, 0.9997]
 labels sum=979 out of 10000
Class 4: probs shape=torch.Size([10000]), labels shape=torch.Size([10000])
 probs dtype=torch.float32, labels dtype=torch.int64
 probs range=[0.0000, 1.0000]
 labels sum=983 out of 10000
Class 5: probs shape=torch.Size([10000]), labels shape=torch.Siz

  bin_ids = torch.bucketize(calibration_set, edges) - 1



Accuracy after BBQ calibration: 0.7880
ECE after BBQ calibration: 0.0174
Brier score after BBQ calibration: 0.3027
