In [225]:
import torch
import numpy as np
import torchvision
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
from scipy.fft import dctn, idctn
from typing import Optional, Tuple, Any
from copy import deepcopy

import matplotlib.pyplot as plt
# torch.set_printoptions(threshold=10_000)
# np.set_printoptions(threshold=np.inf)

In [227]:
class CIFAR10_custom(CIFAR10):
    def __init__(self,*args, **kwargs) -> None:
        super(CIFAR10_custom, self).__init__(*args, **kwargs)
        self.format = 'rgb'


    def to_ycbcr(self, in_place = False) -> Optional[Tuple[np.ndarray, list]]:
        r"""Convert entire dataset from RGB to YCbCr.

             Args:
                in_place (bool, optional): Whether to modify the entire CIFAR10 dataset in memory.
                If set to False, returns the (images, targets) as tuples where images are of
                shape (B, C, H, W). Defaults to False.
        """
        assert self.format != 'compressed', f'cannot transform format {self.format} into ycbcr, have you ran to_ycbcr after compressed?'

        if self.format == 'ycbcr' and not in_place: # prevent applying transformation twice
            return self.data.transpose((0, 3, 1, 2)), self.targets

        if self.format == 'ycbcr' and in_place: # prevent applying transformation twice
            return None

        data = self.data.transpose((0, 3, 1, 2))  # convert to CHW


        r = data[..., 0, :, :]
        g = data[..., 1, :, :]
        b = data[..., 2, :, :]

        y = 0.299 * r + 0.587 * g + 0.114 * b
        cb = -0.168736 * r - 0.331264 * g + 0.5 * b + 128
        cr = 0.5 * r - 0.418688 * g - 0.081312 * b + 128


        if in_place:
            self.format = 'ycbcr'
            self.data = np.stack((y, cb, cr), axis=-3).transpose((0, 2, 3, 1)) # convert to HWC
        else:
            return np.stack((y, cb, cr), axis=-3), self.targets


    def compress(self, in_place = False, *args, **kwargs) -> Optional[Tuple[np.ndarray, list]]:
        r"""Compresses the entire dataset from YCbCr format.

        Args:
            in_place (bool, optional): Whether to modify the entire CIFAR10 dataset in memory.
                If set to False, returns the (images, targets) as tuples where images are of
                shape (B, C, H, W). Defaults to False.
            block_size (tuple[int, int], optional): Size of the window to apply compression.
                Defaults to (8, 8).
            alpha (int, optional): Alpha parameter to control the magnitude of compression.
                Defaults to 1.
        """


        assert self.format != 'rgb', f'format should be ycbcr but is {self.format}, call .to_ycbcr(in_place = True) before compress'

        if self.format == 'compressed' and not in_place: # prevent applying transformation twice
            return self.data.transpose((0, 3, 1, 2)), self.targets

        if self.format == 'compressed' and in_place: # prevent applying transformation twice
            return None

        data = self.data.transpose((0, 3, 1, 2)) # convert to CHW

        data = apply_across_batch(data, compress_quantise_across_channels, *args, **kwargs)

        if in_place:
            self.format = 'compressed'
            self.data = data.transpose((0, 2, 3, 1)) # convert to HWC
        else:
            return data, self.targets

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """

        if self.format == 'rgb':
            return super().__getitem__(index)

        if self.format in ('compressed', 'ycbcr'):

            img, target = self.data[index], self.targets[index]

            if self.transform is not None:
                self._check_transform()
                img = self.transform(img)

            if self.target_transform is not None:
                target = self.target_transform(target)

            return img, target

    def _check_transform(self) -> Any:
        invalid_transform = torchvision.transforms.transforms.__name__

        transformations = deepcopy(self.transform.__getstate__()['transforms'])

        are_valid = list(map(lambda t: t.__class__.__module__ != invalid_transform, transformations))

        assert all(are_valid), f"base image transformations from torchvision are not supported when format is {self.format}, use custom ones from preprocessing.transforms"



def apply_across_batch(array, func, *args, **kwargs):
    # Get the shape of the input array
    batch_size, channels, H, W = array.shape

    # Initialize an output array of the same shape as the input
    output = np.zeros_like(array)

    # Iterate over the batch and channel dimensions
    for i in range(batch_size):
        output[i, :, :, :] = func(array[i, :, :, :], *args, **kwargs)

    return output

In [228]:
def blockwise_dct(image: np.array, block_size: tuple[int, int] = (8, 8)):
    height, width = image.shape
    block_height, block_width = block_size

    dct_blocks = np.zeros_like(image, dtype=np.float32)


    for i in range(0, height, block_height):
        for j in range(0, width, block_width):
            block = image[i:i+block_height, j:j+block_width]

            dct_block = dctn(block, norm='ortho')

            dct_blocks[i:i+block_height, j:j+block_width] = dct_block

    return dct_blocks

In [229]:
def blockwise_quantize(dct: np.ndarray, mode='l', block_size: tuple[int, int] = (8, 8), alpha: int = 1,) -> np.ndarray:
    luminance_quantization_matrix = np.array([
        [16, 11, 10, 16, 24, 40, 51, 61],
        [12, 12, 14, 19, 26, 58, 60, 55],
        [14, 13, 16, 24, 40, 57, 69, 56],
        [14, 17, 22, 29, 51, 87, 80, 62],
        [18, 22, 37, 56, 68, 109, 103, 77],
        [24, 35, 55, 64, 81, 104, 113, 92],
        [49, 64, 78, 87, 103, 121, 120, 101],
        [72, 92, 95, 98, 112, 100, 103, 99]
    ])

    chrominance_quantization_matrix = np.array([
        [17, 18, 24, 47, 99, 99, 99, 99],
        [18, 21, 26, 66, 99, 99, 99, 99],
        [24, 26, 56, 99, 99, 99, 99, 99],
        [47, 66, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99]
    ])


    quantization_matrix = luminance_quantization_matrix if mode == 'l' else chrominance_quantization_matrix

    quantization_matrix = quantization_matrix * alpha

    height, width = dct.shape
    block_height_num, block_width_num = height // block_size[0], width // block_size[1]

    quantization_matrix_tiled = np.tile(quantization_matrix, (block_height_num, block_width_num))

    return np.round(dct / quantization_matrix_tiled).astype(np.int8)

In [230]:
def compress_quantise_across_channels(image: np.ndarray, block_size: tuple[int, int] = (8, 8), alpha: int = 1, *args, **kwargs):
    channels, height, width = image.shape

    assert channels == 3, f'channels must be 3 YCbCr but got {channels} instead'


    y = blockwise_dct(image[0, :, :], block_size, *args, **kwargs)
    cb = blockwise_dct(image[1, :, :], block_size, *args, **kwargs)
    cr = blockwise_dct(image[2, :, :],block_size, *args, **kwargs)

    y = blockwise_quantize(y, 'l', block_size, alpha, *args, **kwargs)
    cb = blockwise_quantize(y, 'c', block_size, alpha, *args, **kwargs)
    cr = blockwise_quantize(y, 'c', block_size, alpha, *args, **kwargs)


    return np.stack((y, cb, cr), axis=-3)

In [231]:
array = np.random.rand(3, 32, 32)
compress_quantise_across_channels(array).shape

(3, 32, 32)

In [232]:
download_path = os.path.join('..', 'data', 'cifar10')

if not os.path.exists(download_path):
    os.makedirs(download_path)

In [233]:
transform = transforms.Compose([transforms.ToTensor()])

cifar = CIFAR10_custom(download_path, transform=transform, download = True)

Files already downloaded and verified


In [234]:
cifar.format

'rgb'

In [235]:
cifar[0][0]

tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
         [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
         [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
         ...,
         [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
         [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
         [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],

        [[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],
         [0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],
         [0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],
         ...,
         [0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],
         [0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],
         [0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],

        [[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],
         [0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],
         [0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.

In [236]:
cifar.to_ycbcr(in_place=True)
output = cifar.compress(in_place=True)

In [237]:
cifar[0][0]

AssertionError: base image transformations from torchvision are not supported when format is compressed, use custom ones from preprocessing.transforms

In [162]:
images, _ = output

In [163]:
fimage = images[:4]
fimage

array([[[[ 37., -13.,  -4., ...,   0.,   0.,   0.],
         [ -7.,  -7.,   2., ...,   0.,   0.,   0.],
         [  3.,   1.,   3., ...,   0.,   0.,   0.],
         ...,
         [ -1.,   0.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.]],

        [[  2.,  -1.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.],
         ...,
         [  0.,   0.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.]],

        [[  2.,  -1.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.],
         ...,
         [  0.,   0.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.],
         [  0.,   0.,   0., ...,   0.,   0.,   0.]]],


       [[[ 71.

In [179]:
train_transformations = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])

In [209]:
train_transformations.__getstate__()['transforms'][3].__class__

torchvision.transforms.transforms.ToTensor

In [202]:
train_transformations.__getstate__()['transforms'][3].__class__()

ToTensor()

In [196]:
torchvision.transforms.transforms.ToTensor in train_transformations.__getstate__()['transforms']

False

In [210]:
torchvision.transforms.transforms.__name__ is train_transformations.__getstate__()['transforms'][3].__class__.__module__

True