In [1]:
try:
    import google.colab  # noqa: F401

    # specify the version of DataEval (==X.XX.X) for versions other than the latest
    %pip install -q dataeval[tensorflow]
except Exception:
    pass

import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [2]:
import numpy as np

from dataeval._internal.datasets import MNIST
from dataeval.detectors.ood import OOD_AE, OOD_VAEGMM
from dataeval.tensorflow.models import AE, VAEGMM, create_model

In [3]:
# Load in the training mnist dataset and use the first 4000
train_ds = MNIST(root="./data/", train=True, download=True, size=4000, unit_interval=True, channels="channels_first")

# Split out the images and labels
images, labels = train_ds.data, train_ds.targets
input_shape = images[0].shape

Files already downloaded and verified


In [4]:
detectors = [
    OOD_AE(create_model(AE, input_shape)),
    OOD_VAEGMM(create_model(VAEGMM, input_shape)),
]

In [5]:
for detector in detectors:
    print(f"Training {detector.__class__.__name__}...")
    detector.fit(images, threshold_perc=99, epochs=20, verbose=False)

Training OOD_AE...


Training OOD_VAEGMM...


In [6]:
corruption = MNIST(
    root="./data",
    train=True,
    download=False,
    size=2000,
    unit_interval=True,
    channels="channels_first",
    corruption="translate",
)
corrupted_images = corruption.data

Files already downloaded and verified


In [7]:
[(type(detector).__name__, np.mean(detector.predict(images).is_ood)) for detector in detectors]

[('OOD_AE', 0.01), ('OOD_VAEGMM', 0.0085)]

In [8]:
[(type(detector).__name__, np.mean(detector.predict(corrupted_images).is_ood)) for detector in detectors]

[('OOD_AE', 0.994), ('OOD_VAEGMM', 0.011)]