In [None]:
!git clone https://github.com/FabianSommerauer/ViTiny.git

In [None]:
!pip install einops

In [None]:
import os
import matplotlib.pyplot as plt
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
from einops import rearrange

import sys
sys.path.append("ViTiny/src/")

from ViTinyBase import ViTinyBase


Basic constants and functions

In [None]:
BATCH_SIZE = 50
EPOCHS = 10
MODELS_FOLDER = './models'

# normalizes images from [0,1] to [-1, 1]
normalize_images = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def imshow(img, label=None):
    img = img / 2 + 0.5
    img = rearrange(img.numpy(), "c w h -> w h c")
    plt.imshow(img)
    if label:
        plt.title(label)
    plt.show()

Mount google drive (optional)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

MODELS_FOLDER = './drive/MyDrive/vitiny_models'

Load the data

In [None]:
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,
                                             transform=normalize_images)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                                           num_workers=5)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                        transform=normalize_images)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=5)

Prepare some variables and show example images

In [None]:
classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

train_iter = iter(train_loader)
images, labels = next(train_iter)

labels_str = '[' + ', '.join([classes[lbl] for lbl in labels]) + ']'

image_size = images.size()[-2:]

# split image into 4x4 patches
#patch_size = (dim // 4 for dim in image_size)
patch_size = (4, 4)

In [None]:
imshow(torchvision.utils.make_grid(images), labels_str)

Create the model and the necessary loss + optimizer

In [None]:
# cuda0 = torch.device('cuda:0')  # CUDA GPU 0
os.makedirs(MODELS_FOLDER, exist_ok=True)

model = ViTinyBase(image_size, patch_size, len(classes), 8, 16, 16, 8, 16)
# model.to(cuda0)

loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

Now we train the model

In [None]:
print_interval = 200

for epoch in range(EPOCHS):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        # inputs = inputs.to(cuda0)
        # labels = labels.to(cuda0)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % print_interval == print_interval - 1:  # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / print_interval:.3f}')
            running_loss = 0.0

    torch.save(model.state_dict(), os.path.join(MODELS_FOLDER, f'cifar_vitiny_epoch_{epoch + 1}.pth'))

torch.save(model.state_dict(), os.path.join(MODELS_FOLDER, 'cifar_vitiny.pth'))
print('Finished Training')

In [None]:
os.makedirs(MODELS_FOLDER, exist_ok=True)
torch.save(model.state_dict(), os.path.join(MODELS_FOLDER, 'cifar_vitiny_epoch_3.pth'))