In [1]:
# import packages
import numpy as np
import torch
from matplotlib import pyplot as plt
%matplotlib inline


from utils import data, measures, models, plot, run

## Failure Modes

As mentioned earlier, the Prior Networks approach requires the user to provide examples of out-of-distribution data for the network to be trained on. This is problematic because (1) specifying all possible out-of-distribution data is difficult, and (2) failing to provide examples of all possible out-of-distribution data results in the model failing to identify out-of-distribution data that it was not trained on.

Here, we demonstrate this phenomenon with another toy example. There are three in-distribution clusters and the user naively specifies a fourth out-of-distribution cluster, thinking this will be sufficient for the model to learn to identify all possible out-of-distribution inputs. However, at test time, we present the model with a fifth, previously-unseen out-of-distribution cluster and show that the model generates highly confident (low entropy) predictions for the incorrect class.

We start by generating and displaying these five clusters:

In [9]:
failure_data = data.create_data(
    create_data_functions=[data.create_data_mixture_of_gaussians,],
    functions_args=[{
        'gaussians_means': 5 * np.array([
            [0., 2.],
            [-np.sqrt(3), -1.],
            [np.sqrt(3), -1.],
            [-5., -5.],
            [5., 0.]]
        ),
        'gaussians_covariances': np.array([
            [[2.0, 0], [0, 2.0]],
            [[2.0, 0], [0, 2.0]],
            [[2.0, 0], [0, 2.0]],
            [[2.0, 0], [0, 2.0]],
            [[2.0, 0], [0, 2.0]],
        ]),
        'n_samples_per_gaussian': np.array(
            [250, 250, 250, 250, 250]),
        'out_of_distribution': np.array(
            [False, False, False, True, True])}])

In [10]:
labels_names = ['Class 1', 'Class 2', 'Class 3', 'Training OOD', 'Testing OOD']

plot.plot_training_data(
    samples=failure_data['samples'].numpy(),
    labels=failure_data['targets'].numpy(),
    labels_names=labels_names,
    plot_title='Training Data',
    xaxis=dict(title='Patient Feature 1 (e.g. age)'),
    yaxis=dict(title='Patient Feature 2 (e.g. BMI)')
)

Like in the successful examples above, we'll create our model and train it on the three in-distriubtion clusters and one out-of-distribution cluster. The loss demonstrates that the model converges.

In [11]:
# create the model, optimizer, training data
model = run.create_model(in_dim=2, out_dim=3, n_per_hidden_layer=[50], args={})
optimizer = run.create_optimizer(model=model, args={'lr': 0.001})
loss_fn = run.create_loss_fn(loss_fn_str='kl', args={})
train_data_indices = failure_data['targets'] != 4
train_data = {k: v[train_data_indices] for k, v in failure_data.items()}
test_data = {k: v[~train_data_indices] for k, v in failure_data.items()}

In [12]:
# fit the model
model, optimizer, training_loss = run.train_model(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    train_data=train_data,
    args={},
    n_epochs=1000,
    batch_size=32)

In [13]:
# plot the training loss
plot.plot_training_loss(training_loss=training_loss)

However, when we plot the predicted classes, we see that the model is highly confident in the wrong class. By visualizing the model's decision surface, we see that the model has learned to identify the out-of-distribution cluster in the training data, it is not capable of generalizing to new out-of-distribution data.

In [14]:
# print predictions
test_model_outputs = model(test_data['samples'])
print(np.round(test_model_outputs['y_pred'].detach().numpy(), 3))

[[0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0.

In [15]:
#please rerun this to create interactive 3D plot
plot.plot_decision_surface(model=model,
                           samples=failure_data['samples'],
                           labels=failure_data['targets'],
                           labels_names=labels_names,
                           z_fn=measures.entropy_categorical,
                           x_axis_title='Patient Feature 1 (e.g. age)',
                           y_axis_title='Patient Feature 2 (e.g. BMI)',
                           z_axis_title='Predicted Class Entropy'
                          )