In [1]:
# Google Colab Only
try:
    import google.colab  # noqa: F401

    # specify the version of DataEval (==X.XX.X) for versions other than the latest
    %pip install -q dataeval maite-datasets
except Exception:
    pass

In [2]:
from maite_datasets.image_classification import MNIST

from dataeval import Embeddings
from dataeval.core import divergence_fnn
from dataeval.encoders import NumpyFlattenEncoder

In [3]:
# Load in the training mnist dataset and use the first 4000
train_ds = MNIST(root="./data/", image_set="train", download=True)

# Create encoder with specified batch size
encoder = NumpyFlattenEncoder(batch_size=400)

# Extract the first 4000 embeddings
embeddings = Embeddings(train_ds, encoder=encoder)[:4000]

In [4]:
print("Number of samples: ", len(embeddings))
print("Image shape:", embeddings[0].shape)

Number of samples:  4000
Image shape: (784,)


In [5]:
data_a = embeddings[:2000]
data_b = embeddings[2000:]

In [6]:
div = divergence_fnn(data_a, data_b)
print(div)

{'divergence': np.float64(0.1855), 'errors': np.int64(1629)}




In [7]:
corrupted_ds = MNIST(root="./data", image_set="train", corruption="translate", download=True)

# Create encoder with batch size
corrupted_encoder = NumpyFlattenEncoder(batch_size=64)
corrupted_emb = Embeddings(corrupted_ds, encoder=corrupted_encoder)[:2000]

In [8]:
print("Number of corrupted samples: ", len(corrupted_emb))
print("Corrupted image shape:", corrupted_emb[0].shape)

Number of corrupted samples:  2000
Corrupted image shape: (784,)


In [9]:
div = divergence_fnn(data_a, corrupted_emb)
print(div)

{'divergence': np.float64(0.963), 'errors': np.int64(74)}


In [10]:
### TEST ASSERTION CELL ###
assert div["divergence"] > 0.95