In [1]:
# Google Colab Only
try:
    import google.colab  # noqa: F401

    # specify the version of DataEval (==X.XX.X) for versions other than the latest
    %pip install -q dataeval
except Exception:
    pass

In [2]:
import numpy as np
from torch.utils.data import Subset

from dataeval.detectors.linters import Duplicates
from dataeval.utils.data import Metadata
from dataeval.utils.data.datasets import MNIST

In [3]:
# Load in the mnist dataset
testing_dataset = MNIST(root="./data/", image_set="test", unit_interval=True)

# Get the labels
labels = Metadata(testing_dataset).targets.labels

In [4]:
# Creating some indices to duplicate
print("Exact duplicates")
duplicates = {}
for i in [1, 2, 5, 9]:
    matching_indices = np.where(labels == i)[0]
    print(f"\t{i} - ({matching_indices[23]}, {matching_indices[78]})")
    duplicates[int(matching_indices[78])] = int(matching_indices[23])

Exact duplicates
	1 - (239, 789)
	2 - (238, 788)
	5 - (230, 780)
	9 - (235, 785)


In [5]:
# Create a subset with the identified duplicate indices swapped
indices_with_duplicates = [duplicates.get(i, i) for i in range(len(testing_dataset))]
duplicates_ds = Subset(testing_dataset, indices_with_duplicates)

In [6]:
# Initialize the Duplicates class to begin to identify duplicate images.
identifyDuplicates = Duplicates()

# Evaluate the data
results = identifyDuplicates.evaluate(duplicates_ds)

In [7]:
for category, images in results.dict().items():
    print(f"{category} - {len(images)}")
    print(f"\t{images}")

exact - 4
	[[230, 780], [235, 785], [238, 788], [239, 789]]
near - 72
	[[99, 4609], [219, 6699], [229, 2189, 2209, 7029, 7859, 7889, 8259], [257, 6287], [299, 379, 3219, 3409, 3879, 6449, 6529, 6749], [307, 4067], [449, 1539, 2929, 4839], [489, 1479], [529, 8329], [619, 1909, 5279, 8709], [649, 1359, 8249], [779, 2039, 3579], [809, 1039], [833, 2623], [947, 6957], [949, 8639], [969, 3239, 3479, 6479, 8689], [999, 7879], [1099, 6939], [1139, 3979, 5109, 5959], [1149, 1509, 4939, 5459, 7009, 7769], [1289, 3249, 4799], [1309, 7529], [1379, 4619], [1389, 6719], [1429, 2479, 2779], [1579, 2989], [1739, 6779], [1769, 5509], [1789, 3389, 8289], [1949, 4279], [2009, 3919], [2149, 2169, 2559, 4519], [2179, 3019], [2289, 6429], [2507, 8077], [2739, 4729], [2849, 5259], [2867, 5047], [3039, 4229], [3107, 7487], [3159, 3399, 4679, 7039], [3270, 5100], [3454, 8644], [3599, 8339], [3613, 3663], [3869, 6199], [3999, 5549], [4109, 5699], [4399, 7219], [4549, 7849], [4649, 7279], [4759, 8309], [4969, 5

In [8]:
### TEST ASSERTION CELL ###
assert len(results.exact) == 4
assert [230, 780] in results.exact
assert [235, 785] in results.exact
assert [238, 788] in results.exact
assert [239, 789] in results.exact