In [1]:
# Google Colab Only
try:
    import google.colab  # noqa: F401

    # specify the version of DataEval (==X.XX.X) for versions other than the latest
    %pip install -q dataeval
except Exception:
    pass

In [2]:
import numpy as np
from maite_datasets.image_classification import MNIST

from dataeval.data import Embeddings, Metadata, Select
from dataeval.data.selections import ClassBalance, ClassFilter, Limit
from dataeval.metrics.estimators import ber

In [3]:
# Configure the dataset transforms
transforms = [
    lambda x: x / 255.0,  # scale to [0, 1]
    lambda x: x.astype(np.float32),  # convert to float32
]

# Load the train set of the MNIST dataset and apply transforms
train_ds = MNIST(root="./data/", image_set="train", transforms=transforms, download=True)

# Get the indices of the first 2000 samples for labels 1, 4, and 9
train_ds = Select(train_ds, selections=[Limit(6000), ClassFilter((1, 4, 9)), ClassBalance()])

# Split out the embeddings and labels
embeddings = Embeddings(train_ds, batch_size=64).to_tensor()
labels = Metadata(train_ds).class_labels

print(train_ds)

Processing datum metadata:   0%|          | 0/6000 [00:00<?, ?it/s]

Select Dataset
--------------
    Selections: [Limit(size=6000), ClassFilter(classes=(1, 4, 9), filter_detections=True), ClassBalance()]
    Selected Size: 6000

MNIST Dataset
-------------
    Corruption: None
    Transforms: [<function <lambda> at 0x7c3ee86e3ec0>, <function <lambda> at 0x7c3ee86e3f60>]
    Image_set: train
    Metadata: {'id': 'MNIST_train', 'index2label': {0: 'zero', 1: 'one', 2: 'two', 3: 'three', 4: 'four', 5: 'five', 6: 'six', 7: 'seven', 8: 'eight', 9: 'nine'}, 'split': 'train'}
    Path: /builds/jatic/aria/dataeval/docs/source/notebooks/data/mnist
    Size: 60000


In [4]:
print("Number of training samples: ", len(embeddings))
print("Image shape:", embeddings.shape)
print("Label counts: ", np.unique(labels, return_counts=True))

Number of training samples:  6000
Image shape: torch.Size([6000, 784])
Label counts:  (array([1, 4, 9]), array([2000, 2000, 2000]))


In [5]:
# Evaluate the BER metric for the MNIST data with digits 1, 4, 9.
# One minus the value of this metric gives our estimate of the upper bound on accuracy.
base_result = ber(embeddings, labels, method="MST")

In [6]:
print("The bayes error rate estimation:", base_result.ber)

The bayes error rate estimation: 0.024833333333333332


In [7]:
### TEST ASSERTION CELL ###
assert 0.97 < 1 - base_result.ber < 0.98

In [8]:
print("The maximum achievable accuracy:", 1 - base_result.ber)

The maximum achievable accuracy: 0.9751666666666666


In [9]:
# Creates a binary mask where current label == 1 that can be used as the new labels
labels_merged = labels == 1
print("New label counts:", np.unique(labels_merged, return_counts=True))

New label counts: (array([False,  True]), array([4000, 2000]))


In [10]:
# Evaluate the BER metric for the MNIST data with updated labels
new_result = ber(embeddings, labels_merged, method="MST")

In [11]:
print("The bayes error rate estimation:", new_result.ber)

The bayes error rate estimation: 0.005


In [12]:
### TEST ASSERTION CELL ###
assert 0.994 < 1 - new_result.ber < 0.996

In [13]:
print("The maximum achievable accuracy:", 1 - new_result.ber)

The maximum achievable accuracy: 0.995
