In [1]:
# Google Colab Only
try:
    import google.colab  # noqa: F401

    # specify the version of DataEval (==X.XX.X) for versions other than the latest
    %pip install -q dataeval maite-datasets
except Exception:
    pass

In [2]:
import numpy as np
from maite_datasets.image_classification import MNIST

from dataeval.data import Metadata, Select
from dataeval.data.selections import Indices
from dataeval.evaluators.linters import Duplicates

In [3]:
# Load in the mnist dataset
testing_dataset = MNIST(root="./data/", image_set="test", download=True)

# Get the labels
labels = Metadata(testing_dataset).class_labels

Processing datum metadata:   0%|          | 0/10000 [00:00<?, ?it/s]

In [4]:
# Creating some indices to duplicate
print("Exact duplicates")
duplicates = {}
for i in [1, 2, 5, 9]:
    matching_indices = np.where(labels == i)[0]
    print(f"\t{i} - ({matching_indices[23]}, {matching_indices[78]})")
    duplicates[int(matching_indices[78])] = int(matching_indices[23])

Exact duplicates
	1 - (180, 663)
	2 - (249, 728)
	5 - (219, 866)
	9 - (212, 773)


In [5]:
# Create a subset with the identified duplicate indices swapped
indices_with_duplicates = [duplicates.get(i, i) for i in range(len(testing_dataset))]
duplicates_ds = Select(testing_dataset, Indices(indices_with_duplicates))

In [6]:
# Initialize the Duplicates class to begin to identify duplicate images.
identifyDuplicates = Duplicates()

# Evaluate the data
results = identifyDuplicates.evaluate(duplicates_ds)

Processing images for HashStat:   0%|          | 0/10000 [00:00<?, ?it/s]

In [7]:
for category, images in results.data().items():
    print(f"{category} - {len(images)}")
    print(f"\t{images}")

exact - 4
	[[180, 663], [212, 773], [219, 866], [249, 728]]
near - 96
	[[57, 4039], [176, 5872], [178, 1867, 1876, 6141, 6901, 6917, 7280, 8068], [203, 272, 2822, 3003, 3430, 5637, 5699, 5896, 9368, 9395, 9415], [204, 8418], [223, 6171], [255, 3969], [330, 1238, 2541, 4191, 9282], [348, 1193], [377, 7340], [416, 9324], [430, 1657, 4651, 7717, 9499], [476, 1040, 7270, 9348], [652, 1760, 3152], [675, 831], [745, 9836], [772, 7636], [783, 2827, 3070, 5651, 7686, 8459], [809, 6913], [889, 6073], [920, 3480, 4524, 5211, 9464], [929, 1213, 4273, 4774, 6125, 6799], [941, 6830], [964, 8672], [993, 2937], [1011, 2867, 4179], [1025, 6634], [1075, 4050], [1083, 5880], [1137, 2164, 2379], [1280, 2626], [1424, 5917], [1448, 4858], [1484, 2984, 7303], [1674, 3777], [1729, 3452], [1836, 1844, 2239, 4006, 8633, 8666, 9143, 9434], [1861, 2674], [1939, 5614], [2357, 4147], [2418, 8352], [2434, 4643], [2442, 7950], [2688, 3741], [2736, 4936], [2786, 2997, 4104, 6159, 8488], [2874, 9586], [3054, 7362], [3

In [8]:
### TEST ASSERTION CELL ###
assert len(results.exact) == len(duplicates)
for k, v in duplicates.items():
    assert [v, k] in results.exact