In [8]:
import torch
import numpy as np
import os
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.models import resnet18

from preprocessing.dataset import CIFAR10_custom
from preprocessing.transforms import CompressedToTensor, ZigZagOrder, ChooseAC, FlattenZigZag
from model.init import init_truncated_normal, init_kaiming_normal, set_seed
from model.vit import CompressedVisionTransformer
from model.resnet import ResNet18

In [9]:
DOWNLAOD_PATH = os.path.join('data', 'cifar10')
SEED = 42
VALIDATION_SET = 0.1
BATCH_SIZE = 128

# Loading in the data

In [10]:
transform_train = Compose([ToTensor(),
                           Normalize((0.5,), (0.5,))])

transform_test = Compose([ToTensor(),
                          Normalize((0.5,), (0.5,))])

transform_compressed_train = Compose([CompressedToTensor(),
                     ZigZagOrder(),
                     ChooseAC(5)])

transform_compressed_test = Compose([CompressedToTensor(),
                                      ZigZagOrder(),
                                      ChooseAC(5)])

cifar = CIFAR10(root=DOWNLAOD_PATH, train=True, transform=transform_train, target_transform=None, download = False)
cifar_test = CIFAR10(root=DOWNLAOD_PATH, train=False, transform=transform_test, target_transform=None, download = False)


cifar_compressed = CIFAR10_custom(root=DOWNLAOD_PATH, train=True, transform=transform_compressed_train, target_transform=None, download = False, compression=None)
cifar_compressed_test = CIFAR10_custom(root=DOWNLAOD_PATH, train=False, transform=transform_compressed_test, target_transform=None, download = False, compression=None)

In [11]:
cifar_compressed.to_ycbcr(in_place=True)
cifar_compressed.compress(in_place=True)

cifar_compressed_test.to_ycbcr(in_place=True)
cifar_compressed_test.compress(in_place=True)

In [12]:
with set_seed(SEED):
    num_train = len(cifar)
    num_val = int(0.1 * num_train)
    num_train -= num_val

    cifar_train, cifar_val = random_split(cifar, [num_train, num_val])
    cifar_compressed_train, cifar_compressed_val = random_split(cifar_compressed, [num_train, num_val])

# Models

In [13]:
cvit = CompressedVisionTransformer(ac=5,
                        channels=3,
                        patch_num=16,
                        num_classes=10,
                        d_model=248,
                        nhead=8,
                        dim_feedforward=1024,
                        dropout=0.1,
                        activation=nn.GELU(),
                        ntransformers=4,
                        layer_norm_eps=1e-5,
                        norm_first=False,
                        bias=True,
                        learnable_positional=True)

with set_seed(SEED):
    cvit.init_weights(init_truncated_normal)
cvit.pre_training()

In [14]:
resnet18 = ResNet18(channels=3, num_classes=10)

with set_seed(SEED):
    resnet18.init_weights(init_kaiming_normal)
    resnet18.init_weights(init_truncated_normal)

# CVIT Parameters

In [41]:
criterion = nn.CrossEntropyLoss()
lr = 1e-3
num_epochs = 30
batch_size = 128
weight_decay = 0
checkpoint_every_th_epoch = None
cvit_optimizer = optim.Adam(cvit.parameters(), lr=lr, weight_decay=weight_decay)

train_compressed = DataLoader(cifar_compressed_train, batch_size=batch_size, shuffle=True)
val_compressed = DataLoader(cifar_compressed_val, batch_size=batch_size, shuffle=True)
test_compressed = DataLoader(cifar_compressed_test, batch_size=len(cifar_compressed_test), shuffle=False)

In [42]:
for epoch in range(num_epochs):
    # Training Phase
    cvit.train()
    train_loss = 0.0
    for images, labels in train_compressed:
        images, labels = images.to(torch.float32), labels

        cvit_optimizer.zero_grad()

        outputs = cvit(images)
        loss = criterion(outputs, labels)
        loss.backward()
        cvit_optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(cifar_compressed_train)

    # Validation Phase
    cvit.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_compressed:
            images, labels = images.to(torch.float32), labels

            outputs = cvit(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(cifar_compressed_val)
    val_accuracy = correct / total

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

Epoch 1/30, Train Loss: 1.9610, Val Loss: 1.8702, Val Accuracy: 0.2674
Epoch 2/30, Train Loss: 1.8527, Val Loss: 1.8531, Val Accuracy: 0.2956
Epoch 3/30, Train Loss: 1.8244, Val Loss: 1.7551, Val Accuracy: 0.3360
Epoch 4/30, Train Loss: 1.7391, Val Loss: 1.7589, Val Accuracy: 0.3336
Epoch 5/30, Train Loss: 1.7142, Val Loss: 1.6513, Val Accuracy: 0.3716
Epoch 6/30, Train Loss: 1.6465, Val Loss: 1.5815, Val Accuracy: 0.4112
Epoch 7/30, Train Loss: 1.6089, Val Loss: 1.5609, Val Accuracy: 0.4216
Epoch 8/30, Train Loss: 1.5754, Val Loss: 1.6861, Val Accuracy: 0.3966
Epoch 9/30, Train Loss: 1.5368, Val Loss: 1.5507, Val Accuracy: 0.4420
Epoch 10/30, Train Loss: 1.5190, Val Loss: 1.5327, Val Accuracy: 0.4514
Epoch 11/30, Train Loss: 1.5027, Val Loss: 1.5107, Val Accuracy: 0.4422
Epoch 12/30, Train Loss: 1.4781, Val Loss: 1.4715, Val Accuracy: 0.4626
Epoch 13/30, Train Loss: 1.4479, Val Loss: 1.4806, Val Accuracy: 0.4636
Epoch 14/30, Train Loss: 1.4356, Val Loss: 1.4578, Val Accuracy: 0.4780
E