In [1]:
import torch
import numpy as np
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
from scipy.fft import dctn, idctn

import matplotlib.pyplot as plt
# torch.set_printoptions(threshold=10_000)
# np.set_printoptions(threshold=np.inf)

In [42]:
class CIFAR10_custom(CIFAR10):
    def __init__(self,*args, **kwargs) -> None:
        super(CIFAR10_custom, self).__init__(*args, **kwargs)
        self.format = 'rgb'


    def to_ycbcr(self, in_place = False) -> None:
        r"""Convert entire dataset from RGB to YCbCr.

             Args:
                in_place (bool) = False: boolean whether to modify the entire CIFRA10 dataset in memory. If set to false, returns the (images, targets) of (B, C, H, W) images.

        """
        if self.format == 'ycbcr' and not in_place: # prevent applying transformation twice
            return self.data.transpose((0, 3, 1, 2)), self.targets

        if self.format == 'ycbcr' and in_place:
            return None

        data = self.data.transpose((0, 3, 1, 2))  # convert to CHW


        r = data[..., 0, :, :]
        g = data[..., 1, :, :]
        b = data[..., 2, :, :]

        y = 0.299 * r + 0.587 * g + 0.114 * b
        cb = -0.168736 * r - 0.331264 * g + 0.5 * b + 128
        cr = 0.5 * r - 0.418688 * g - 0.081312 * b + 128


        if in_place:
            self.format = 'ycbcr'
            self.data = np.stack((y, cb, cr), axis=-3).transpose((0, 2, 3, 1))
        else:
            return np.stack((y, cb, cr), axis=-3), self.targets


In [43]:
def apply_across_batch(array, func, *args, **kwargs):
    # Get the shape of the input array
    batch_size, channels, H, W = array.shape

    # Initialize an output array of the same shape as the input
    output = np.zeros_like(array)

    # Iterate over the batch and channel dimensions
    for i in range(batch_size):
        output[i, :, :, :] = func(array[i, :, :, :], *args, **kwargs)

    return output

In [44]:
def blockwise_dct(image: np.array, block_size: tuple[int, int]=(8, 8)):
    height, width = image.shape
    block_height, block_width = block_size

    dct_blocks = np.zeros_like(image, dtype=np.float32)


    for i in range(0, height, block_height):
        for j in range(0, width, block_width):
            # Extract the block from the image
            block = image[i:i+block_height, j:j+block_width]

            # Apply the 2D DCT to the block
            dct_block = dctn(block, norm='ortho')

            # Store the DCT coefficients
            dct_blocks[i:i+block_height, j:j+block_width] = dct_block

    return dct_blocks

In [45]:
def blockwise_quantize(dct: np.ndarray, mode='l', alpha: int = 1, block_size=(8, 8)) -> np.ndarray:
    luminance_quantization_matrix = np.array([
        [16, 11, 10, 16, 24, 40, 51, 61],
        [12, 12, 14, 19, 26, 58, 60, 55],
        [14, 13, 16, 24, 40, 57, 69, 56],
        [14, 17, 22, 29, 51, 87, 80, 62],
        [18, 22, 37, 56, 68, 109, 103, 77],
        [24, 35, 55, 64, 81, 104, 113, 92],
        [49, 64, 78, 87, 103, 121, 120, 101],
        [72, 92, 95, 98, 112, 100, 103, 99]
    ])

    chrominance_quantization_matrix = np.array([
        [17, 18, 24, 47, 99, 99, 99, 99],
        [18, 21, 26, 66, 99, 99, 99, 99],
        [24, 26, 56, 99, 99, 99, 99, 99],
        [47, 66, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99]
    ])

    # Choose the appropriate quantization matrix based on the mode
    quantization_matrix = luminance_quantization_matrix if mode == 'l' else chrominance_quantization_matrix

    # Adjust the quantization matrix by the alpha multiplier
    quantization_matrix = quantization_matrix * alpha

    # Ensure the dimensions of the DCT input are divisible by the block size
    height, width = dct.shape
    block_height_num, block_width_num = height // block_size[0], width // block_size[1]

    # Tile the quantization matrix to cover the whole DCT matrix
    quantization_matrix_tiled = np.tile(quantization_matrix, (block_height_num, block_width_num))

    # Perform the quantization
    quantized_dct = np.round(dct / quantization_matrix_tiled).astype(np.int8)

    return quantized_dct

In [46]:
def compress_quantise_across_channels(image: np.ndarray, alpha: int = 1, block_size=(8, 8), *args, **kwargs):
    channels, height, width = image.shape

    assert channels == 3, f'channels must be 3 YCbCr but got {channels} instead'


    y = blockwise_dct(image[0, :, :], block_size, *args, **kwargs)
    cb = blockwise_dct(image[1, :, :], block_size, *args, **kwargs)
    cr = blockwise_dct(image[2, :, :],block_size, *args, **kwargs)

    y = blockwise_quantize(y, 'l', alpha, block_size, *args, **kwargs)
    cb = blockwise_quantize(y, 'c', alpha, block_size, *args, **kwargs)
    cr = blockwise_quantize(y, 'c', alpha, block_size, *args, **kwargs)


    output = np.stack((y, cb, cr), axis=-3)
    return output

In [47]:
array = np.random.rand(3, 32, 32)
compress_quantise_across_channels(array).shape

(3, 32, 32)

In [48]:
download_path = os.path.join('..', 'data', 'cifar10')

if not os.path.exists(download_path):
    os.makedirs(download_path)

In [49]:
cifar = CIFAR10_custom(download_path, download = True)

Files already downloaded and verified


In [63]:
cifar.format

'ycbcr'

In [68]:
cifar.to_ycbcr(in_place=True)

In [70]:
output = cifar.to_ycbcr(in_place=False)

In [71]:
cifar.data[:4]

array([[[[ 61.217   , 129.006208, 126.418688],
         [ 44.989   , 128.006208, 126.581312],
         [ 48.028   , 125.162528, 129.40656 ],
         ...,
         [137.038   , 111.612864, 142.951488],
         [130.451   , 111.944128, 143.370176],
         [128.782   , 113.450336, 141.707552]],

        [[ 18.804   , 128.674944, 126.      ],
         [  0.      , 128.      , 128.      ],
         [ 10.078   , 122.31264 , 133.650496],
         ...,
         [ 94.703   , 105.59424 , 148.183296],
         [ 90.002   , 105.425504, 148.683296],
         [ 94.045   , 107.09424 , 147.93936 ]],

        [[ 23.957   , 126.331264, 128.743936],
         [  8.893   , 122.981376, 133.069184],
         [ 31.412   , 114.787808, 140.544928],
         ...,
         [ 90.29    , 105.262976, 147.764608],
         [ 90.888   , 104.925504, 148.764608],
         [ 80.23    , 106.425504, 148.520672]],

        ...,

        [[172.926   ,  84.588032, 153.017088],
         [153.786   ,  60.400672, 161.676128]

In [72]:
images, _ = output

In [73]:
fimage = images[:4]
fimage

array([[[[ 61.217   ,  44.989   ,  48.028   , ..., 137.038   ,
          130.451   , 128.782   ],
         [ 18.804   ,   0.      ,  10.078   , ...,  94.703   ,
           90.002   ,  94.045   ],
         [ 23.957   ,   8.893   ,  31.412   , ...,  90.29    ,
           90.888   ,  80.23    ],
         ...,
         [172.926   , 153.786   , 156.673   , ..., 133.891   ,
           35.739   ,  38.085   ],
         [146.357   , 128.716   , 143.562   , ..., 152.608   ,
           69.273   ,  59.804   ],
         [150.675   , 136.671   , 146.793   , ..., 188.552   ,
          123.991   ,  98.989   ]],

        [[129.006208, 128.006208, 125.162528, ..., 111.612864,
          111.944128, 113.450336],
         [128.674944, 128.      , 122.31264 , ..., 105.59424 ,
          105.425504, 107.09424 ],
         [126.331264, 122.981376, 114.787808, ..., 105.262976,
          104.925504, 106.425504],
         ...,
         [ 84.588032,  60.400672,  54.256768, ...,  91.944128,
          111.7816  , 117

In [45]:
cfimage = apply_across_batch(fimage, compress_quantise_across_channels)

In [46]:
cfimage[0, 0, 0:8, 0:8]

array([[ 37., -13.,  -4.,   1.,   0.,   0.,   0.,   0.],
       [ -7.,  -7.,   2.,   1.,   0.,   0.,   0.,   0.],
       [  3.,   1.,   3.,   0.,   0.,   0.,   0.,   0.],
       [  2.,   1.,   0.,   0.,   0.,   0.,   0.,   0.],
       [  1.,   1.,   0.,   0.,   0.,   0.,   0.,   0.],
       [  1.,   1.,   0.,   0.,   0.,   0.,   0.,   0.],
       [  1.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
       [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.]])