In [1]:
import torch
import numpy as np
import os
from scipy.fft import dctn, idctn

from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.models import resnet18


from preprocessing.transforms import CompressedToTensor, ZigZagOrder, ChooseAC, FlattenZigZag, ConvertToFrequencyDomain, ConvertToYcbcr, Quantize, LUMINANCE_QUANTIZATION_MATRIX, CHROMINANCE_QUANTIZATION_MATRIX
from model.init import init_truncated_normal, init_kaiming_normal, set_seed
from model.vit import CompressedVisionTransformer, VisionTransformer

In [2]:
DOWNLAOD_PATH = os.path.join('data', 'cifar10')
SEED = 42
VALIDATION_SET = 0.1
BATCH_SIZE = 128
AC = 5

In [3]:
quantization_matrices = [LUMINANCE_QUANTIZATION_MATRIX, CHROMINANCE_QUANTIZATION_MATRIX, CHROMINANCE_QUANTIZATION_MATRIX]

vanilla_transform = Compose([
    ToTensor()
    # Returns pixels in range [0-1]
])

transform = Compose([
    CompressedToTensor(), # 3x32x32
    # Returns pixels in range [0-255]
    ConvertToYcbcr(), # 3x32x32
    # Returns pixels in range [0-1]
    ConvertToFrequencyDomain(norm='ortho'), # 3x32x32
    Quantize(quantization_matrices=quantization_matrices), # 3x32x32
    ZigZagOrder(), # 3x16x64
    ChooseAC(AC), # 3x16x(AC+1)
    FlattenZigZag() # 3x(16x(AC+1))
])

In [4]:
cifar = CIFAR10(root=DOWNLAOD_PATH, train=True, transform=transform, target_transform=None, download = False)
cifar_test = CIFAR10(root=DOWNLAOD_PATH, train=False, transform=transform, target_transform=None, download = False)

In [5]:
with set_seed(SEED): # For reproducible results run any random operations with set_seed()
    num_train = len(cifar)
    num_val = int(0.1 * num_train)
    num_train -= num_val

    cifar_train, cifar_val = random_split(cifar, [num_train, num_val])

In [6]:
train = DataLoader(cifar_train, batch_size=BATCH_SIZE, shuffle=True)
val = DataLoader(cifar_val, batch_size=BATCH_SIZE, shuffle=True)
test = DataLoader(cifar_test, batch_size=cifar_test.__len__(), shuffle=False)

# Example training

In [None]:
cvit = CompressedVisionTransformer(ac=AC, # Required for proper positional encoding
                                   channels=3,
                                   patch_num=16,
                                   num_classes=10,
                                   d_model=248,
                                   nhead=8,
                                   dim_feedforward=1024,
                                   dropout=0.1,
                                   activation=nn.GELU(),
                                   ntransformers=4,
                                   layer_norm_eps=1e-5,
                                   norm_first=False,
                                   bias=True,
                                   learnable_positional=True)

cvit.pre_training()

with set_seed(SEED): # For reproducible results run any random operations with set_seed()
    cvit.init_weights(init_truncated_normal)

In [None]:
criterion = nn.CrossEntropyLoss()
lr = 1e-4
num_epochs = 10
batch_size = 128
weight_decay = 0
checkpoint_every_th_epoch = None
cvit_optimizer = optim.Adam(cvit.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
with set_seed(SEED):
    for epoch in range(num_epochs):
        cvit.train()
        train_loss = 0.0
        for images, labels in train:
            images, labels = images.to(torch.float32), labels
    
            cvit_optimizer.zero_grad()
    
            outputs = cvit(images)
            loss = criterion(outputs, labels)
            loss.backward()
            cvit_optimizer.step()
    
            train_loss += loss.item() * images.size(0)
    
        train_loss /= len(train)
    
        cvit.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val:
                images, labels = images.to(torch.float32), labels
    
                outputs = cvit(images)
                loss = criterion(outputs, labels)
    
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                #TODO: Add more metrics
    
        val_loss /= len(val)
        val_accuracy = correct / total
    
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')