In [None]:
import numpy as np

from daml._internal.metrics.stats import DatasetStats

In [None]:
# # MNIST Data
import hashlib
import os
import typing
from urllib.error import HTTPError, URLError
from urllib.request import urlretrieve


def download_mnist() -> str:
    """Code to download mnist originates from keras/datasets:

    https://github.com/keras-team/keras/blob/v2.15.0/keras/datasets/mnist.py#L25-L86
    """
    origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
    path = _get_file(
        "mnist.npz",
        origin=origin_folder + "mnist.npz",
        file_hash=("731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
    )

    return path


def _get_file(
    fname: str,
    origin: str,
    file_hash: typing.Optional[str] = None,
):
    cache_dir = os.path.join(os.path.expanduser("~"), ".keras")
    datadir_base = os.path.expanduser(cache_dir)
    if not os.access(datadir_base, os.W_OK):
        datadir_base = os.path.join("/tmp", ".keras")
    datadir = os.path.join(datadir_base, "datasets")
    os.makedirs(datadir, exist_ok=True)

    fname = os.fspath(fname) if isinstance(fname, os.PathLike) else fname
    fpath = os.path.join(datadir, fname)

    download = False
    if os.path.exists(fpath):
        if file_hash is not None and not _validate_file(fpath, file_hash):
            download = True
    else:
        download = True

    if download:
        try:
            error_msg = "URL fetch failure on {}: {} -- {}"
            try:
                urlretrieve(origin, fpath)
            except HTTPError as e:
                raise Exception(error_msg.format(origin, e.code, e.msg)) from e
            except URLError as e:
                raise Exception(error_msg.format(origin, e.errno, e.reason)) from e
        except (Exception, KeyboardInterrupt):
            if os.path.exists(fpath):
                os.remove(fpath)
            raise

        if os.path.exists(fpath) and file_hash is not None and not _validate_file(fpath, file_hash):
            raise ValueError(
                "Incomplete or corrupted file detected. "
                f"The sha256 file hash does not match the provided value "
                f"of {file_hash}.",
            )
    return fpath


def _validate_file(fpath, file_hash, chunk_size=65535):
    hasher = hashlib.sha256()
    with open(fpath, "rb") as fpath_file:
        for chunk in iter(lambda: fpath_file.read(chunk_size), b""):
            hasher.update(chunk)

    return str(hasher.hexdigest()) == str(file_hash)


mnist_path = download_mnist()

In [None]:
# Create
rng = np.random.default_rng(33)
size = 10000

with np.load(mnist_path, allow_pickle=True) as fp:
    test_images, labels = fp["x_train"][:size], fp["y_train"][:size]

norm_test_imgs = np.repeat(test_images[:, np.newaxis, :, :], 3, axis=1) / 255
jitter = rng.integers(10, size=norm_test_imgs.shape)
norm_test_imgs += jitter


# rng.shuffle(test_images)
# rng.shuffle(norm_test_imgs)

print(test_images.shape)
print(norm_test_imgs.shape)

In [None]:
# dataset_stats = DatasetStats(test_images)
dataset_multistats = DatasetStats(test_images[:100])
# imagestats = SingleImageStats(norm_test_imgs[0])
# image_stats = ImageStats(test_images[0])
# image_stats.__dict__

In [None]:
# dataset_stats.image_stats[0].__dict__
dataset_multistats.ch_percentiles
mask = dataset_multistats.get_channel_mask(3, 2)
dataset_multistats.ch_map[mask][:, 0].shape

In [None]:
from importlib import reload

import linter

import daml._internal.metrics.hash as hasher
import daml._internal.metrics.stats as stats

reload(stats)
reload(hasher)
reload(linter)

count = 5000
lint = linter.Linter(norm_test_imgs[:count])
results = lint.get_outliers("modzscore", 3.75)
print(f"{len(results)} ({round(100*len(results)/count,2)}%) outliers found.")
results

In [None]:
dupes = lint.get_duplicates()
print(f"{len(dupes['exact'])} ({round(100*len(dupes['exact'])/count,2)}%) exact duplicates found.")
print(f"{len(dupes['near'])} ({round(100*len(dupes['near'])/count,2)}%) near duplicates found.")
dupes

In [None]:
from PIL import Image

Image.fromarray(test_images[23]).show()
Image.fromarray(test_images[4383]).show()
Image.fromarray(test_images[80]).show()
Image.fromarray(test_images[2448]).show()