In [1]:
try:
    import google.colab  # noqa: F401

    # specify the version of DataEval (==X.XX.X) for versions other than the latest
    %pip install -q dataeval[tensorflow]
except Exception:
    pass

import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [2]:
import numpy as np

from dataeval.detectors.ood import OOD_AE, OOD_VAEGMM
from dataeval.utils.tensorflow.models import AE, VAEGMM, create_model
from dataeval.utils.torch.datasets import MNIST

E0000 00:00:1730287673.920437    1326 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1730287673.926248    1326 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# Load in the training mnist dataset and use the first 4000
train_ds = MNIST(root="./data/", train=True, download=True, size=4000, unit_interval=True, channels="channels_first")

# Split out the images and labels
images, labels = train_ds.data, train_ds.targets
input_shape = images[0].shape

Files already downloaded and verified


In [4]:
detectors = [
    OOD_AE(create_model(AE, input_shape)),
    OOD_VAEGMM(create_model(VAEGMM, input_shape)),
]

W0000 00:00:1730287678.205340    1326 gpu_device.cc:2344] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [5]:
for detector in detectors:
    print(f"Training {detector.__class__.__name__}...")
    detector.fit(images, threshold_perc=99, epochs=20, verbose=False)

Training OOD_AE...


Training OOD_VAEGMM...


In [6]:
corruption = MNIST(
    root="./data",
    train=True,
    download=False,
    size=2000,
    unit_interval=True,
    channels="channels_first",
    corruption="translate",
)
corrupted_images = corruption.data

Files already downloaded and verified


In [7]:
[(type(detector).__name__, np.mean(detector.predict(images).is_ood)) for detector in detectors]

[('OOD_AE', 0.01), ('OOD_VAEGMM', 0.01225)]

In [8]:
[(type(detector).__name__, np.mean(detector.predict(corrupted_images).is_ood)) for detector in detectors]

[('OOD_AE', 0.958), ('OOD_VAEGMM', 0.1115)]