In [42]:
import torch
import numpy as np
import os
from scipy.fft import dctn, idctn

from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.models import resnet18


from preprocessing.transforms import CompressedToTensor, ZigZagOrder, ChooseAC, FlattenZigZag, ConvertToFrequencyDomain, ConvertToYcbcr, Quantize, LUMINANCE_QUANTIZATION_MATRIX, CHROMINANCE_QUANTIZATION_MATRIX
from model.init import init_truncated_normal, init_kaiming_normal, set_seed
from model.vit import CompressedVisionTransformer, VisionTransformer

In [43]:
DOWNLAOD_PATH = os.path.join('data', 'cifar10')
SEED = 42
VALIDATION_SET = 0.1
BATCH_SIZE = 128
AC = 5

In [44]:
quantization_matrices = [LUMINANCE_QUANTIZATION_MATRIX, CHROMINANCE_QUANTIZATION_MATRIX, CHROMINANCE_QUANTIZATION_MATRIX]

vanilla_transform = Compose([
    ToTensor()
    # Returns pixels in range [0-1]
])

transform = Compose([
    CompressedToTensor(), # 3x32x32
    # Returns pixels in range [0-255]
    ConvertToYcbcr(), # 3x32x32
    # Returns pixels in range [0-1]
    ConvertToFrequencyDomain(norm='ortho'), # 3x32x32
    Quantize(quantization_matrices=quantization_matrices), # 3x32x32
    ZigZagOrder(), # 3x16x64
    ChooseAC(AC), # 3x16x(AC+1)
    FlattenZigZag() # 3x(16x(AC+1))
])

In [45]:
cifar = CIFAR10(root=DOWNLAOD_PATH, train=True, transform=transform, target_transform=None, download = False)
cifar_test = CIFAR10(root=DOWNLAOD_PATH, train=False, transform=transform, target_transform=None, download = False)

In [46]:
with set_seed(SEED):
    num_train = len(cifar)
    num_val = int(0.1 * num_train)
    num_train -= num_val

    cifar_train, cifar_val = random_split(cifar, [num_train, num_val])

In [47]:
train = DataLoader(cifar_train, batch_size=BATCH_SIZE, shuffle=True)
val = DataLoader(cifar_val, batch_size=BATCH_SIZE, shuffle=True)
test = DataLoader(cifar_test, batch_size=cifar_test.__len__(), shuffle=False)

# Example training