In [1]:
try:
    import google.colab  # noqa: F401

    # change the version (==vX.XX.X) in the statement below in order to get the latest version of dataeval.
    %pip install -q dataeval[torch]
except Exception:
    pass

In [2]:
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2

from dataeval.detectors.linters import Duplicates

In [3]:
# Load in the mnist dataset
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
testing_dataset = datasets.MNIST(root="./data/", train=False, transform=to_tensor, download=True)
test_data = np.array(testing_dataset.data, dtype=float)
labels = np.array(testing_dataset.targets)

In [4]:
# Creating some duplicates
print("Exact duplicates")
duplicates = {}
for i in [1, 2, 5, 9]:
    matching_indices = np.where(labels == i)[0]
    test_data[matching_indices[78]] = test_data[matching_indices[23]]
    print(f"\t{i} - ({matching_indices[23]}, {matching_indices[78]})")
    duplicates[i] = (matching_indices[23], matching_indices[78], matching_indices[2])

Exact duplicates
	1 - (180, 663)
	2 - (249, 728)
	5 - (219, 866)
	9 - (212, 773)


In [5]:
print("Number of samples: ", len(test_data))

Number of samples:  10000


In [6]:
# Initialize the Duplicates class
duplicator = Duplicates()

# Evaluate the data
results = duplicator.evaluate(test_data)

In [7]:
for category, images in results.dict().items():
    print(f"{category} - {len(images)}")
    print(f"\t{images}")

exact - 4
	[[180, 663], [212, 773], [219, 866], [249, 728]]
near - 96
	[[57, 4039], [176, 5872], [178, 1867, 1876, 6141, 6901, 6917, 7280, 8068], [203, 272, 2822, 3003, 3430, 5637, 5699, 5896, 9368, 9395, 9415], [204, 8418], [223, 6171], [255, 3969], [330, 1238, 2541, 4191, 9282], [348, 1193], [377, 7340], [416, 9324], [430, 1657, 4651, 7717, 9499], [476, 1040, 7270, 9348], [652, 1760, 3152], [675, 831], [745, 9836], [772, 7636], [783, 2827, 3070, 5651, 7686, 8459], [809, 6913], [889, 6073], [920, 3480, 4524, 5211, 9464], [929, 1213, 4273, 4774, 6125, 6799], [941, 6830], [964, 8672], [993, 2937], [1011, 2867, 4179], [1025, 6634], [1075, 4050], [1083, 5880], [1137, 2164, 2379], [1280, 2626], [1424, 5917], [1448, 4858], [1484, 2984, 7303], [1674, 3777], [1729, 3452], [1836, 1844, 2239, 4006, 8633, 8666, 9143, 9434], [1861, 2674], [1939, 5614], [2357, 4147], [2418, 8352], [2434, 4643], [2442, 7950], [2688, 3741], [2736, 4936], [2786, 2997, 4104, 6159, 8488], [2874, 9586], [3054, 7362], [3