In [1]:
# Google Colab Only
try:
    import google.colab  # noqa: F401

    # specify the version of DataEval (==X.XX.X) for versions other than the latest
    %pip install -q dataeval
except Exception:
    pass

In [2]:
import numpy as np

from dataeval.metrics.estimators import ber
from dataeval.utils.data import Embeddings, Metadata
from dataeval.utils.data.datasets import MNIST


In [3]:
# Load in both the training and testing mnist dataset
train_ds = MNIST(root="./data/", train=True, size=6000, classes=[1, 4, 9], unit_interval=True)

# Split out the embeddings and labels
embeddings = Embeddings(train_ds, batch_size=64).to_tensor()
labels = Metadata(train_ds).targets.labels

Determining if data needs to be downloaded
Loaded data successfully
Running data preprocessing steps


In [4]:
print("Number of training samples: ", len(embeddings))
print("Image shape:", embeddings.shape)
print("Label counts: ", np.unique(labels, return_counts=True))

Number of training samples:  6000
Image shape: torch.Size([6000, 784])
Label counts:  (array([1, 4, 9]), array([2000, 2000, 2000]))


In [5]:
# Evaluate the BER metric for the MNIST data with digits 1, 4, 9.
# One minus the value of this metric gives our estimate of the upper bound on accuracy.
base_result = ber(embeddings, labels, method="MST")

In [6]:
print("The bayes error rate estimation:", base_result.ber)

The bayes error rate estimation: 0.022833333333333334


In [7]:
### TEST ASSERTION CELL ###
assert 0.976 < 1 - base_result.ber < 0.978

In [8]:
print("The maximum achievable accuracy:", 1 - base_result.ber)

The maximum achievable accuracy: 0.9771666666666666


In [9]:
# Creates a binary mask where current label == 1 that can be used as the new labels
labels_merged = labels == 1
print("New label counts:", np.unique(labels_merged, return_counts=True))

New label counts: (array([False,  True]), array([4000, 2000]))


In [10]:
# Evaluate the BER metric for the MNIST data with updated labels
new_result = ber(embeddings, labels_merged, method="MST")

In [11]:
print("The bayes error rate estimation:", new_result.ber)

The bayes error rate estimation: 0.005333333333333333


In [12]:
### TEST ASSERTION CELL ###
assert 0.994 < 1 - new_result.ber < 0.996

In [13]:
print("The maximum achievable accuracy:", 1 - new_result.ber)

The maximum achievable accuracy: 0.9946666666666667
