# Tutorial on exploring out-of-distribution examples using metadata.


## _Problem Statement_

In computer vision tasks like **image classification** and **object detection**, when OOD examples are detected, there are two things we might like to know about them. First of all, looking at each example as an individual, what makes it stand out from the reference dataset? Second, as a population, what about them has shifted the most relative to the reference dataset? Metadata can help address both of these things.

For the first, we can look at the values of each metadata feature of each example, and find which lies furthest out in the tails of the reference distribution. Since we are interested in extreme values, we should take the median of each reference feature, and then find the absolute deviation of each OOD example's from that median. (We can store the sign if we wish, but we should evaluate significance in terms of absolute deviation). In order to compare between features, we will normalize deviations by the inter-quartile range of each feature's reference distribution.

For the second, we can compare the distribution of each feature to the reference using the Kolmogorov-Smirnov test. For features which show a statistically significant difference, we can use the Wasserstein to measure it (again, normalized by the IQR of the reference).


### _When to use_

When OOD examples have been detected, e.g. by the DataEval `OOD_AE` class or similar, the tools developed here should be used to try to learn more about specifically what image properties move each image out of the distribution.


### _What you will need_

1. A training image dataset with low percentage of known OOD images.
2. A test image dataset to evaluate for OOD images.
3. A python environment with the following packages installed:
   - `dataeval[all]`


### _Setting up_

Let's import the required libraries needed to set up a minimal working example


In [1]:
import numpy as np
from metadata_utils import InstanceMNIST
from metadata_utils import collate_fn_2 as collate_fn

from dataeval.detectors.ood.ae_torch import OOD_AE
from dataeval.utils.torch.models import AE_torch

from torch.utils.data import DataLoader
from dataeval.utils.torch.datasets import MNIST


## Load the data

We will use InstanceMNIST, a PyTorch wrapper for the TensorFlow MNIST datasets, fgor this tutorial.


In [2]:
corruption_list = ['identity', 'identity', 'translate', 'shot_noise', 'motion_blur', 'scale']

mnist = InstanceMNIST(corruption_list, size=8000)
mnist_val = InstanceMNIST('identity', train=False, size=8000)

refdata = mnist.identity
valdata = mnist_val.identity
shiftdata = mnist.translate
spikydata = mnist.shot_noise
blurdata = mnist.motion_blur
scaledata = mnist.scale



In [3]:
big_batch_size = 2000
collate_fn = collate_fn

refbb = DataLoader(refdata, collate_fn=collate_fn, batch_size=big_batch_size)
for images, labels, metadata in refbb:
    break

# Now adjust shape of images array. 
input_shape = (*images[0].shape, 1)
bbshape = (*images.shape,1)
images = images.reshape(bbshape).detach().numpy()

In [4]:
# Load in the training mnist dataset and use the first 4000
train_ds = MNIST(root="./data/", train=True, download=True, size=4000, unit_interval=True, channels="channels_first")

# Split out the images and labels
images, labels = train_ds.data, train_ds.targets
input_shape = images[0].shape

Files already downloaded and verified


## Initialize the model

Now, lets look at how to use DataEval's OOD detection methods.  
We will focus on a simple autoencoder network from our Alibi Detect provider.


In [5]:
input_shape = (1, 28, 28) # n_channels, n_rows, n_columns
detectors = [
    OOD_AE(AE_torch(input_shape))] # implement as list to make it easy to try additional detectors. 

## Train the model

Next we will train a model on the dataset.
For better results, the epochs can be increased.
We set the threshold to detect the most extreme 1% of training data as out-of-distribution. (Training may take several minutes.)


In [6]:
# 12 epochs in 5 minutes
for detector in detectors:
    print(f"Training {detector.__class__.__name__}...")
    detector.fit(images, threshold_perc=99, epochs=12, verbose=True)

Training OOD_AE...
Epoch 0...
loss: 0.111, |grad|: 0.176
loss: 0.025, |grad|: 0.207
loss: 0.046, |grad|: 0.210
loss: 0.022, |grad|: 0.160
loss: 0.013, |grad|: 0.104
loss: 0.012, |grad|: 0.108
loss: 0.011, |grad|: 0.141
loss: 0.016, |grad|: 0.095
Epoch 1...
loss: 0.013, |grad|: 0.110
loss: 0.013, |grad|: 0.135
loss: 0.024, |grad|: 0.205
loss: 0.012, |grad|: 0.118
loss: 0.009, |grad|: 0.089
loss: 0.008, |grad|: 0.079
loss: 0.005, |grad|: 0.062
loss: 0.016, |grad|: 0.083
Epoch 2...
loss: 0.011, |grad|: 0.101
loss: 0.009, |grad|: 0.115
loss: 0.018, |grad|: 0.173
loss: 0.009, |grad|: 0.093
loss: 0.007, |grad|: 0.079
loss: 0.007, |grad|: 0.090
loss: 0.004, |grad|: 0.058
loss: 0.014, |grad|: 0.083
Epoch 3...
loss: 0.009, |grad|: 0.085
loss: 0.006, |grad|: 0.077
loss: 0.010, |grad|: 0.102
loss: 0.007, |grad|: 0.079
loss: 0.006, |grad|: 0.072
loss: 0.005, |grad|: 0.059
loss: 0.004, |grad|: 0.079
loss: 0.013, |grad|: 0.082
Epoch 4...
loss: 0.009, |grad|: 0.094
loss: 0.006, |grad|: 0.084
loss: 0.

## Test for OOD

We have trained our detector on a dataset of digits.  
What happens when we give it corrupted images of digits (which we expect to be "OOD")?


In [14]:
corrbb = DataLoader(shiftdata, collate_fn=collate_fn, batch_size=big_batch_size)
for corrimages, corrlabels, corrmetadata in corrbb:
    break


Now we evaluate the two datasets using the trained model.


In [15]:
[(type(detector).__name__, np.mean(detector.predict(images).is_ood)) for detector in detectors]

[('OOD_AE', 0.01)]

In [16]:
[(type(detector).__name__, np.mean(detector.predict(corrimages).is_ood)) for detector in detectors]

[('OOD_AE', 0.816)]

### Results

We can see that the Autoencoder based OOD detector was able to identify many of the shot_noise images as outliers.

Depending on your needs, other outlier detectors may work better under specific conditions; you can add them to the detectors list.


In [17]:
ood_detector = detectors[0]

### Understand OOD using metadata

We can now look at the metadata features for OOD examples, and find which metadata features are the most surprising for each one. The function [least_likely_features()](metadata_tools.py#least_likely_features) will do this for us.


In [18]:
from metadata_tools import least_likely_features
od = least_likely_features(refdata, spikydata, ood_detector)


feature        |  occurences
fill_frac      :    2895
spikiness      :    1538
cm_y           :    1095
y_ctr          :    858
cm_x           :    684
x_ctr          :    622
height         :    137
width          :    11
isolated_pixels:    2


We can also compare the distribution of each metadata feature to the reference. [ks_compare](metadata_tools.py#ks_compare) uses the Kolmogorov-Smirnov two-sample test to look for significant shifts of metadata features, and reports them in order of decreasing statitical significance. It also reports the Wasserstein distance between each pair of distributions, in units of the IQR of the reference.

We compare first to the validation data, where we see no significant metadata shifts as expected.


In [12]:
from metadata_tools import ks_compare

valbb = DataLoader(valdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr1bb = DataLoader(shiftdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr2bb = DataLoader(spikydata, collate_fn=collate_fn, batch_size=big_batch_size)
corr3bb = DataLoader(blurdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr4bb = DataLoader(scaledata, collate_fn=collate_fn, batch_size=big_batch_size)

print('identity')
ks_compare(refbb, valbb);
print('\ntranslate')
ks_compare(refbb, corr1bb);
print('\nshot_noise')
ks_compare(refbb, corr2bb);
print('\nmotion_blur')
ks_compare(refbb, corr3bb);
print('\nscale')
ks_compare(refbb, corr4bb);

identity
feature        | p-value | shift/IQR
fill_frac      :  0.024  :   0.029
width          :  0.060  :   0.028
random         :  0.196  :   0.017
spikiness      :  0.239  :   0.022
cm_y           :  0.517  :   0.008
y_ctr          :  0.740  :   0.026
x_ctr          :  0.968  :   0.015
cm_x           :  0.973  :   0.005
height         :  1.000  :   0.003
isolated_pixels:  1.000  :   0.001

translate
feature        | p-value | shift/IQR
cm_x           :  0.000  :   5.422
cm_y           :  0.000  :   5.327
y_ctr          :  0.000  :   1.899
x_ctr          :  0.000  :   1.002
height         :  0.000  :   0.039
width          :  0.000  :   0.020
spikiness      :  0.017  :   0.025
fill_frac      :  0.409  :   0.013
random         :  0.917  :   0.010
isolated_pixels:  1.000  :   0.002

shot_noise
feature        | p-value | shift/IQR
fill_frac      :  0.000  :   0.729
cm_y           :  0.000  :   0.144
cm_x           :  0.000  :   0.066
spikiness      :  0.000  :   0.121
width          : 

### _Summary_

We demonstrate a method for investigating individiual OOD examples using metadata, by finding which metadata feature is most unusual for each example, relative to the reference dataset.

We also demonstrate a method for finding significant distributional shifts in metadata features, and display the p-values of these shifts. We also compute a measure of the magnitudes of distributional shifts, relative to the width of their reference distributions.
