In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import os
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
#from linformer import Linformer

In [2]:
from vit import ViT

In [3]:
# Training settings
batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [5]:
device = 'cuda'

In [6]:
train_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()])

val_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()])

test_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()])

In [7]:
trainset = datasets.CIFAR10(root='./data',
                            train=True,
                            download=False, 
                            transform=train_transforms)

trainloader = DataLoader(trainset,
                         batch_size=batch_size,
                         shuffle=True,
                         num_workers=2)

testset = datasets.CIFAR10(root='./data',
                           train=False,
                           download=False,
                           transform=val_transforms)

testloader = DataLoader(testset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=2)

In [8]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=10,
    #transformer=efficient_transformer,
    channels=3,
    depth=12,
    heads=8,
    mlp_dim = 2048
).to(device)

In [9]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [10]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(trainloader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(trainloader)
        epoch_loss += loss / len(trainloader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in testloader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(testloader)
            epoch_val_loss += val_loss / len(testloader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 1 - loss : 2.0846 - acc: 0.2100 - val_loss : 1.9870 - val_acc: 0.2642



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 2 - loss : 1.9114 - acc: 0.2927 - val_loss : 1.8162 - val_acc: 0.3363



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 3 - loss : 1.7560 - acc: 0.3597 - val_loss : 1.6610 - val_acc: 0.4058



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 4 - loss : 1.6286 - acc: 0.4125 - val_loss : 1.5632 - val_acc: 0.4362



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 5 - loss : 1.5566 - acc: 0.4396 - val_loss : 1.5259 - val_acc: 0.4534



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 6 - loss : 1.5091 - acc: 0.4587 - val_loss : 1.4736 - val_acc: 0.4710



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 7 - loss : 1.4681 - acc: 0.4738 - val_loss : 1.4658 - val_acc: 0.4714



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 8 - loss : 1.4378 - acc: 0.4862 - val_loss : 1.4237 - val_acc: 0.4951



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 9 - loss : 1.3965 - acc: 0.5013 - val_loss : 1.3881 - val_acc: 0.5052



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 10 - loss : 1.3678 - acc: 0.5123 - val_loss : 1.3814 - val_acc: 0.5121



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 11 - loss : 1.3383 - acc: 0.5223 - val_loss : 1.3512 - val_acc: 0.5166



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 12 - loss : 1.3189 - acc: 0.5320 - val_loss : 1.3217 - val_acc: 0.5307



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 13 - loss : 1.2939 - acc: 0.5422 - val_loss : 1.3216 - val_acc: 0.5255



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 14 - loss : 1.2706 - acc: 0.5509 - val_loss : 1.3382 - val_acc: 0.5291



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 15 - loss : 1.2553 - acc: 0.5554 - val_loss : 1.2619 - val_acc: 0.5510



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 16 - loss : 1.2354 - acc: 0.5613 - val_loss : 1.2548 - val_acc: 0.5497



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 17 - loss : 1.2160 - acc: 0.5711 - val_loss : 1.2349 - val_acc: 0.5603



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 18 - loss : 1.2014 - acc: 0.5747 - val_loss : 1.2756 - val_acc: 0.5406



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 19 - loss : 1.1862 - acc: 0.5801 - val_loss : 1.2298 - val_acc: 0.5603



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 20 - loss : 1.1732 - acc: 0.5833 - val_loss : 1.2332 - val_acc: 0.5641

