In [None]:
!pip install imagecorruptions
!pip install imgaug

In [None]:
!pip install --upgrade imgaug

Walk through the code example, using a vanilla neural network, A network trained with Bayes By Backprop and a deep ensemble.
Use Cifar 10 data set and plot accuracy and calibration histograms over severity levels.

In [None]:
import cv2
import imgaug.augmenters.imgcorruptlike as icl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.metrics import accuracy_score

Prepare dataset

In [None]:
cifar = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar.load_data()

In [None]:
CLASS_NAMES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

In [None]:
NUM_TRAIN_EXAMPLES = train_images.shape[0]

Define helper functions to define vanilla and ensemble networks

In [None]:
def cnn_building_block(num_filters):
    return tf.keras.Sequential(
        [
            tf.keras.layers.Conv2D(filters=num_filters, kernel_size=(3, 3), activation="relu"),
            tf.keras.layers.MaxPool2D(strides=2),
        ]
    )


def build_and_compile_model():
    model = tf.keras.Sequential(
        [
            tf.keras.layers.Rescaling(1.0 / 255, input_shape=(32, 32, 3)),
            cnn_building_block(16),
            cnn_building_block(32),
            cnn_building_block(64),
            tf.keras.layers.MaxPool2D(strides=2),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation="relu"),
            tf.keras.layers.Dense(10, activation="softmax"),
        ]
    )
    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model

Vanilla models

In [None]:
vanilla_model = build_and_compile_model()
vanilla_model.fit(train_images, train_labels, epochs=10)

Deep ensemble model

In [None]:
NUM_ENSEMBLE_MEMBERS = 5
ensemble_model = []
for ind in range(NUM_ENSEMBLE_MEMBERS):
    member = build_and_compile_model()
    print(f"Train model {ind:02}")
    member.fit(train_images, train_labels, epochs=10)
    ensemble_model.append(member)

Define helper functions to BBB network

In [None]:
def cnn_building_block_bbb(num_filters, kl_divergence_function):
    return tf.keras.Sequential(
        [
            tfp.layers.Convolution2DReparameterization(
                num_filters,
                kernel_size=(3, 3),
                kernel_divergence_fn=kl_divergence_function,
                activation=tf.nn.relu,
            ),
            tf.keras.layers.MaxPool2D(strides=2),
        ]
    )


def build_and_compile_model_bbb():
    kl_divergence_function = lambda q, p, _: tfp.distributions.kl_divergence(q, p) / tf.cast(
        NUM_TRAIN_EXAMPLES, dtype=tf.float32
    )

    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Rescaling(1.0 / 255, input_shape=(32, 32, 3)),
            cnn_building_block_bbb(16, kl_divergence_function),
            cnn_building_block_bbb(32, kl_divergence_function),
            cnn_building_block_bbb(64, kl_divergence_function),
            tf.keras.layers.Flatten(),
            tfp.layers.DenseReparameterization(
                64,
                kernel_divergence_fn=kl_divergence_function,
                activation=tf.nn.relu,
            ),
            tfp.layers.DenseReparameterization(
                10,
                kernel_divergence_fn=kl_divergence_function,
                activation=tf.nn.softmax,
            ),
        ]
    )

    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
        experimental_run_tf_function=False,
    )

    model.build(input_shape=[None, 32, 32, 3])
    return model

BBB network

In [None]:
bbb_model = build_and_compile_model_bbb()
bbb_model.fit(train_images, train_labels, epochs=15)

Test images

In [None]:
NUM_SUBSET = 1000
test_images_subset = test_images[:NUM_SUBSET]
test_labels_subset = test_labels[:NUM_SUBSET]

Apply dataset shift

In [None]:
corruption_functions = [
    icl.GaussianNoise,
    icl.ShotNoise,
    icl.ImpulseNoise,
    icl.DefocusBlur,
    icl.GlassBlur,
    icl.MotionBlur,
    icl.ZoomBlur,
    icl.Snow,
    icl.Frost,
    icl.Fog,
    icl.Brightness,
    icl.Contrast,
    icl.ElasticTransform,
    icl.Pixelate,
    icl.JpegCompression,
]
NUM_TYPES = len(corruption_functions)
NUM_LEVELS = 5

In [None]:
corrupted_images = []
# loop over different corruption severities
for corruption_severity in range(1, NUM_LEVELS + 1):
    corruption_type_batch = []
    # loop over different corruption types
    for corruption_type in corruption_functions:
        corrupted_image_batch = corruption_type(severity=corruption_severity, seed=0)(
            images=test_images_subset
        )
        corruption_type_batch.append(corrupted_image_batch)
    corruption_type_batch = np.stack(corruption_type_batch, axis=0)
    corrupted_images.append(corruption_type_batch)
corrupted_images = np.stack(corrupted_images, axis=0)

**Inference - get predictions**

In [None]:
corrupted_images = corrupted_images.reshape((-1, 32, 32, 3))

In [None]:
# Get predictions on original images
vanilla_predictions = vanilla_model.predict(test_images_subset)
# Get predictions on corrupted images
vanilla_predictions_on_corrupted = vanilla_model.predict(corrupted_images)
vanilla_predictions_on_corrupted = vanilla_predictions_on_corrupted.reshape(
    (NUM_LEVELS, NUM_TYPES, NUM_SUBSET, -1)
)

In [None]:
def get_ensemble_predictions(images, num_inferences):
    ensemble_predictions = tf.stack(
        [ensemble_model[ensemble_ind].predict(images) for ensemble_ind in range(num_inferences)],
        axis=0,
    )
    return np.mean(ensemble_predictions, axis=0)

In [None]:
# Get predictions on original images
ensemble_predictions = get_ensemble_predictions(test_images_subset, NUM_ENSEMBLE_MEMBERS)
# Get predictions on corrupted images
ensemble_predictions_on_corrupted = get_ensemble_predictions(corrupted_images, NUM_ENSEMBLE_MEMBERS)
ensemble_predictions_on_corrupted = ensemble_predictions_on_corrupted.reshape(
    (NUM_LEVELS, NUM_TYPES, NUM_SUBSET, -1)
)

In [None]:
def get_bbb_predictions(images, num_inferences):
    bbb_predictions = tf.stack(
        [bbb_model.predict(images) for _ in range(num_inferences)],
        axis=0,
    )
    return np.mean(bbb_predictions, axis=0)

In [None]:
NUM_INFERENCES_BBB = 20
# Get predictions on original images
bbb_predictions = get_bbb_predictions(test_images_subset, NUM_INFERENCES_BBB)
# Get predictions on corrupted images
bbb_predictions_on_corrupted = get_bbb_predictions(corrupted_images, NUM_INFERENCES_BBB)
bbb_predictions_on_corrupted = bbb_predictions_on_corrupted.reshape(
    (NUM_LEVELS, NUM_TYPES, NUM_SUBSET, -1)
)

Inference - get classes and scores

In [None]:
def get_classes_and_scores(model_predictions):
    model_predicted_classes = np.argmax(model_predictions, axis=-1)
    model_scores = np.max(model_predictions, axis=-1)
    return model_predicted_classes, model_scores

In [None]:
vanilla_predicted_classes, vanilla_scores = get_classes_and_scores(vanilla_predictions)
(
    vanilla_predicted_classes_on_corrupted,
    vanilla_scores_on_corrupted,
) = get_classes_and_scores(vanilla_predictions_on_corrupted)

In [None]:
(
    ensemble_predicted_classes,
    ensemble_scores,
) = get_classes_and_scores(ensemble_predictions)
(
    ensemble_predicted_classes_on_corrupted,
    ensemble_scores_on_corrupted,
) = get_classes_and_scores(ensemble_predictions_on_corrupted)

In [None]:
(
    bbb_predicted_classes,
    bbb_scores,
) = get_classes_and_scores(bbb_predictions)
(
    bbb_predicted_classes_on_corrupted,
    bbb_scores_on_corrupted,
) = get_classes_and_scores(bbb_predictions_on_corrupted)

Visualise scores over data set shift

In [None]:
plot_images = corrupted_images.reshape((NUM_LEVELS, NUM_TYPES, NUM_SUBSET, 32, 32, 3))

In [None]:
# Index of the selected images
ind_image = 9
# Define figure
fig, axes = plt.subplots(nrows=3, ncols=5, figsize=(16, 10))
# Loop over corruption levels
for ind_level in range(NUM_LEVELS):
    # Loop over corruption types
    for ind_type in range(3):
        # Plot slightly upscaled image for easier inspection
        image = plot_images[ind_level, ind_type, ind_image, ...]
        image_upscaled = cv2.resize(image, dsize=(150, 150), interpolation=cv2.INTER_CUBIC)
        axes[ind_type, ind_level].imshow(image_upscaled)
        # Get score and class predicted by vanilla model
        vanilla_score = vanilla_scores_on_corrupted[ind_level, ind_type, ind_image, ...]
        vanilla_prediction = vanilla_predicted_classes_on_corrupted[
            ind_level, ind_type, ind_image, ...
        ]
        # Get score and class predicted by ensemble model
        ensemble_score = ensemble_scores_on_corrupted[ind_level, ind_type, ind_image, ...]
        ensemble_prediction = ensemble_predicted_classes_on_corrupted[
            ind_level, ind_type, ind_image, ...
        ]
        # Get score and class predicted by BBB model
        bbb_score = bbb_scores_on_corrupted[ind_level, ind_type, ind_image, ...]
        bbb_prediction = bbb_predicted_classes_on_corrupted[ind_level, ind_type, ind_image, ...]
        # Plot prediction info in title
        title_text = (
            f"Vanilla: {vanilla_score:.3f} "
            + f"[{CLASS_NAMES[vanilla_prediction]}] \n"
            + f"Ensemble: {ensemble_score:.3f} "
            + f"[{CLASS_NAMES[ensemble_prediction]}] \n"
            + f"BBB: {bbb_score:.3f} "
            + f"[{CLASS_NAMES[bbb_prediction]}]"
        )
        axes[ind_type, ind_level].set_title(title_text, fontsize=14)
        # Remove axes ticks and labels
        axes[ind_type, ind_level].axis("off")
fig.tight_layout()
plt.show()

Accuracy

In [None]:
vanilla_acc = accuracy_score(test_labels_subset.flatten(), vanilla_predicted_classes)
ensemble_acc = accuracy_score(test_labels_subset.flatten(), ensemble_predicted_classes)
bbb_acc = accuracy_score(test_labels_subset.flatten(), bbb_predicted_classes)

In [None]:
print(vanilla_acc)
print(ensemble_acc)
print(bbb_acc)

In [None]:
accuracies = [
    {"model_name": "vanilla", "type": 0, "level": 0, "accuracy": vanilla_acc},
    {"model_name": "ensemble", "type": 0, "level": 0, "accuracy": ensemble_acc},
    {"model_name": "bbb", "type": 0, "level": 0, "accuracy": bbb_acc},
]

In [None]:
for ind_type in range(NUM_TYPES):
    for ind_level in range(NUM_LEVELS):
        # Calculate accuracy for vanilla model
        vanilla_acc_on_corrupted = accuracy_score(
            test_labels_subset.flatten(),
            vanilla_predicted_classes_on_corrupted[ind_level, ind_type, :],
        )
        accuracies.append(
            {
                "model_name": "vanilla",
                "type": ind_type + 1,
                "level": ind_level + 1,
                "accuracy": vanilla_acc_on_corrupted,
            }
        )

        # Calculate accuracy for ensemble model
        ensemble_acc_on_corrupted = accuracy_score(
            test_labels_subset.flatten(),
            ensemble_predicted_classes_on_corrupted[ind_level, ind_type, :],
        )
        accuracies.append(
            {
                "model_name": "ensemble",
                "type": ind_type + 1,
                "level": ind_level + 1,
                "accuracy": ensemble_acc_on_corrupted,
            }
        )

        # Calculate accuracy for BBB model
        bbb_acc_on_corrupted = accuracy_score(
            test_labels_subset.flatten(),
            bbb_predicted_classes_on_corrupted[ind_level, ind_type, :],
        )
        accuracies.append(
            {
                "model_name": "bbb",
                "type": ind_type + 1,
                "level": ind_level + 1,
                "accuracy": bbb_acc_on_corrupted,
            }
        )

In [None]:
df = pd.DataFrame(accuracies)
plt.figure(dpi=100)
sns.boxplot(data=df, x="level", y="accuracy", hue="model_name")
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.tight_layout
plt.show()

Calibration

In [None]:
def expected_calibration_error(
    pred_correct,
    pred_score,
    n_bins=5,
):
    """Compute expected calibration error.
    ----------
    pred_correct : np.ndarray (n_samples,)
        Whether the prediction is correct or not
    pred_score : np.ndarray (n_samples,)
        Confidence in the prediction
    n_bins : int, default=5
        Number of bins to discretize the [0, 1] interval.
    """
    # Convert from bool to integer (makes counting easier)
    pred_correct = pred_correct.astype(np.int32)

    # Create bins and assign prediction scores to bins
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    binids = np.searchsorted(bins[1:-1], pred_score)

    # Count number of samples and correct predictions per bin
    bin_true_counts = np.bincount(binids, weights=pred_correct, minlength=len(bins))
    bin_counts = np.bincount(binids, minlength=len(bins))

    # Calculate sum of confidence scores per bin
    bin_probs = np.bincount(binids, weights=pred_score, minlength=len(bins))

    # Identify bins that contain samples
    nonzero = bin_counts != 0
    # Calculate accuracy for every bin
    bin_acc = bin_true_counts[nonzero] / bin_counts[nonzero]
    # Calculate average confidence scores per bin
    bin_conf = bin_probs[nonzero] / bin_counts[nonzero]

    # bin_counts_nonzero = bin_counts[nonzero]
    # array_diff = bin_acc - bin_conf
    # positive = array_diff > 0.0
    # return np.average(array_diff[positive], weights=bin_counts_nonzero[positive])

    return np.average(np.abs(bin_acc - bin_conf), weights=bin_counts[nonzero])

In [None]:
NUM_BINS = 10

vanilla_cal = expected_calibration_error(
    test_labels_subset.flatten() == vanilla_predicted_classes,
    vanilla_scores,
    n_bins=NUM_BINS,
)

ensemble_cal = expected_calibration_error(
    test_labels_subset.flatten() == ensemble_predicted_classes,
    ensemble_scores,
    n_bins=NUM_BINS,
)

bbb_cal = expected_calibration_error(
    test_labels_subset.flatten() == bbb_predicted_classes,
    bbb_scores,
    n_bins=NUM_BINS,
)

In [None]:
print(vanilla_cal)
print(ensemble_cal)
print(bbb_cal)

In [None]:
calibration = [
    {
        "model_name": "vanilla",
        "type": 0,
        "level": 0,
        "calibration_error": vanilla_cal,
    },
    {
        "model_name": "ensemble",
        "type": 0,
        "level": 0,
        "calibration_error": ensemble_cal,
    },
    {
        "model_name": "bbb",
        "type": 0,
        "level": 0,
        "calibration_error": bbb_cal,
    },
]

In [None]:
for ind_type in range(NUM_TYPES):
    for ind_level in range(NUM_LEVELS):
        # Calculate calibration error for vanilla model
        vanilla_cal_on_corrupted = expected_calibration_error(
            test_labels_subset.flatten()
            == vanilla_predicted_classes_on_corrupted[ind_level, ind_type, :],
            vanilla_scores_on_corrupted[ind_level, ind_type, :],
        )
        calibration.append(
            {
                "model_name": "vanilla",
                "type": ind_type + 1,
                "level": ind_level + 1,
                "calibration_error": vanilla_cal_on_corrupted,
            }
        )

        # Calculate calibration error for ensemble model
        ensemble_cal_on_corrupted = expected_calibration_error(
            test_labels_subset.flatten()
            == ensemble_predicted_classes_on_corrupted[ind_level, ind_type, :],
            ensemble_scores_on_corrupted[ind_level, ind_type, :],
        )
        calibration.append(
            {
                "model_name": "ensemble",
                "type": ind_type + 1,
                "level": ind_level + 1,
                "calibration_error": ensemble_cal_on_corrupted,
            }
        )

        # Calculate calibration error for BBB model
        bbb_cal_on_corrupted = expected_calibration_error(
            test_labels_subset.flatten()
            == bbb_predicted_classes_on_corrupted[ind_level, ind_type, :],
            bbb_scores_on_corrupted[ind_level, ind_type, :],
        )
        calibration.append(
            {
                "model_name": "bbb",
                "type": ind_type + 1,
                "level": ind_level + 1,
                "calibration_error": bbb_cal_on_corrupted,
            }
        )

In [None]:
df = pd.DataFrame(calibration)
plt.figure(dpi=100)
sns.boxplot(data=df, x="level", y="calibration_error", hue="model_name")
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.tight_layout
plt.show()