# Tutorial on associating changes in metadata with observed dataset drifts.


## _Problem Statement_

When incoming data have been found to have drifted, we may wish to understand the underlying causes of the drift. Metadata may help with this task. We can look for metadata features that are predictive of out-of-distribution examples (OOD), or, if OOD are few, we can examine the significance and magnitude of distributional shifts of metadata factors.


### _When to use_

Once you have detected drift using the `dataeval.detectors` drift detection classes, you should employ these tools to look either for metadata features that accurately predict out-of-distribution (OOD) examples, or for significant differences in two metadata distributions, i.e. the metadata corresponding to your reference and drifted datasets.


### _What you will need_

1. A reference dataset which you have used to train a model.
2. A drifted dataset, i.e. new data for which you have detected drift (see DriftDetectionTutorial; steps reproduced here), with corresponding metadata.
3. Metadata corresponding to each of your two datasets, OR defined methods that generate intrinsic metadata from the data examples, for each dataset.
4. A python environment with the following packages installed:
   - `dataeval[torch]` or `dataeval[all]`


### _New tools developed_

1. predict_ood_mi() - a standalone function that quantifies the power of metadata features to predict OOD examples
2. ks_compare() - a standalone function that compares incoming metadata features to a reference and reports normalized shifts and their significance


### _Setting up_

Let's import the required libraries needed to set up a minimal working example


In [None]:
from functools import partial

import numpy as np
import matplotlib.pyplot as plt

import torch

from torch.utils.data import DataLoader

from metadata_utils import InstanceMNIST, collate_fn_2
from metadata_tools import predict_ood_mi, ks_compare

from dataeval.detectors.drift import (
    DriftCVM,
    DriftKS,
    DriftMMD,
    preprocess_drift,
)
from dataeval.utils.torch.models import AE_torch
from dataeval.utils.torch.datasets import MNIST

from dataeval.detectors.ood.ae_torch import OOD_AE
from dataeval.utils.torch.models import AriaAutoencoder
from dataeval.utils.torch.trainer import AETrainer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'device is {device}.')


## Loading in data

Let's start by loading some TensorFlow MNIST datasets. We will then wrap them into Pytorch datasets, attach some metadata, and then examine them. The new PyTorch class [InstanceMNIST](metadata_utils.py) will store the MNIST data and compute/store intrinsic metadata for the purposes of this demo.

The **init** method of InstanceMNIST takes 2 keyword args: a corruption type if any, and a split if any. Possible corruption types are listed inside the get_MNIST_data() method and also stored as a corruptions attribute in each instance. For more information see https://www.tensorflow.org/datasets/catalog/mnist_corrupted .


Besides storing images and labels, the wrapper class InstanceMNIST lets us add methods to generate and store whatever intrinsic metadata we want. The existing methods provide a template for doing this, see e.g. bbox(). In addition to a function that returns the "one scalar per image" quantity you want to compute, you need to package the quantity as a dict of lists, i.e. each metadata feature name will correspond to a dict key, and each key will refer to a list of metadata values corresponding to the examples in the dataset.

InstanceMNIST also explicitly normalizes MNIST pixel values to be between 0 and 1, and casts to numpy float32.


### Instantiate a reference dataset and some corrupted datasets

We can then use these in our experiments below. The corruptions are intended to simulate drifts that we might observe in practice.

2024-10-22 Need to change InstanceMNIST so that it grabs some of each of a list of corruptions. Then get an actual dataset for a particualr corruption with a method call.


In [None]:
corruption_list = ['identity', 'translate', 'shot_noise', 'motion_blur', 'scale']

mnist = InstanceMNIST(corruption_list, train=True, size=8000)
mnist_val = InstanceMNIST('identity', train=False, size=8000)

refdata = mnist.identity
valdata = mnist_val.identity
shiftdata = mnist.translate
spikydata = mnist.shot_noise
blurdata = mnist.motion_blur
scaledata = mnist.scale



## Make the drift detectors

In order to reduce the dimensionality of the data, we can pass a simple Autoencoder to the drift detectors using the the `preprocess_fn` keyword arg. While this is not crucial for the MNIST data set, it is highly recommended for datasets that have higher dimensionality, to reduce the number of comparisons made.

For the purposes of the tutorial, we will use 3 forms of drift detectors: Maximum Mean Discrepancy (MMD), Cramér-von Mises (CVM), and Kolmogorov-Smirnov (KS). These detectors are built using data_reference and thus measure drifts relative to that dataset.


In [None]:
model = AriaAutoencoder(channels=1)
trainer = AETrainer(model, device=device)

We now train the autoencoder to reconstruct examples contained in the refdata dataset. (Since our corruptions here are so heavy-handed and easily detected, we included the option to skip training the autoencoder, but left in the snippet that shows how to do so.)

Then, we use the encoder part of the autoencoder (trained or not) to generate projections into a latent space, i.e. embeddings, and we use the embeddings to decide whether or not new incoming datasets have drifted relative to the reference.


In [None]:
train_the_autoencoder = True
if train_the_autoencoder:
    print(f"Training {model.__class__.__name__}...")
    training_loss = trainer.train(refdata, epochs=10)
else:
    print('NOT TRAINING AUTOENCODER!')
    
encoder_net = model.encoder.to(device)


We will build the drift detectors using the first 2000 images from refdata.


In [None]:
data_reference = refdata.images[0:2000]

# define preprocessing function
preprocess_fn = partial(preprocess_drift, model=encoder_net, batch_size=64, device=device)

# initialise drift detectors
drift_detectors = [detector(data_reference, preprocess_fn=preprocess_fn) for detector in [DriftMMD, DriftCVM, DriftKS]]

## Test reference against control

Let's check for drift between the first 2000 images and the second 2000 images from this sample. The drift detector should not detect any drift.


In [None]:
data_control = valdata.images[2000:4000]

Run the test by calling the predict() method on each detector, with data_control as an argument. Then examine the is_drift attribute.


In [None]:
print('Test two samples from same dataset for drift:')
[(type(detector).__name__, detector.predict(data_control).is_drift) for detector in drift_detectors]

Thus we assess, for all drift detectors, that there is no significant drift between these two MNIST subsets, as expected.


## Look for drift with translated data

The translate corruption moves each digit within its image, towards a randomly selected corner, by a few pixels.


In [None]:
corrupt_images = shiftdata.images[0:2000]
print('Test corrupted image sample for drift:')
[(type(detector).__name__, detector.predict(corrupt_images).is_drift) for detector in drift_detectors]

So the translate corruption indeed leads to a measurable drift. Such a drift might be the footprint of data examples in the test set which are OOD for the reference set. We will make a data loader that grabs a batch of 2000 and use it to look for OOD examples.


Call the dataloader "refbb": it will grab a big batch of reference data. After getting a batch, we need to adjust the shape of its images, for compatibility with the OOD detectors. We will do the same for a second dataloader that will hold images for validation.


In [None]:
big_batch_size = 2000
collate_fn = collate_fn_2


refbb = DataLoader(refdata, collate_fn=collate_fn, batch_size=big_batch_size)
valbb = DataLoader(valdata, collate_fn=collate_fn, batch_size=big_batch_size)

for train_images, _, _ in refbb:
    break
for val_images, _, _ in valbb:
    break


# Now adjust shape of images array, i.e. append a 1 to the shape. 
input_shape = (*train_images[0].shape, 1)

bbshape = train_images.shape  # DON'T adjust the batch shape

train_images = train_images.reshape(bbshape).detach().numpy()
val_images = val_images.reshape(bbshape).detach().numpy()

We can now use the adjusted input shape to instantiate OOD detectors.


The next cell trains the OOD detectors, which will take a few minutes, but you only need to do it once for each reference dataset. After training an OOD detector, you can test as many incoming datasets as you care to for OOD examples.


In [None]:

# Load in the training mnist dataset and use the first 4000
train_ds = MNIST(root="./data/", train=True, download=True, size=4000, unit_interval=True, channels="channels_first")

# Split out the images and labels
images, labels = train_ds.data, train_ds.targets
input_shape = images[0].shape

OOD_detectors = [OOD_AE(AE_torch((1, 28, 28)))]

for detector in OOD_detectors:
    print(f"Training {detector.__class__.__name__}...")
    detector.fit(images, threshold_perc=99, epochs=10, verbose=True)

We can now make a few data loaders for corrupted images. These have bb in their names, for "big batch".


In [None]:
corr1bb = DataLoader(shiftdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr2bb = DataLoader(spikydata, collate_fn=collate_fn, batch_size=big_batch_size)
corr3bb = DataLoader(blurdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr4bb = DataLoader(scaledata, collate_fn=collate_fn, batch_size=big_batch_size)

We will choose the translate corruption and run the OOD detectors.


In [None]:
corrbb = corr1bb # corr1bb is translate corruption

for corrimages, corrlabels, corrmetadata in corrbb:
    break

In [None]:
corrbb = corr1bb # corr1bb is translate corruption

for corrimages, corrlabels, corrmetadata in corrbb:
    break

print('What fraction of reference images have OOD examples?')
print([(type(detector).__name__, np.mean(detector.predict(train_images).is_ood)) for detector in OOD_detectors])
print('\nWhat fraction of images from drifted dataset have OOD examples?')
print([(type(detector).__name__, np.mean(detector.predict(corrimages).is_ood)) for detector in OOD_detectors])

And indeed, our OOD_AE detector reports that most images in the mnist-translate dataset are OOD relative to MNIST itself. We can now ask if any available metadata features are predictive of OOD.


In [None]:
ood_ae = OOD_detectors[0]
is_ood = ood_ae.predict(corrimages).is_ood

Training loss and validation loss should be comparable. If validation loss is much greater than training loss, it means we trained the OOD model for too long. (Another symptom of overtraining, here in this notebook, is an excess of OOD detections on other samples of the reference dataset, i.e. the "identity" corruption will show a high rate of detections, when it should actaully be near zero.)


In [None]:
train_pred = ood_ae.predict(train_images)
train_loss = np.mean(train_pred.instance_score)

val_pred = ood_ae.predict(val_images)
val_loss = np.mean(val_pred.instance_score)

print(f'training loss reached {train_loss:0.4f}.')
print(f'validation loss is {val_loss:0.4f}.')


The function [predict_ood_mi](metadata_tools.py#predict_ood_mi) takes big batch dataloaders and an ood detector, and returns a list of metadata factors, sorted in decreasing order of the mutual information they share with the OOD flag.


We can test predict_ood_mi() on our 3 corrupted datasets. First there is translate, in which digits have been moved toward a randomly selected corner of their images.

Note that I have added a random feature to the metadata. This can easily be done with any dataset, just pair a random value with every data example. Doing so provides a point of reference
for evaluating whether any metadata association test we might perform is meaningful or not.

You can specify whether your metadata features are continuous or discrete, through a keyword arg. Here all the features are continuous, so I will set discrete_features to False. If you have a mixture of continuous and discrete, you could have one bool for each feature, e.g. you have 3 features and only the first is discrete --> discrete_features = [True, False, False].


In [None]:
print('translate')
predict_ood_mi(refbb, corr1bb, ood_ae, discrete_features=False);

The centers of mass and the centers of the bboxes are the most predictive metadata, as we might expect; in fact they carry most of the bit that the OOD flag represents.

Next we try the shot_noise corruption, which adds random values to the nonzero pixels; we might expect the spikiness measure to associate with this.


In [None]:
print('shot noise')
predict_ood_mi(refbb, corr2bb, ood_ae, discrete_features=False);

Finally, the scale corruption (usually) has somewhat weaker associations between metadata and OOD, relative to the translate corruption. Spikiness is oddly strong; perhaps the scaling operation also smooths adjacent pixel values as a side effect, so that the **absence** of spikiness predicts OOD to some small degree; mutual information can detect this also.


In [None]:
print('scale')
predict_ood_mi(refbb, corr4bb, ood_ae, discrete_features=False);

### Explore causes of drift with scaled data

In cases where the metadata are more weakly associated with OOD examples, we might dig deeper for an explanation of the drift. We will try this with scale.

First confirm that scale data indeed results in detection of drift.


In [None]:
corrupt_images = scaledata.images[:2000]
print('Test scaled digits for dataset drift:')
[(type(detector).__name__, detector.predict(corrupt_images).is_drift) for detector in drift_detectors]

So the scale corruption also generates data drift. Can the drift be attributed to OOD examples?


In [None]:
corrbb = DataLoader(scaledata, collate_fn=collate_fn, batch_size=big_batch_size)

for corrimages, corrlabels, corrmetadata in corrbb:
    break

print('What fraction of scaled digits are OOD?')
print([(type(detector).__name__, np.mean(detector.predict(corrimages).is_ood)) for detector in OOD_detectors])

We can of course run the OOD detector on images made with each the other corruptions, to see which has the most OOD examples.


In [None]:
corruptions = mnist.corruptions
mnist_all = InstanceMNIST(mnist.corruptions, size=3750)

c_ood_dict = {}
for c in corruptions:
    i0 = np.random.randint(0, len(refdata)-big_batch_size)
    i1 = i0 + big_batch_size
    split="train[" + str(i0) + ":" + str(i1)+"]"
    cdata = getattr(mnist_all, c)

    cbb = DataLoader(cdata, collate_fn=collate_fn, batch_size=big_batch_size)
    for corrimages, _, _ in cbb:
        break

    c_ood_dict.update({c: [np.mean(detector.predict(corrimages).is_ood) for detector in OOD_detectors]})

ood_frac = [v[0] for v in c_ood_dict.values()]
iord = np.argsort(ood_frac)
names = [k for k in c_ood_dict]
maxlen = max([len(name) for name in names])

hdr = 'corruption'
print(f'{hdr:{maxlen}} |  ood fraction')
print('='*(maxlen+15))
for i in iord:
        print(f'{names[i]:{maxlen+1}}:      {ood_frac[i]:.3f}')

Note that the identity corruption, which is actually not a corruption at all, can nevertheless yield some OOD detections. One can evaluate the true prevalance of OOD by comparing the rate of these false detections with other putative detections. In other words, we can sometimes have an overly sensitive OOD detector, in which case we might get a lot of detections even when testing the reference distribution against itself. In the case shown above, we can see that the scale, glass_blur, and motion_blur corruptions have essentially no examples detected as OOD.


### Compare metadata features using KS two-sample test.

In cases where OOD examples are relatively few, we might want to ask if there is anything in the metadata that has undergone a distributional change that might be contributing to the observed drift.

The function [ks_compare()](metadata_tools.py#ks_compare) does a feature-wise KS test on the drifted metadata, relative to reference metadata. It lumps batches together until the Kolmogorov-Smirnov test statistic has reached a stable value. Thus, small batch dataloaders can be passed to ks_compare() without leading to underpowered statistical tests.

ks_compare returns a dict containing ks_2samp results for each metadata feature. It also prints the p-value for each feature in ascending order, i.e. most significant first.


Make some new dataloaders with more typical batch sizes of 11, using the same datasets created above, and use them to look for metadata distributional shifts using ks_compare().


In [None]:
collate_fn = collate_fn_2
batch_size = 11
refdl = DataLoader(refdata, collate_fn=collate_fn, batch_size=batch_size)
valdl = DataLoader(valdata, collate_fn=collate_fn, batch_size=batch_size)
refdlc1 = DataLoader(shiftdata, collate_fn=collate_fn, batch_size=batch_size)
refdlc2 = DataLoader(spikydata, collate_fn=collate_fn, batch_size=batch_size)
refdlc3 = DataLoader(blurdata, collate_fn=collate_fn, batch_size=batch_size)
refdlc4 = DataLoader(scaledata, collate_fn=collate_fn, batch_size=batch_size)

For example, we want to know if a drift of any metadata features might explain the drift observed for scaled data. (The MNIST corruptions are quite dramatic and it is fairly obvious what is leading to the detected drift. But it is possible to detect drift in cases where the OOD detector fails to see OOD examples; in such cases one would use ks_compare to find subtle distributional shifts in metadata factors.)

First compare metadata factors from the training and validation datasets. We expect to see mostly high p-values for all metadata features.


In [None]:
ks_compare(refdl, valdl);

What about for the scale corruption, which yielded few OOD detections?


In [None]:
print('scale')
res = ks_compare(refdl, refdlc4)

Most of the available metadata factors have undergone significant distributional changes, as seen by the small p-values from the KS test. The shift metric shows how far the distribution has shifted on average, in units of the width of the reference distribution. The values are sorted from greatest to least statistical significance (increasing p-values).

We can of course also run ks_compare on other corrupted datasets.


In [None]:
print('translate')
res1 = ks_compare(refdl, refdlc1)
print('\n\nshot_noise')
res2 = ks_compare(refdl, refdlc2)
print('\n\nmotion_blur')
res3 = ks_compare(refdl, refdlc3)
print('\n\nscale')
res4 = ks_compare(refdl, refdlc4)


### _Summary_

We demonstrate a method for associating OOD examples with metadata features, in cases where dataset drift has been observed, using mutual information. By including a random metadata feature, we can evaluate whether a given association, however weak, could be considered significant or not.

We also demonstrate a method for deciding whether a set of detected OOD examples is mostly real or not: by running the OOD detector on an uncorrupted version of the reference distribution.

We also demonstrate a method for finding significant distributional shifts in metadata features, and display the p-values of these shifts. Here again, including a random metadata feature provides a sanity check. We also compute a measure of the magnitudes of distributional shifts, relative to the width of their reference distributions.
