In [None]:
from __future__ import annotations

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_classification
import torch
from torch_imp import HistogramBinning

torch.manual_seed(42)
np.random.seed(42)

print("Torch version:", torch.__version__)

In [None]:
# Create synthetic binary classification dataset
X, y = make_classification(
    n_samples = 2000,
    n_features = 20,
    n_informative = 10,
    n_redundant = 2,
    random_state = 42)

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)
# Split into three sets:
# Training: 0-1500 (not used in this demo, for illustration)
# Calibration: 1500-1750 (used to fit the calibrator)
# Test: 1750-2000 (used to evaluate calibration performance)
X_train, y_train = X[:1500], y[:1500]
X_cal, y_cal = X[1500:1750], y[1500:1750]
X_test, y_test = X[1750:], y[1750:]

print("shapes:", X_train.shape, X_cal.shape, X_test.shape)

In [None]:
# using ImageNet weights for demonstration
import torch
from torchvision import models

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

model.eval()

In [None]:
#Load and preprocess CIFAR-10 test set with ImageNet normalization
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]    ),
])

dataset = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=preprocess)

loader = DataLoader(dataset, batch_size=32, shuffle=False)

In [None]:
#Convert CIFAR-10 to binary: "cat" (class 3) vs. "not cat"
import torch

classes = ["airplane", "automobile", "bird", "cat", "deer",
           "dog", "frog", "horse", "ship", "truck"]

target_class = 3

all_images = []
binary_labels = []

for img, label in dataset:
    all_images.append(img)

    binary_labels.append(1 if label == target_class else 0)

all_images = torch.stack(all_images)
binary_labels = torch.tensor(binary_labels)

len(all_images), len(binary_labels)


In [None]:
#Extract softmax probabilities for ImageNet "tammy cat" class (281)
import torch.nn.functional as F

model.eval()

probs = []

with torch.no_grad():
    for imgs, _labels in loader :
        imgs = imgs  # noqa: PLW0127, PLW2901

        logits = model(imgs)

        softmaxed = F.softmax(logits, dim=1)

        cat_probs = softmaxed[:, 281]

        probs.append(cat_probs)
y_pred = torch.cat(probs)
y_pred[:10]

In [None]:
#Fit histogram binnig on calibration set, evaluate on test set
device = torch.device("cpu")
calibrator = HistogramBinning(base_model=model, device=device)

#Fit on CALIBRATION set
cal_preds = torch.tensor(y_pred[1500:1750], dtype=torch.float32)
cal_labels = torch.tensor(binary_labels[1500:1750], dtype=torch.float32)
calibrator.fit(cal_preds, cal_labels)

#Evaluate on TEST set
test_preds = torch.tensor(y_pred[1750:], dtype=torch.float32)
test_labels = torch.tensor(binary_labels[1750:], dtype=torch.float32)

test_preds_calibrated = calibrator.predict(test_preds)

print("Test predictions before calibration (first 10):")
print(test_preds[:10])

print("\nTest predictions **after** calibration (first 10):")
print(test_preds_calibrated[:10])

In [None]:
# --- Compute ECE and plot reliability diagram on the test set ---

import numpy as np
import torch


# Function to compute ECE (Expected Calibration Error)
def compute_ece(preds: torch.Tensor, labels: torch.Tensor, n_bins: int = 15) -> float:

    bins = torch.linspace(0.0, 1.0, n_bins + 1, device=preds.device)
    ece = 0.0

    for i in range(n_bins):
        start, end = bins[i], bins[i+1]
        # select predictions in this bin
        in_bin = (preds >= start) & (preds < end)
        bin_size = in_bin.sum().item()

        if bin_size > 0:
            # Convert to float for mean calculation
            avg_pred = preds[in_bin].float().mean().item()
            avg_label = labels[in_bin].float().mean().item()
            # weight by fraction of samples in this bin
            ece += (bin_size / len(preds)) * abs(avg_pred - avg_label)

    return ece


# Compute ECE before & after calibration
ece_before = compute_ece(test_preds, test_labels)
ece_after  = compute_ece(test_preds_calibrated, test_labels)

print(f"ECE before calibration: {ece_before:.4f}")
print(f"ECE after calibration:  {ece_after:.4f}")


# Function to plot reliability diagram
def reliability_plot(raw: torch.Tensor, calibrated: torch.Tensor, labels: torch.Tensor, n_bins: int = 15) -> None:

    bins = torch.linspace(0.0, 1.0, n_bins + 1, device=raw.device)

    def bin_stats(preds: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
        conf = []
        acc = []
        for i in range(n_bins):
            in_bin = (preds >= bins[i]) & (preds < bins[i+1])
            if in_bin.sum() > 0:
                conf.append(preds[in_bin].float().mean().item())
                acc.append(labels[in_bin].float().mean().item())
            else:
                conf.append(0.0)
                acc.append(0.0)
        return np.array(conf), np.array(acc)

    raw_conf, raw_acc = bin_stats(raw)
    cal_conf, cal_acc = bin_stats(calibrated)

    plt.figure(figsize=(6, 6))
    plt.plot([0, 1], [0, 1], "--", label="Perfect Calibration")
    plt.plot(raw_conf, raw_acc, label="Before Calibration")
    plt.plot(cal_conf, cal_acc, label="After Calibration")
    plt.xlabel("Average predicted probability")
    plt.ylabel("Accuracy in bin")
    plt.title("Reliability Diagram")
    plt.legend()
    plt.grid(True)
    plt.show()


# Draw the plot using test set
reliability_plot(test_preds, test_preds_calibrated, test_labels)
