In [1]:
#| default_exp core

# Vision Transformer

> In this notebook, we will go through a basic implementation of a basic Vision Transformer in `PyTorch`.

In [4]:
#| export
import einops as ein

import torch
import torch.nn as nn
import torch.nn.functional as F

## Exctracting patches and projecting to embedding space

> The first step will be breaking the image into patches and projecting them into the embedding space.

In [26]:
class PatchEmbedding(nn.Module):
    """Breaks an input image into patches and projects them into an embedding space."""
    
    def __init__(self,
                 patch_size, # Patch size. As as starting point, we'll be using squared patches.
                 d_emb, # Embedding dim.
                 in_channels=3, # Image channels.
                 ):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
        self.d_emb = d_emb

        self.embedding = nn.Linear(torch.multiply(*self.patch_size)*in_channels, d_emb)

    def forward(self,
                inputs,
                ):

        ## 1. Break image into patches and flatten them putting together the batch and patch dims
        patches = ein.rearrange(inputs, "b c (h h2) (w w2) -> b (h w) (c h2 w2)", h2=self.patch_size[0], w2=self.patch_size[1])
        n_patches = patches.shape[1]
        
        ## 1.1. Put together the batch and patch dims
        patches = ein.rearrange(patches, "batch patches pixels -> (batch patches) pixels")

        ## 2. Project them into embedding space
        patches_emb = self.embedding(patches)

        ## 3. Recover the patch dim
        patches_emb = ein.rearrange(patches_emb, "(batch patches) d_emb -> batch patches d_emb", patches=n_patches)

        return patches_emb

In [30]:
patch_size = (14, 14)
batch_size = 4
d_emb = 50
sample_input = torch.randn(size=(batch_size,3,28,28))
patches = ein.rearrange(sample_input, "b c (h h2) (w w2) -> (b h w) (c h2 w2)", h2=patch_size[0], w2=patch_size[1])
patches.shape

torch.Size([16, 588])

In [32]:
pe = PatchEmbedding(patch_size=patch_size, in_channels=3, d_emb=d_emb)
sample_pe = pe(sample_input)
sample_pe.shape

torch.Size([4, 4, 50])

## Class Token

> To be able to perform classification tasks, the *ViT* requieres a class token [DSExchange Explanation](https://datascience.stackexchange.com/questions/90649/class-token-in-vit-and-bert).

In [62]:
class ClassToken(nn.Module):
    """Prepends a CLASS token to the embedding the space to perform classification with a ViT."""

    def __init__(self,
                 d_emb, # Embedding dim.
                 ):
        super(ClassToken, self).__init__()
        self.cls_token = nn.Parameter(torch.zeros(d_emb))
    
    def forward(self,
                inputs, # Embeddings tensor [Batch, N_patch, D_emb] to add a class token.
                ):
        ## 0. Extract batch_size from input
        batch_size, _, _ = inputs.shape

        ## 1. Add fake dims to be able to concatenate the class token with the embedding tensor.
        cls_token = ein.repeat(self.cls_token, "d_emb -> batch n_patches d_emb", batch=batch_size, n_patches=1)

        ## 2. Concatenate the class Token at the beggining of the sequence.
        return torch.cat((cls_token, inputs), dim=1)

In [63]:
ct = ClassToken(d_emb=d_emb)
sample_pec = ct(sample_pe)
sample_pec.shape

torch.Size([4, 5, 50])

## Position Embedding

> To give the model a notion of spatial distribution, we will be including position information for each patch using patch encoding.

In [66]:
class PositionEncoding(nn.Module):
    """Includes a position encoding into an embedding."""

    def __init__(self,
                 d_emb, # Embedding dim.
                 ):
        super(PositionEncoding, self).__init__()
        self.position_encoding = nn.Parameter(torch.zeros(d_emb))

    def forward(self,
                inputs,# Embeddings tensor [Batch, N_patch + 1, D_emb] to add a class token.
                ):
                
        ## 1. Add the position encoding to the input vector.
        return inputs + self.position_encoding

In [69]:
pose = PositionEncoding(d_emb=d_emb)
sample_pecpose = pose(sample_pec)
assert (sample_pecpose == sample_pec).all() # It's initialized to 0, so input and output must be equal.
sample_pecpose.shape

torch.Size([4, 5, 50])