In [1]:
import torch
import numpy as np
import os
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.models import resnet18

from preprocessing.dataset import CIFAR10_custom 
from preprocessing.transforms import CompressedToTensor, ZigZagOrder, ChooseAC, FlattenZigZag, ConvertToFrequencyDomain, ConvertToYcbcr, Quantize, LUMINANCE_QUANTIZATION_MATRIX, CHROMINANCE_QUANTIZATION_MATRIX
from model.init import init_truncated_normal, init_kaiming_normal, set_seed
from model.vit import CompressedVisionTransformer, VisionTransformer
from model.baseline import ResNet18, NeuralNet

In [2]:
DOWNLAOD_PATH = os.path.join('data', 'cifar10')
SEED = 42
VALIDATION_SET = 0.1
BATCH_SIZE = 128
AC = 5

# Loading in the data

In [3]:
quantization_matrices = [LUMINANCE_QUANTIZATION_MATRIX, CHROMINANCE_QUANTIZATION_MATRIX, CHROMINANCE_QUANTIZATION_MATRIX]

vanilla_transform = Compose([
    ToTensor()
    # Returns pixels in range [0-1]
])

transform = Compose([
    CompressedToTensor(), # 3x32x32
    # Returns pixels in range [0-255]
    ConvertToYcbcr(), # 3x32x32
    # Returns pixels in range [0-1]
    ConvertToFrequencyDomain(norm='ortho'), # 3x32x32
    Quantize(quantization_matrices=quantization_matrices, alpha=1.0, floor=True), # 3x32x32
    ZigZagOrder(), # 3x16x64
    ChooseAC(AC), # 3x16x(AC+1)
    FlattenZigZag() # 3x(16x(AC+1))
])

In [19]:
cifar = CIFAR10(root=DOWNLAOD_PATH, train=True, transform=transform, target_transform=None, download = False)
cifar_test = CIFAR10(root=DOWNLAOD_PATH, train=False, transform=transform, target_transform=None, download = False)

In [5]:
with set_seed(SEED):
    num_train = len(cifar)
    num_val = int(0.1 * num_train)
    num_train -= num_val

    cifar_train, cifar_val = random_split(cifar, [num_train, num_val])

# Models

In [6]:
cvit = CompressedVisionTransformer(ac=AC,
                        channels=3,
                        patch_num=16,
                        num_classes=10,
                        d_model=248,
                        nhead=8,
                        dim_feedforward=1024,
                        dropout=0.1,
                        activation=nn.GELU(),
                        ntransformers=4,
                        layer_norm_eps=1e-5,
                        norm_first=False,
                        bias=True,
                        learnable_positional=True)

vit = VisionTransformer(in_channels=3, 
                        height=32,
                        width=32,
                        patch_size=4,
                        num_classes=10,
                        d_model=248,
                        nhead=8,
                        dim_feedforward=1024,
                        dropout=0.1,
                        activation=nn.GELU(),
                        ntransformers=4,
                        layer_norm_eps=1e-5,
                        norm_first=False,
                        bias=True,
                        learnable_positional=True)



with set_seed(SEED):
    cvit.init_weights(init_truncated_normal)
    vit.init_weights(init_truncated_normal)
    
cvit.pre_training()
vit.pre_training()

In [7]:
resnet18 = ResNet18(channels=3, num_classes=10)
neural_net = NeuralNet(ac=AC,
                       channels=3,
                       patch_num=16,
                       num_classes=10,
                       hidden_size=4,
                       dim_feedforward=1024,
                       activation = nn.ReLU(),
                       bias = True,
                       layer_norm=False)

with set_seed(SEED):
    resnet18.init_weights(init_kaiming_normal)
    resnet18.init_weights(init_truncated_normal)
    neural_net.init_weights(init_truncated_normal)

# CVIT Parameters

In [8]:
cvit_checkpoint = torch.load('checkpoints/cvit5d248w8h30e.pth')
cvit.load_state_dict(cvit_checkpoint)

In [12]:
criterion = nn.CrossEntropyLoss()
lr = 1e-4
num_epochs = 10
batch_size = 128
weight_decay = 0
checkpoint_every_th_epoch = None
cvit_optimizer = optim.Adam(cvit.parameters(), lr=lr, weight_decay=weight_decay)

train_compressed = DataLoader(cifar_compressed_train, batch_size=batch_size, shuffle=True)
val_compressed = DataLoader(cifar_compressed_val, batch_size=batch_size, shuffle=True)
test_compressed = DataLoader(cifar_compressed_test, batch_size=len(cifar_compressed_test), shuffle=False)

In [13]:
for epoch in range(num_epochs):
    # Training Phase
    cvit.train()
    train_loss = 0.0
    for images, labels in train_compressed:
        images, labels = images.to(torch.float32), labels

        cvit_optimizer.zero_grad()

        outputs = cvit(images)
        loss = criterion(outputs, labels)
        loss.backward()
        cvit_optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(cifar_compressed_train)

    # Validation Phase
    cvit.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_compressed:
            images, labels = images.to(torch.float32), labels

            outputs = cvit(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(cifar_compressed_val)
    val_accuracy = correct / total

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

In [11]:
torch.save(cvit.state_dict(), 'checkpoints/cvit5d248w8h60e.pth')

## NN on compressed training

In [8]:
criterion = nn.CrossEntropyLoss()
lr = 1e-3
num_epochs = 30
batch_size = 128
weight_decay = 0
checkpoint_every_th_epoch = None
nn_optimizer = optim.Adam(neural_net.parameters(), lr=lr, weight_decay=weight_decay)

train_compressed = DataLoader(cifar_compressed_train, batch_size=batch_size, shuffle=True)
val_compressed = DataLoader(cifar_compressed_val, batch_size=batch_size, shuffle=True)
test_compressed = DataLoader(cifar_compressed_test, batch_size=len(cifar_compressed_test), shuffle=False)

In [9]:
for epoch in range(num_epochs):
    # Training Phase
    neural_net.train()
    train_loss = 0.0
    for images, labels in train_compressed:
        images, labels = images.to(torch.float32), labels

        nn_optimizer.zero_grad()

        outputs = neural_net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        nn_optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(cifar_compressed_train)

    # Validation Phase
    neural_net.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_compressed:
            images, labels = images.to(torch.float32), labels

            outputs = neural_net(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(cifar_compressed_val)
    val_accuracy = correct / total

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

In [10]:
torch.save(neural_net.state_dict(), 'checkpoints/nn5d1024w30e.pth')

# ResNet18 training

In [43]:
criterion = nn.CrossEntropyLoss()
lr = 1e-3
num_epochs = 30
batch_size = 128
weight_decay = 0
checkpoint_every_th_epoch = None
res_optimizer = optim.Adam(resnet18.parameters(), lr=lr, weight_decay=weight_decay)

train= DataLoader(cifar_train, batch_size=batch_size, shuffle=True)
val = DataLoader(cifar_val, batch_size=batch_size, shuffle=True)
test = DataLoader(cifar_test, batch_size=len(cifar_test), shuffle=False)

In [44]:
for epoch in range(num_epochs):
    # Training Phase
    resnet18.train()
    train_loss = 0.0
    for images, labels in train:
        images, labels = images, labels

        res_optimizer.zero_grad()

        outputs = resnet18(images)
        loss = criterion(outputs, labels)
        loss.backward()
        res_optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(cifar_train)

    # Validation Phase
    resnet18.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val:
            images, labels = images, labels

            outputs = resnet18(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(cifar_val)
    val_accuracy = correct / total

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')