In [244]:
import torch
import numpy as np
import os
from torch import nn
from dataset import CIFAR10_custom
from transforms import CompressedToTensor
from torchvision.transforms import Compose

torch.set_printoptions(threshold=10_000)
# np.set_printoptions(threshold=np.inf)

In [21]:
download_path = os.path.join('..', 'data', 'cifar10')

if not os.path.exists(download_path):
    os.makedirs(download_path)

In [22]:
transform = Compose([CompressedToTensor()])
with_zig_zag = Compose([CompressedToTensor(),
                        ZigZagOrder()])

cifar = CIFAR10_custom(download_path, transform = transform,  download = True)
cifar2 = CIFAR10_custom(download_path, transform = transform,  download = True)

Files already downloaded and verified


In [23]:
cifar.format

'rgb'

In [24]:
cifar[0][0]

tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
         [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
         [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
         ...,
         [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
         [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
         [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],

        [[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],
         [0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],
         [0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],
         ...,
         [0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],
         [0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],
         [0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],

        [[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],
         [0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],
         [0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.

In [25]:
cifar.to_ycbcr(in_place=True)
output = cifar.compress(in_place=True)

In [110]:
compressed_img = cifar[0][0]
compressed_batch = torch.from_numpy(cifar.data[0:4,:,:,:].transpose((0, 3, 1, 2)))

In [51]:
ex = torch.rand((4, 3, 32, 32))
ex2 = torch.rand((3, 32, 32))

In [31]:
torch.flatten(ex, -2, -1).shape

torch.Size([3, 1024])

In [34]:
ex.shape[:-2]

torch.Size([3])

In [38]:
for channels in ex2.shape[:-2]:
    print(channels)

In [118]:
def _zigzag_order(block: torch.Tensor) -> torch.Tensor:
    # Zigzag index pattern for an 8x8 block
    zigzag_indices = torch.tensor([
        0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
        12, 19, 26, 33, 40, 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28,
        35, 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51,
        58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63
    ])

    # Flatten the block and reorder it using the zigzag indices
    flat_block = block.flatten()
    zigzag_block = flat_block[zigzag_indices]

    return zigzag_block

In [121]:
def _blockwise_zigzag_order(image: torch.Tensor, block_size: tuple[int, int] = (8, 8)) -> torch.Tensor:

    img_height, img_width = image.shape
    block_height, block_width = block_size
    block_height_num, block_width_num = img_height//block_height, img_width//block_width
    zigzag_length = block_size[0] * block_size[1]

    zigzag_patches = torch.zeros((block_height_num, block_width_num, zigzag_length), dtype = torch.int8)


    for i in range(0, img_height, block_height):
        for j in range(0, img_width, block_width):
            block = image[i:i+block_height, j:j+block_width]
            zigzag_block = _zigzag_order(block)
            zigzag_patches[i//block_height, j//block_width, :] = zigzag_block

    return zigzag_patches

In [227]:
def zigzag_order(image: torch.Tensor, block_size: tuple[int, int] = (8, 8)) -> torch.Tensor:
    batch_size = image.shape[:-3]
    channels = image.shape[-3]

    unfold = torch.nn.Unfold(kernel_size=block_size, stride=block_size)
    windows = unfold(image).transpose(-2, -1).view(*batch_size, -1, channels, block_size[0]*block_size[1]).transpose(-2, -3)

    print(windows.shape)
    return windows[..., zigzag_indices]

In [221]:
_zigzag_order(compressed_img[:, 0:8, 0:8])

tensor([ 37., -13.,  -7.,   3.,  -7.,  -4.,   1.,   2.,   1.,   2.,   1.,   1.,
          3.,   1.,   0.,   0.,   0.,   0.,   0.,   1.,   1.,   1.,   1.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.], dtype=torch.float64)

In [222]:
unfold = torch.nn.Unfold(kernel_size=(8,8), stride=(8,8))
print(compressed_batch.shape)
print(compressed_batch[0, 0, 0:8, 0:8])
unfolded = unfold(compressed_batch)

torch.Size([4, 3, 32, 32])
tensor([[ 37., -13.,  -4.,   1.,   0.,   0.,   0.,   0.],
        [ -7.,  -7.,   2.,   1.,   0.,   0.,   0.,   0.],
        [  3.,   1.,   3.,   0.,   0.,   0.,   0.,   0.],
        [  2.,   1.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  1.,   1.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  1.,   1.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  1.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.]], dtype=torch.float64)


In [223]:
compressed_batch.shape

torch.Size([4, 3, 32, 32])

In [180]:
torch.nn.functional.unfold()

TypeError: unfold() missing 2 required positional arguments: 'input' and 'kernel_size'

In [181]:
_zigzag_order(unfolded.transpose(1, 2)[0, :, :64])

tensor([ 37., -13.,  -7.,   3.,  -7.,  -4.,   1.,   2.,   1.,   2.,   1.,   1.,
          3.,   1.,   0.,   0.,   0.,   0.,   0.,   1.,   1.,   1.,   1.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.], dtype=torch.float64)

In [195]:
trans = unfolded.transpose(-2, -1).view(4, -1, 3, 64).transpose(-2, -3)

In [196]:
trans[0, 0, :, :].shape

torch.Size([16, 64])

In [197]:
zigzag_indices = torch.tensor([
    0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
    12, 19, 26, 33, 40, 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28,
    35, 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51,
    58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63
])

In [203]:
final = trans[:, :, :, zigzag_indices]
final[0, 0, :, :]

tensor([[ 37., -13.,  -7.,   3.,  -7.,  -4.,   1.,   2.,   1.,   2.,   1.,   1.,
           3.,   1.,   0.,   0.,   0.,   0.,   0.,   1.,   1.,   1.,   1.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.],
        [ 42.,   4.,   7.,   0.,  -5.,   1.,   0.,   1.,   5.,   3.,   2.,   0.,
           0.,   1.,  -1.,   0.,   0.,   0.,   0.,   0.,   2.,   0.,   0.,   0.,
           1.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.],
        [ 44.,  -2.,   8.,   0.,  -2.,   4.,   1.,  -4.,   4.,  -1.,   4.,  -2.,
           3.,   0.,   0.,   0.,   1.,   

In [228]:
check3 = _blockwise_zigzag_order(compressed_img[0, :, :])

In [229]:
check1 = zigzag_order(compressed_batch)[0, 0, :, :]

torch.Size([4, 3, 16, 64])


In [230]:
check2 = zigzag_order(compressed_img)[0, :, :]

torch.Size([3, 16, 64])


In [231]:
check1 == check2

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, Tr

In [232]:
check3.view(-1, 64) == check1

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, Tr

In [128]:
unfolded.transpose(-2, -1).view(4, 16, 3, -1).transpose(-2, -3).shape

torch.Size([4, 3, 16, 64])

In [125]:
unfolded.shape

torch.Size([4, 192, 16])

In [141]:
check2.shape

torch.Size([3, 16, 64])