In [None]:
from typing import Optional

import numpy as np
import scipy as sp
from PIL import Image, ImageFilter

In [None]:
# # MNIST Data
import hashlib
import os
import typing
from urllib.error import HTTPError, URLError
from urllib.request import urlretrieve


def download_mnist() -> str:
    """Code to download mnist originates from keras/datasets:

    https://github.com/keras-team/keras/blob/v2.15.0/keras/datasets/mnist.py#L25-L86
    """
    origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
    path = _get_file(
        "mnist.npz",
        origin=origin_folder + "mnist.npz",
        file_hash=("731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
    )

    return path


def _get_file(
    fname: str,
    origin: str,
    file_hash: typing.Optional[str] = None,
):
    cache_dir = os.path.join(os.path.expanduser("~"), ".keras")
    datadir_base = os.path.expanduser(cache_dir)
    if not os.access(datadir_base, os.W_OK):
        datadir_base = os.path.join("/tmp", ".keras")
    datadir = os.path.join(datadir_base, "datasets")
    os.makedirs(datadir, exist_ok=True)

    fname = os.fspath(fname) if isinstance(fname, os.PathLike) else fname
    fpath = os.path.join(datadir, fname)

    download = False
    if os.path.exists(fpath):
        if file_hash is not None and not _validate_file(fpath, file_hash):
            download = True
    else:
        download = True

    if download:
        try:
            error_msg = "URL fetch failure on {}: {} -- {}"
            try:
                urlretrieve(origin, fpath)
            except HTTPError as e:
                raise Exception(error_msg.format(origin, e.code, e.msg)) from e
            except URLError as e:
                raise Exception(error_msg.format(origin, e.errno, e.reason)) from e
        except (Exception, KeyboardInterrupt):
            if os.path.exists(fpath):
                os.remove(fpath)
            raise

        if (
            os.path.exists(fpath)
            and file_hash is not None
            and not _validate_file(fpath, file_hash)
        ):
            raise ValueError(
                "Incomplete or corrupted file detected. "
                f"The sha256 file hash does not match the provided value "
                f"of {file_hash}.",
            )
    return fpath


def _validate_file(fpath, file_hash, chunk_size=65535):
    hasher = hashlib.sha256()
    with open(fpath, "rb") as fpath_file:
        for chunk in iter(lambda: fpath_file.read(chunk_size), b""):
            hasher.update(chunk)

    return str(hasher.hexdigest()) == str(file_hash)


mnist_path = download_mnist()

In [None]:
# Create
rng = np.random.default_rng(33)
size = 25

with np.load(mnist_path, allow_pickle=True) as fp:
    test_images, labels = fp["x_train"][:size], fp["y_train"][:size]

norm_test_imgs = np.repeat(test_images[:, np.newaxis, :, :], 3, axis=1) / 255
jitter = rng.integers(10, size=norm_test_imgs.shape)
norm_test_imgs += jitter


rng.shuffle(test_images)
rng.shuffle(norm_test_imgs)

print(test_images.shape)
print(norm_test_imgs.shape)

In [None]:
#########
# All of these functions work on a single image assuming that channels is first
# Assuming image is an ndarray


def _slice_to_3_dimensions(array):
    # Calculate how many dimensions need to be sliced
    num_slices_needed = array.ndim - 3

    # Generate a slicing tuple to keep the last three dimensions
    # and take the first element of the rest
    slice_tuple = (0,) * num_slices_needed + (slice(None),) * 3

    # Apply the slicing tuple to the array
    sliced_array = array[slice_tuple]

    return sliced_array


def _use_PIL(image):
    adj_img = np.moveaxis(image, 0, -1)
    print(adj_img.dtype)
    im = Image.fromarray(np.moveaxis(image, 0, -1))
    gray = im.convert("L")
    ed = gray.filter(ImageFilter.FIND_EDGES)
    return np.array(ed)


def _edge_filter(image):
    offset = 0.5
    kernel = np.ones((3, 3), np.uint8) * -1
    kernel[1, 1] = 8
    img = np.sum(image, axis=0).astype(np.float32)
    edges = np.zeros_like(img, dtype=np.float32)

    for y in range(1, img.shape[0] - 1):
        for x in range(1, img.shape[1] - 1):
            region = img[y - 1 : y + 2, x - 1 : x + 2]
            edges[y, x] = np.sum(region * kernel) + offset

    return edges.astype(np.uint8)


class ImageStats:
    def __init__(self, image: np.ndarray):
        self.image = image
        # Potentially need to add a check to make sure the image contains values
        self.get_channels()
        self.get_size_and_aspect_ratio()
        self.get_image_bit()
        self.get_missing_and_zero()
        self.get_basic_stats_per_band()
        self.get_histogram()
        self.get_brightness()
        self.get_blurriness()
        self.get_entropy()
        self.image = np.array([])

    def get_channels(self):
        dim = self.image.ndim
        if dim == 2:
            self.bands = 1
            self.image = np.expand_dims(self.image, axis=0)
        elif dim == 3:
            self.bands = self.image.shape[0]
        elif dim > 3:
            print(
                "Image has more than 3 dimensions. \
                  This is for single images, not batches or videos. \
                  Selecting the first index in the beginning dimensions \
                  for continued processing."
            )
            self.bands = self.image.shape[-3]
            self.image = _slice_to_3_dimensions(self.image)
        else:
            raise ValueError("You provided a 1-D array, not an image.")

    def get_size_and_aspect_ratio(self):
        self.height = self.image.shape[-2]
        self.width = self.image.shape[-1]

        self.size = self.height * self.width
        self.aspect_ratio = min(self.width / self.height, self.height / self.width)

    def get_image_bit(self):
        max_val = np.max(self.image)
        min_val = np.min(self.image)

        self.rescale = True
        if min_val < 0:
            self.bit_range = (min_val, max_val)
        elif max_val <= 1:
            self.bit_range = (0, 1)
            self.rescale = False
        elif max_val < 2**8:
            self.bit_range = (0, 2**8 - 1)
        elif max_val < 2**12:
            self.bit_range = (0, 2**12 - 1)
        elif max_val < 2**16:
            self.bit_range = (0, 2**16 - 1)
        else:
            self.bit_range = (0, 2**32 - 1)

    def get_missing_and_zero(self):
        self.missing = np.sum(np.isnan(self.image))
        self.zero = self.size - np.count_nonzero(self.image, axis=(1, 2))

    def get_basic_stats_per_band(self):
        self.mean = np.mean(self.image, axis=(1, 2))
        self.var = np.var(self.image, axis=(1, 2))
        self.skew = sp.stats.skew(self.image, axis=(1, 2))
        self.kurtosis = sp.stats.kurtosis(self.image, axis=(1, 2))
        # self.range = np.hstack([
        #     np.min(self.image, axis=(1,2)).T,
        #     np.max(self.image, axis=(1,2)).T
        # ])
        # self.median = np.median(self.image, axis=(1,2))
        # Below code also implements the above range and median
        self.percentiles = np.percentile(
            self.image, q=[0, 25, 50, 75, 100], axis=(1, 2)
        ).T  # this gives back array (bands, # of percentiles)
        if self.bands == 1:
            self.percentiles = self.percentiles[np.newaxis, :]

    def get_histogram(self):
        # Depending on max and min values this creates 256 bins of equal width
        # Might need to consider normalizing first
        # or determine how to define the range that works in all cases
        self.histogram = np.vstack(
            [
                np.histogram(
                    self.image[i, :, :],
                    bins=256,
                    range=self.bit_range,
                )[0]
                for i in range(self.bands)
            ]
        )

    def get_brightness(self):
        luma = np.array([0.2126, 0.7152, 0.0722])
        if self.rescale and self.bands == 3:
            self.avg_brightness = np.sum(
                luma * ((self.mean + self.bit_range[0]) / self.bit_range[1]) ** 2
            )
            adj_image = (self.image + self.bit_range[0]) / self.bit_range[1]
            self.brightness = (
                np.sum(luma[:, np.newaxis, np.newaxis] * adj_image**2) / self.size
            )
        elif self.rescale:
            self.avg_brightness = np.mean(
                (self.mean + self.bit_range[0]) / self.bit_range[1]
            )
            self.brightness = np.mean(
                np.sum(
                    (self.image + self.bit_range[0]) / self.bit_range[1], axis=(1, 2)
                )
                / self.size
            )
        elif self.bands == 3:
            self.avg_brightness = np.repeat(np.sum(luma * self.mean**2), 3)
            self.brightness = np.repeat(
                np.sum(luma[:, np.newaxis, np.newaxis] * self.image**2) / self.size, 3
            )
        else:
            self.avg_brightness = np.mean(self.mean)
            self.brightness = np.mean(np.sum(self.image, axis=(1, 2)) / self.size)

    def get_entropy(self):
        flat_hist = np.sum(self.histogram, axis=0) / self.bands
        flat_sum = flat_hist.sum()
        if flat_sum == 0:
            return 0

        probabilities = flat_hist / flat_sum
        probabilities = probabilities[probabilities > 0]

        self.entropy = -np.sum(probabilities * np.log2(probabilities))

    def get_blurriness(self):
        # if self.bands == 1 or self.bands == 3:
        #     edges = _use_PIL(self.image)
        # else:
        edges = _edge_filter(self.image)

        self.blurry = np.std(edges)


class DatasetStats:
    def __init__(
        self,
        images,
        labels: Optional[np.ndarray] = None,
        boxes: Optional[np.ndarray] = None,
    ) -> None:
        self.images = images
        self.labels = labels
        self.boxes = boxes

        self.img_height = np.zeros(self.images.shape[0])
        self.img_width = np.zeros(self.images.shape[0])
        self.img_size = np.zeros(self.images.shape[0])
        self.img_aspect_ratio = np.zeros(self.images.shape[0])
        self.img_channels = np.zeros(self.images.shape[0])
        self.img_missing = np.zeros(self.images.shape[0])
        self.img_brightness = np.zeros(self.images.shape[0])
        self.img_entropy = np.zeros(self.images.shape[0])
        self.img_avg_brightness = np.zeros(self.images.shape[0])
        self.img_blurriness = np.zeros(self.images.shape[0])
        self.img_range = {}
        self.image_stats = []

        self.process_images()
        self.process_channel_stats()
        self.dataset_stats()

    def process_images(self):
        for i, image in enumerate(self.images):
            stats = ImageStats(image)
            self.image_stats.append(stats)

            # These stats are per image so grabbing now
            self.img_height[i] = stats.height
            self.img_width[i] = stats.width
            self.img_size[i] = stats.size
            self.img_aspect_ratio[i] = stats.aspect_ratio
            self.img_channels[i] = stats.bands
            if stats.missing:
                self.img_missing[i] = 1
            if stats.bit_range not in self.img_range:
                self.img_range[stats.bit_range] = 1
            else:
                self.img_range[stats.bit_range] += 1
            self.img_avg_brightness[i] = stats.avg_brightness
            self.img_brightness[i] = stats.brightness
            self.img_entropy[i] = stats.entropy
            self.img_blurriness[i] = stats.blurry

    def process_channel_stats(self):
        max_channels = int(self.img_channels.max())
        # These stats are per channel
        self.img_zeros = np.empty((self.images.shape[0], max_channels))
        self.img_mean = np.empty((self.images.shape[0], max_channels))
        self.img_var = np.empty((self.images.shape[0], max_channels))
        self.img_skew = np.empty((self.images.shape[0], max_channels))
        self.img_kurtosis = np.empty((self.images.shape[0], max_channels))
        self.img_percentile = np.empty((self.images.shape[0], max_channels, 5))
        self.img_histogram = np.empty((self.images.shape[0], max_channels, 256))

        for i, stat in enumerate(self.image_stats):
            if self.img_channels[i] < max_channels:
                self.img_zeros[i, : self.img_channels[i]] = stat.zero
                self.img_zeros[i, self.img_channels[i] :] = np.nan
                self.img_mean[i, : self.img_channels[i]] = stat.mean
                self.img_mean[i, self.img_channels[i] :] = np.nan
                self.img_var[i, : self.img_channels[i]] = stat.var
                self.img_var[i, self.img_channels[i] :] = np.nan
                self.img_skew[i, : self.img_channels[i]] = stat.skew
                self.img_skew[i, self.img_channels[i] :] = np.nan
                self.img_kurtosis[i, : self.img_channels[i]] = stat.kurtosis
                self.img_kurtosis[i, self.img_channels[i] :] = np.nan
                self.img_percentile[i, : self.img_channels[i], :] = stat.percentiles
                self.img_percentile[i, self.img_channels[i] :, :] = np.nan
                self.img_histogram[i, : self.img_channels[i], :] = stat.histogram
                self.img_histogram[i, self.img_channels[i] :, :] = np.nan
            else:
                self.img_zeros[i, :] = stat.zero
                self.img_mean[i, :] = stat.mean
                self.img_var[i, :] = stat.var
                self.img_skew[i, :] = stat.skew
                self.img_kurtosis[i, :] = stat.kurtosis
                self.img_percentile[i, :, :] = stat.percentiles
                self.img_histogram[i, :, :] = stat.histogram

    def dataset_stats(self):
        # These stats are listed in the form of (min, mean, max)
        self.height = (
            self.img_height.min(),
            self.img_height.mean(),
            self.img_height.max(),
        )
        self.width = (
            self.img_width.min(),
            self.img_width.mean(),
            self.img_width.max(),
        )
        self.size = (
            self.img_size.min(),
            self.img_size.mean(),
            self.img_size.max(),
        )
        self.aspect_ratio = (
            self.img_aspect_ratio.min(),
            self.img_aspect_ratio.mean(),
            self.img_aspect_ratio.max(),
        )
        self.zeros = (
            self.img_zeros.min(),
            self.img_zeros.mean(),
            self.img_zeros.max(),
        )
        self.avg_brightness = (
            self.img_avg_brightness.min(),
            self.img_avg_brightness.mean(),
            self.img_avg_brightness.max(),
        )
        self.brightness = (
            self.img_brightness.min(),
            self.img_brightness.mean(),
            self.img_brightness.max(),
        )
        self.entropy = (
            self.img_entropy.min(),
            self.img_entropy.mean(),
            self.img_entropy.max(),
        )
        self.blurriness = (
            self.img_blurriness.min(),
            self.img_blurriness.mean(),
            self.img_blurriness.max(),
        )

        # These stats are based on the dataset as a whole
        if self.images.ndim == 3:
            self.images = np.expand_dims(self.images, axis=1)
        self.dataset_var = np.var(self.images, axis=(1, 2, 3))
        self.dataset_skew = sp.stats.skew(self.images)
        self.dataset_kurtosis = sp.stats.kurtosis(self.images)

        # These stats give counts
        self.missing = np.sum(self.img_missing)
        self.value_range = self.img_range

        # These stats give a tuple of (min/channel, mean/channel, max/channel)
        # for example 3 channels would give ([1,2,3], [2,3,4], [3,4,5])
        self.mean = (
            np.nanmin(self.img_mean, axis=0),
            np.nanmean(self.img_mean, axis=0),
            np.nanmax(self.img_mean, axis=0),
        )
        self.var = (
            np.nanmin(self.img_var, axis=0),
            np.nanmean(self.img_var, axis=0),
            np.nanmax(self.img_var, axis=0),
        )
        self.skew = (
            np.nanmin(self.img_skew, axis=0),
            np.nanmean(self.img_skew, axis=0),
            np.nanmax(self.img_skew, axis=0),
        )
        self.kurtosis = (
            np.nanmin(self.img_kurtosis, axis=0),
            np.nanmean(self.img_kurtosis, axis=0),
            np.nanmax(self.img_kurtosis, axis=0),
        )
        self.percentile = (
            np.nanmin(self.img_percentile, axis=0),
            np.nanmean(self.img_percentile, axis=0),
            np.nanmax(self.img_percentile, axis=0),
        )
        self.histogram = (
            np.nanmin(self.img_histogram, axis=0),
            np.nanmean(self.img_histogram, axis=0),
            np.nanmax(self.img_histogram, axis=0),
        )

In [None]:
# Testing single channel images - unnormalized
single_channel = DatasetStats(test_images)

In [None]:
# Testing multiple channel images - normalized
multi_channel = DatasetStats(norm_test_imgs)