In [1]:
import torch
import numpy as np
import os
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from preprocessing.dataset import CIFAR10_custom
from preprocessing.transforms import CompressedToTensor, ZigZagOrder, ChooseAC, FlattenZigZag
from model.init import init_truncated_normal
from model.vit import CompressedVisionTransformer

In [2]:
download_path = os.path.join('data', 'cifar10')

In [3]:
transform = Compose([CompressedToTensor(),
                     ZigZagOrder(),
                     ChooseAC(5)])


cifar_compressed = CIFAR10_custom(root=download_path, train=True, transform=transform, target_transform=None, download = False, compression=None)

cifar_compressed_test = CIFAR10_custom(root=download_path, train=False, transform=transform, target_transform=None, download = False, compression=None)

In [4]:
cifar_compressed.to_ycbcr(in_place=True)
cifar_compressed.compress(in_place=True)

In [5]:
batch_size = 16
train_loader = DataLoader(cifar_compressed, batch_size=batch_size, shuffle=True)

# Models

In [16]:
cvit = CompressedVisionTransformer(ac=5,
                        channels=3,
                        patch_num=16,
                        num_classes=10,
                        d_model=248,
                        nhead=8,
                        dim_feedforward=1024,
                        dropout=0.1,
                        activation=nn.GELU(),
                        ntransformers=4,
                        layer_norm_eps=1e-5,
                        norm_first=False,
                        bias=True,
                        learnable_positional=True)

cvit.init_weights(init_truncated_normal)
cvit.pre_training()