# Tutorial on exploring out-of-distribution examples using metadata.


## _Problem Statement_

In computer vision tasks like **image classification** and **object detection**, when OOD examples are detected, there are two things we might like to know about them. First of all, looking at each example as an individual, what makes it stand out from the reference dataset? Second, as a population, what about them has shifted the most relative to the reference dataset? Metadata can help address both of these things.

For the first, we can look at the values of each metadata feature of each example, and find which lies furthest out in the tails of the reference distribution. Since we are interested in extreme values, we should take the median of each reference feature, and then find the absolute deviation of each OOD example's from that median. (We can store the sign if we wish, but we should evaluate significance in terms of absolute deviation). In order to compare between features, we will normalize deviations by the inter-quartile range of each feature's reference distribution.

For the second, we can compare the distribution of each feature to the reference using the Kolmogorov-Smirnov test. For features which show a statistically significant difference, we can use the Wasserstein to measure it (again, normalized by the IQR of the reference).


### _When to use_

When OOD examples have been detected, e.g. by the DataEval `OOD_AE` class or similar, the tools developed here should be used to try to learn more about specifically what image properties move each image out of the distribution.


### _What you will need_

1. A training image dataset with low percentage of known OOD images.
2. A test image dataset to evaluate for OOD images.
3. A python environment with the following packages installed:
   - `tensorflow-datasets`


### _Setting up_

Let's import the required libraries needed to set up a minimal working example


In [None]:
import numpy as np
from metadata_utils import InstanceMNIST
from metadata_utils import collate_fn_2 as collate_fn

from torch.utils.data import DataLoader

import torch

from torch.utils.data import DataLoader

from metadata_utils import InstanceMNIST, collate_fn_2
from metadata_tools import predict_ood_mi, ks_compare


from dataeval.detectors.ood import OOD_AE
from dataeval.utils.torch.models import  AE
from dataeval.utils.dataset.datasets import MNIST

device = "cuda" if torch.cuda.is_available() else "cpu"



## Load the data

We will use InstanceMNIST, a PyTorch wrapper for the TensorFlow MNIST datasets, fgor this tutorial.


In [None]:
corruption_list = ['identity', 'identity', 'translate', 'shot_noise', 'motion_blur', 'scale']

mnist = InstanceMNIST(corruption_list, size=8000)
mnist_val = InstanceMNIST('identity', train=False, size=8000)

refdata = mnist.identity
valdata = mnist_val.identity
shiftdata = mnist.translate
spikydata = mnist.shot_noise
blurdata = mnist.motion_blur
scaledata = mnist.scale



In [None]:
# Load in the training mnist dataset and use the first 2000
train_ds = MNIST(root="./data/", train=True, download=True, size=2000, unit_interval=True, channels="channels_first")
val_ds = MNIST(root="./data/", train=False, download=True, size=2000, unit_interval=True, channels="channels_first")

# Split out the images and labels
images, labels = train_ds.data, train_ds.targets
val_images, va_labels  = val_ds.data, val_ds.targets

input_shape = images[0].shape

## Initialize the model

Now, lets look at how to use DataEval's OOD detection methods.  
We will focus on a simple autoencoder network from our Alibi Detect provider.


In [None]:
detectors = [
    OOD_AE(AE(input_shape), device)] # implement as list to make it easy to try additional detectors. 

## Train the model

Next we will train a model on the dataset.
For better results, the epochs can be increased.
We set the threshold to detect the most extreme 1% of training data as out-of-distribution. (Training may take several minutes.)


In [None]:
for detector in detectors:
    print(f"Training {detector.__class__.__name__}...")
    detector.fit(images, threshold_perc=99, epochs=23, verbose=True)

## Test for OOD

We have trained our detector on a dataset of digits.  
What happens when we give it corrupted images of digits (which we expect to be "OOD")?


In [None]:
corruption = MNIST(
    root="./data",
    train=True,
    download=False,
    size=2000,
    unit_interval=True,
    channels="channels_first",
    corruption="translate",
)
corrupted_images = corruption.data

Now we evaluate the two datasets using the trained model.


In [None]:
[(type(detector).__name__, np.mean(detector.predict(images).is_ood)) for detector in detectors]

In [None]:
[(type(detector).__name__, np.mean(detector.predict(corrupted_images).is_ood)) for detector in detectors]

### Results

We can see that the Autoencoder based OOD detector was able to identify many of the shot_noise images as outliers.

Depending on your needs, other outlier detectors may work better under specific conditions; you can add them to the detectors list.


In [None]:
ood_detector = detectors[0]

### Understand OOD using metadata

We can now look at the metadata features for OOD examples, and find which metadata features are the most surprising for each one. The function [least_likely_features()](metadata_tools.py#least_likely_features) will do this for us.


In [None]:
from metadata_tools import least_likely_features
od = least_likely_features(refdata, spikydata, ood_detector)


The table above records how many times each metadata feature was the most remarkable feature in a given example. It makes sense that fill_frac would often be out on the tail of the reference fill_frac distribution for added shot noise images, since a single nonzero pixel away from the digit dramatically alters the area of the convex hull (the denominator of the fill fraction).

We can also compare the **distribution** of each metadata feature to the reference. [ks_compare](metadata_tools.py#ks_compare) uses the Kolmogorov-Smirnov two-sample test to look for significant shifts of metadata features, and reports them in order of decreasing statitical significance. It also reports the Wasserstein distance between each pair of distributions, in units of the IQR of the reference.

We compare first to the validation data, where we see no significant metadata shifts as expected.


In [None]:
from metadata_tools import ks_compare

big_batch_size = 2000
collate_fn = collate_fn

refbb = DataLoader(refdata, collate_fn=collate_fn, batch_size=big_batch_size)

valbb = DataLoader(valdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr1bb = DataLoader(shiftdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr2bb = DataLoader(spikydata, collate_fn=collate_fn, batch_size=big_batch_size)
corr3bb = DataLoader(blurdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr4bb = DataLoader(scaledata, collate_fn=collate_fn, batch_size=big_batch_size)

print('identity')
ks_compare(refbb, valbb);
print('\ntranslate')
ks_compare(refbb, corr1bb);
print('\nshot_noise')
ks_compare(refbb, corr2bb);
print('\nmotion_blur')
ks_compare(refbb, corr3bb);
print('\nscale')
ks_compare(refbb, corr4bb);

### _Summary_

We demonstrate a method for investigating individual OOD examples using metadata, by finding which metadata feature is most unusual for each example, relative to the reference dataset.

We also demonstrate a method for finding significant distributional shifts in metadata features, and display the p-values of these shifts. We also compute a measure of the magnitudes of distributional shifts, relative to the width of their reference distributions.
