In [17]:
# https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html

import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from IPython.display import clear_output
import copy

import torchvision

In [None]:
dataset_train = torchvision.datasets.CIFAR10(root='./data', download=True, transform=torchvision.transforms.ToTensor(), train=True)
dataset_valid = torchvision.datasets.CIFAR10(root='./data', download=True, transform=torchvision.transforms.ToTensor(), train=False)

In [3]:
batch_size = 32
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False)

In [4]:
def img_to_patch(x, patch_size):
  B, C, H, W = x.shape
  x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
  x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
  x = x.flatten(1, 2) # [B, H'*W', C, p_H, p_W]
  x = x.flatten(2, 4) # [B, H'*W', C*p_H*p_W]
  return x

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample_images = next(iter(train_loader))[0].to(device)
img_to_patch(sample_images, 4).shape

In [6]:
class AttentionBlock(nn.Module):
  def __init__(self, embed_dim, hidden_dim, num_heads, dropout=.0):
    super().__init__()

    self.layer_norm_1 = nn.LayerNorm(embed_dim)
    self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)

    self.layer_norm_2 = nn.LayerNorm(embed_dim)
    self.linear = nn.Sequential(
        nn.Linear(embed_dim, hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_dim, embed_dim),
        nn.Dropout(dropout)
    )

  def forward(self, x):
    inp_x = self.layer_norm_1(x)
    x = x + self.attn(inp_x, inp_x, inp_x)[0]
    x = x + self.linear(self.layer_norm_2(x))
    return x

In [7]:
class VisionTransformer(nn.Module):
  def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers,
               num_classes, patch_size, num_patches, dropout=.0):
    super().__init__()

    self.patch_size = patch_size

    self.input_layer = nn.Linear(num_channels * (patch_size ** 2), embed_dim)
    self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout)
                                      for _ in range(num_layers)])
    self.mlp_head = nn.Sequential(
        nn.LayerNorm(embed_dim),
        nn.Linear(embed_dim, num_classes)
    )

    self.dropout = nn.Dropout(dropout)

    self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
    self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))

  def forward(self, x):
    x = img_to_patch(x, self.patch_size)
    B, T, _ = x.shape
    x = self.input_layer(x)

    cls_token = self.cls_token.repeat(B, 1, 1)
    x = torch.cat([cls_token, x], dim=1)
    x = x + self.pos_embedding[:, :T + 1]

    x = self.dropout(x)
    x = x.transpose(0, 1)
    x = self.transformer(x)

    cls = x[0]
    out = self.mlp_head(cls)
    return out

In [None]:
model = VisionTransformer(256, 512, 3, 8, 6, 10, 4, 64, .2).to(device)
model(sample_images).shape

In [26]:
# a function to train a model
# you may need to convert tensors to float32 (before criterium)
# and .to(device)

def compute_error(model, data_loader, criterion, c_sum=False):
    model.eval()
    losses, num_of_el = 0, 0
    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            y = y.to(device)
            outputs = model(x)
            loss = criterion(outputs, y)
            if not c_sum: loss *= len(y)
            losses += loss
            num_of_el += len(y)
    return losses / num_of_el


def train_model(model: nn.Module,
              train_loader: DataLoader,
              valid_loader: DataLoader,
              num_epochs: int,
              optimizer: torch.optim.Optimizer,
              criterion,
              verbose: bool = True,
              verbose_plot: bool = False
              ) -> float:

    best_epoch = None
    best_params = None
    best_val_loss = np.inf
    train_losses, valid_losses = [], []

    for epoch in range(num_epochs):
        model.train()
        _iter = 1
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if verbose:
                if _iter % 100 == 0:
                    print(f"Minibatch {_iter:>6}    |  loss {loss.item():>5.2f}  |")
            _iter += 1

        val_loss = compute_error(model, valid_loader, criterion)

        if val_loss < best_val_loss:
            best_epoch = epoch
            best_val_loss = val_loss
            best_params = [copy.deepcopy(p.detach().cpu()) for p in model.parameters()]

        if verbose:
            clear_output(True)
            m = f"After epoch {epoch:>2} | valid loss: {val_loss:>5.2f}"
            print("{0}\n{1}\n{0}".format("-" * len(m), m))

        if verbose_plot:
            train_loss = compute_error(model, train_loader, criterion)
            train_losses.append(train_loss.detach().cpu())
            valid_losses.append(val_loss.detach().cpu())

    if best_params is not None:
        if verbose:
            print(f"\nLoading best params on validation set in epoch {best_epoch} with loss {best_val_loss:.2f}")
        with torch.no_grad():
            for param, best_param in zip(model.parameters(), best_params):
                param[...] = best_param

    if verbose_plot:
        plt.figure(figsize=(6, 3))
        plt.plot(train_losses, c='b', label='train')
        plt.plot(valid_losses, c='r', label='valid')
        plt.grid(ls=':')
        plt.legend()
        plt.show()

    return best_val_loss

In [None]:
n_epochs = 5
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

train_model(model, train_loader, valid_loader, n_epochs, optimizer, criterion, True, True)

In [None]:
def accuracy_multiple(outputs, y):
    pred = outputs.argmax(dim=1)
    return sum(pred == y)

print(compute_error(model, valid_loader, accuracy_multiple, c_sum=True))