In [1]:
from pathlib import Path
import torch
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
# Uncomment this line for MNIST training.
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
from tqdm import tqdm, trange
from torch.optim import Adam

np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7e5f783b44d0>

In [2]:
def patchify(images, n_patches):
    """
    In order to "sequentially" pass in the images, we can break down the main image into multiple sub-images
    and map them to a vector. This is exactly what this function does.

    Arguments:
    images: The image passed into this function
    n_patches: The number of patches to split the image into.

    Returns our patches aka the sub-images.
    """
    n, c, h, w = images.shape

    assert h == w, "Only for square images"

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches


def positional_embeddings(sequence_length, d):
    """
    In order for the model to know where to place each image, one can use positional embeddings where high freq values
    are classified into the first few dimensions while low frequency values are added on to the latter dimensions. This
    function performs exactly that. It has two parameters.

    Arguments:
    sequence_length: The number of tokens for the dataset.
    d: The dimensionality for each token.

    Returns a matrix where each (i,j) is added as token i in dimension j.
    """
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** (j / d)))
    return result


class MSA(torch.nn.Module):
    """
    This is the template implementation of the "Multi-Scale Attention" Layer.

    The query, key and value mapping are matrix-multipled against each other in order to
    find the attention, or, the relation of a word and its interaction with surrounding words.
    """
    def __init__(self, d, n_heads=4):
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0  # Shouldn't divide dimension (d) into n_heads

        d_head = int(d / n_heads)
        self.q_mappings = torch.nn.ModuleList([torch.nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = torch.nn.ModuleList([torch.nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = torch.nn.ModuleList([torch.nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, sequences):
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


class Residual(torch.nn.Module):
    """
    This is how a Residual Layer is built. The MSA that we have written will be a part
    of this residual block right here.
    """

    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(Residual, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads
        self.norm1 = torch.nn.LayerNorm(hidden_d)
        self.mhsa = MSA(hidden_d, n_heads)
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(hidden_d, mlp_ratio * hidden_d),
            torch.nn.GELU(),
            torch.nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out


class ViT(torch.nn.Module):
    """
    The workflow will be as follows.
        1. Find the linear mapping of the input
        2. Embed them using the function that we have written
        3. Use 'n' MSA blocks and add a linear and a softmax layer at the end
    """

    def __init__(self, chw, n_patches=16, n_blocks=2, hidden_d=8, n_heads=4, out_d=10):
        super(ViT, self).__init__()

        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d

        # Input and patch sizes
        assert chw[1] % n_patches == 0
        assert chw[2] % n_patches == 0
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # Linear mapping
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = torch.nn.Linear(self.input_d, self.hidden_d)

        # Classification token
        self.v_class = torch.nn.Parameter(torch.rand(1, self.hidden_d))

        # Positional embedding
        self.register_buffer('positional_embeddings', positional_embeddings(n_patches ** 2 + 1, hidden_d),
                             persistent=False)

        # Encoder blocks
        self.blocks = torch.nn.ModuleList([MSA(hidden_d, n_heads) for _ in range(n_blocks)])

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(self.hidden_d, out_d),
            torch.nn.Softmax(dim=-1)
        )

    def forward(self, images):
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)

        # rutorch.nning tokenization
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.v_class.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)

        for block in self.blocks:
            out = block(out)

        out = out[:, 0]
        return self.mlp(out)

In [8]:
def main(train_loader, test_loader):
    """
    This code contains the training and testing loop for training the vision transformers model. It requires two
    parameters

    :param train_loader: The dataloader for the training set for training the model.
    :param test_loader: The dataloader for the testing set during evaluation phase.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    mnist_model = ViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    epochs = 1
    lr = 0.005

    optimizer = Adam(mnist_model.parameters(), lr=lr)
    criterion = CrossEntropyLoss()
    for epoch in trange(epochs, desc="train"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}/{epochs} loss: {train_loss:.2f}")

    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)

        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [10]:
# For MNIST: comment out the lines above and uncomment the lines below.

transform = transforms.ToTensor()
train_mnist = MNIST(root='./mnist', train=True, download=True, transform=transform)
test_mnist = MNIST(root='./mnist', train=False, download=True, transform=transform)
train_loader = DataLoader(train_mnist, shuffle=True, batch_size=128)
test_loader = DataLoader(test_mnist, shuffle=False, batch_size=128)
main(train_loader=train_loader, test_loader=test_loader)

Using device:  cuda (Tesla T4)


train:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:00<05:28,  1.43it/s][A
Epoch 1 in training:   0%|          | 2/469 [00:01<05:27,  1.42it/s][A
Epoch 1 in training:   1%|          | 3/469 [00:02<05:28,  1.42it/s][A
Epoch 1 in training:   1%|          | 4/469 [00:02<05:26,  1.43it/s][A
Epoch 1 in training:   1%|          | 5/469 [00:03<05:25,  1.43it/s][A
Epoch 1 in training:   1%|▏         | 6/469 [00:04<05:26,  1.42it/s][A
Epoch 1 in training:   1%|▏         | 7/469 [00:04<05:26,  1.42it/s][A
Epoch 1 in training:   2%|▏         | 8/469 [00:05<05:35,  1.38it/s][A
Epoch 1 in training:   2%|▏         | 9/469 [00:06<05:30,  1.39it/s][A
Epoch 1 in training:   2%|▏         | 10/469 [00:07<05:28,  1.40it/s][A
Epoch 1 in training:   2%|▏         | 11/469 [00:07<05:26,  1.40it/s][A
Epoch 1 in training:   3%|▎         | 12/469 [00:08<05:45,  1.32it/s][A
Epoch 1 in training:   3

Epoch 1/1 loss: 2.15


Testing: 100%|██████████| 79/79 [00:33<00:00,  2.32it/s]

Test loss: 2.08
Test accuracy: 37.75%



