In [1]:
# Google Colab Only
try:
    import google.colab  # noqa: F401

    # specify the version of DataEval (==X.XX.X) for versions other than the latest
    %pip install -q dataeval maite-datasets
except Exception:
    pass

In [2]:
from maite_datasets.image_classification import MNIST

from dataeval.data import Embeddings
from dataeval.metrics.estimators import divergence

In [3]:
# Load in the training mnist dataset and use the first 4000
train_ds = MNIST(root="./data/", image_set="train", download=True)

# Extract the first 4000 embeddings
embeddings = Embeddings(train_ds, batch_size=400)[:4000]

In [4]:
print("Number of samples: ", len(embeddings))
print("Image shape:", embeddings[0].shape)

Number of samples:  4000
Image shape: torch.Size([784])


In [5]:
data_a = embeddings[:2000]
data_b = embeddings[2000:]

In [6]:
div = divergence(data_a, data_b)
print(div)

{'divergence': np.float64(0.1855), 'errors': np.int64(1629)}


In [7]:
corrupted_ds = MNIST(root="./data", image_set="train", corruption="translate", download=True)
corrupted_emb = Embeddings(corrupted_ds, batch_size=64)[:2000]

In [8]:
print("Number of corrupted samples: ", len(corrupted_emb))
print("Corrupted image shape:", corrupted_emb[0].shape)

Number of corrupted samples:  2000
Corrupted image shape: torch.Size([784])


In [9]:
div = divergence(data_a, corrupted_emb)
print(div)

{'divergence': np.float64(0.963), 'errors': np.int64(74)}


In [10]:
### TEST ASSERTION CELL ###
assert div.divergence > 0.95