In [70]:
import torch
import numpy as np
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
from scipy.fft import dctn, idctn

import matplotlib.pyplot as plt
torch.set_printoptions(threshold=10_000)

In [71]:
class CIFAR10_custom(CIFAR10):
    def __init__(self,*args, **kwargs) -> None:
        super(CIFAR10_custom, self).__init__(*args, **kwargs)

    def to_ycbcr(self, in_place = False) -> None:
        r"""Convert entire dataset from RGB to YCbCr.
        """
        data = self.data.transpose((0, 3, 1, 2))  # convert to CHW

        r = data[..., 0, :, :]
        g = data[..., 1, :, :]
        b = data[..., 2, :, :]

        y = 0.299 * r + 0.587 * g + 0.114 * b
        cb = -0.168736 * r - 0.331264 * g + 0.5 * b + 128
        cr = 0.5 * r - 0.418688 * g - 0.081312 * b + 128

        if in_place:
            self.data = np.stack((y, cb, cr), axis=-3)
        else:
            return np.stack((y, cb, cr), axis=-3)


In [72]:
def apply_across_batch(array, func, *args, **kwargs):
    # Get the shape of the input array
    batch_size, channels, H, W = array.shape

    # Initialize an output array of the same shape as the input
    output = np.zeros_like(array)

    # Iterate over the batch and channel dimensions
    for i in range(batch_size):
        output[i, :, :, :] = func(array[i, :, :, :], *args, **kwargs)

    return output

In [73]:
def blockwise_dct(image: np.array, block_size: tuple[int, int]=(8, 8)):
    height, width = image.shape
    block_height, block_width = block_size

    dct_blocks = np.zeros_like(image, dtype=np.float32)


    for i in range(0, height, block_height):
        for j in range(0, width, block_width):
            # Extract the block from the image
            block = image[i:i+block_height, j:j+block_width]

            # Apply the 2D DCT to the block
            dct_block = dctn(block, norm='ortho')

            # Store the DCT coefficients
            dct_blocks[i:i+block_height, j:j+block_width] = dct_block

    return dct_blocks

In [74]:
def blockwise_quantize(dct: np.ndarray, mode='l', alpha: int = 1, block_size=(8, 8)) -> np.ndarray:
    luminance_quantization_matrix = np.array([
        [16, 11, 10, 16, 24, 40, 51, 61],
        [12, 12, 14, 19, 26, 58, 60, 55],
        [14, 13, 16, 24, 40, 57, 69, 56],
        [14, 17, 22, 29, 51, 87, 80, 62],
        [18, 22, 37, 56, 68, 109, 103, 77],
        [24, 35, 55, 64, 81, 104, 113, 92],
        [49, 64, 78, 87, 103, 121, 120, 101],
        [72, 92, 95, 98, 112, 100, 103, 99]
    ])

    chrominance_quantization_matrix = np.array([
        [17, 18, 24, 47, 99, 99, 99, 99],
        [18, 21, 26, 66, 99, 99, 99, 99],
        [24, 26, 56, 99, 99, 99, 99, 99],
        [47, 66, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99],
        [99, 99, 99, 99, 99, 99, 99, 99]
    ])

    # Choose the appropriate quantization matrix based on the mode
    quantization_matrix = luminance_quantization_matrix if mode == 'l' else chrominance_quantization_matrix

    # Adjust the quantization matrix by the alpha multiplier
    quantization_matrix = quantization_matrix * alpha

    # Ensure the dimensions of the DCT input are divisible by the block size
    height, width = dct.shape
    block_height_num, block_width_num = height // block_size[0], width // block_size[1]

    # Tile the quantization matrix to cover the whole DCT matrix
    quantization_matrix_tiled = np.tile(quantization_matrix, (block_height_num, block_width_num))

    # Perform the quantization
    quantized_dct = np.round(dct / quantization_matrix_tiled).astype(np.int8)

    return quantized_dct

In [86]:
def compress_quantise_across_channels(image: np.ndarray, alpha: int = 1, block_size=(8, 8), *args, **kwargs):
    channels, height, width = image.shape

    assert channels == 3, f'channels must be 3 YCbCr but got {channels} instead'


    y = blockwise_dct(image[0, :, :], block_size, *args, **kwargs)
    cb = blockwise_dct(image[1, :, :], block_size, *args, **kwargs)
    cr = blockwise_dct(image[2, :, :],block_size, *args, **kwargs)

    y = blockwise_quantize(y, 'l', alpha, block_size, *args, **kwargs)
    cb = blockwise_quantize(y, 'c', alpha, block_size, *args, **kwargs)
    cr = blockwise_quantize(y, 'c', alpha, block_size, *args, **kwargs)

    print(f'y.shape {y.shape}')

    output = np.stack((y, cb, cr), axis=-3)
    return output

In [87]:
array = np.random.rand(3, 32, 32)
compress_quantise_across_channels(array).shape

y.shape (32, 32)


(3, 32, 32)