In [1]:
import torch
import numpy as np
import os
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize

from preprocessing.dataset import CIFAR10_custom
from preprocessing.transforms import CompressedToTensor, ZigZagOrder, ChooseAC, FlattenZigZag
from model.init import init_truncated_normal, set_seed
from model.vit import CompressedVisionTransformer

In [2]:
DOWNLAOD_PATH = os.path.join('data', 'cifar10')
SEED = 42
VALIDATION_SET = 0.1
BATCH_SIZE = 128

# Loading in the data

In [3]:
transform_train = Compose([ToTensor(),
                           Normalize((0.5,), (0.5,))])

transform_test = Compose([ToTensor(),
                          Normalize((0.5,), (0.5,))])

transform_compressed_train = Compose([CompressedToTensor(),
                     ZigZagOrder(),
                     ChooseAC(5)])

transform_compressed_test = Compose([CompressedToTensor(),
                                      ZigZagOrder(),
                                      ChooseAC(5)])

cifar = CIFAR10(root=DOWNLAOD_PATH, train=True, transform=transform_train, target_transform=None, download = False)
cifar_test = CIFAR10(root=DOWNLAOD_PATH, train=False, transform=transform_test, target_transform=None, download = False)


cifar_compressed = CIFAR10_custom(root=DOWNLAOD_PATH, train=True, transform=transform_compressed_train, target_transform=None, download = False, compression=None)
cifar_compressed_test = CIFAR10_custom(root=DOWNLAOD_PATH, train=False, transform=transform_compressed_test, target_transform=None, download = False, compression=None)

In [4]:
cifar_compressed.to_ycbcr(in_place=True)
cifar_compressed.compress(in_place=True)

cifar_compressed_test.to_ycbcr(in_place=True)
cifar_compressed_test.compress(in_place=True)

In [5]:
with set_seed(SEED):
    num_train = len(cifar)
    num_val = int(0.1 * num_train)
    num_train -= num_val

    cifar_train, cifar_val = random_split(cifar, [num_train, num_val])
    cifar_compressed_train, cifar_compressed_val = random_split(cifar_compressed, [num_train, num_val])

# Models

In [16]:
cvit = CompressedVisionTransformer(ac=5,
                        channels=3,
                        patch_num=16,
                        num_classes=10,
                        d_model=248,
                        nhead=8,
                        dim_feedforward=1024,
                        dropout=0.1,
                        activation=nn.GELU(),
                        ntransformers=4,
                        layer_norm_eps=1e-5,
                        norm_first=False,
                        bias=True,
                        learnable_positional=True)

cvit.init_weights(init_truncated_normal)
cvit.pre_training()