In [22]:
import torch
from torch.nn import functional as F
from torch import nn

print(torch.__version__)

2.0.0


# General transformer modules

In [23]:
class Attention(nn.Module):
    def __init__(self, emb_dim, head_size, masked=False):
        super().__init__()
        self.emb_dim = emb_dim
        self.masked = masked

        #TODO: Check if these projections should have bias or not
        self.toquery = nn.Linear(emb_dim, head_size)
        self.tokey = nn.Linear(emb_dim, head_size)
        self.tovalue = nn.Linear(emb_dim, head_size)

    def forward(self, x):
        b, t, c = x.size()
        # Project input into query, key, and value
        Q = self.toquery(x) # b, t, head_size
        K = self.tokey(x) # b, t, head_size
        V = self.tovalue(x) # b, t, head_size

        # transpose K to swap the second-to-last with the last dimension before matrix multiplication
        att = Q @ K.transpose(-2,-1) # (b, t, head_size) @ (b, head_size, t) = b, t, t
        att_scaled = att / (self.emb_dim ** 0.5)

        # Apply masking to allow tokens to only attend to the left, not to the right
        if self.masked:
            mask = torch.tril(torch.ones(t, t))
            att_scaled = att_scaled.masked_fill(mask == 0, float('-inf'))

        # Softmax scores to get weights
        weights = F.softmax(att_scaled, dim=-1) # b, t, t

        # Multiply softmaxed weights with values
        out = weights @ V # (b, t, t) @ (b, t, head_size) = b, t, head_size

        return out

In [24]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_dim, n_heads, masked=False):
        super().__init__()

        self.emb_dim = emb_dim
        self.n_heads = n_heads

        # Embedding dimension must be divisble by number of heads
        assert emb_dim % n_heads == 0
        head_size = emb_dim // n_heads

        self.heads = nn.ModuleList([
            Attention(emb_dim, head_size, masked) for _ in range(n_heads)
        ])

        self.unifyheads = nn.Linear(emb_dim, emb_dim, bias=False)

    def forward(self, x):
        # Pass input through each attention head
        out = [h(x) for h in self.heads] # n_head x [b, t, head_size]
        # Concatenate outputs from individual heads back along the last (embedding) dimension
        out = torch.cat(out, dim=-1) # b, t, emb_dim
        # Pass concatenated output from all heads through linear layer
        out = self.unifyheads(out) # # b, t, emb_dim

        return out

In [25]:
class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, ff_dim, n_heads, masked=False):
        super().__init__()
        self.mha = MultiHeadAttention(emb_dim, n_heads, masked)
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, emb_dim)
        )
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        # Attention
        x = x + (self.mha(self.ln1(x)))
        # Feed-forward
        x = x + (self.ff(self.ln2(x)))

        return x


# Vision transformer

In [28]:
B = 32
C = 3
H = 224
W = 224

P = 16

In [49]:
img = torch.randint(low=0, high=255, size=(B,C,H,W), dtype=torch.float32)
img.shape

torch.Size([32, 3, 224, 224])

In [50]:
# B, H*W/P^2, P^2*C
img_patch = img.view((B, (H*W)//P**2, P**2 * C))
img_patch.shape

torch.Size([32, 196, 768])

In [77]:
class VisionTransformer(nn.Module):
    def __init__(self, model_dim, tf_ff_dim, tf_layers, tf_heads, patch_size, max_patches, img_channels, cls_hidden_dim, n_classes):
        super().__init__()
        self.patch_size = patch_size

        self.embed_patches = nn.Linear(patch_size ** 2 * img_channels, model_dim)
        self.embed_positions = nn.Embedding(max_patches, model_dim)
        self.cls_token = nn.Parameter(data=torch.rand((1, 1, model_dim))) # batch, patch, model_dim

        self.transformer_blocks = nn.Sequential(*[
            TransformerBlock(model_dim, tf_ff_dim, tf_heads) for _ in range(tf_layers)
        ])

        self.classification_head = nn.Sequential(
            # Dosovitskiy et al. (2022) mention that at pre-training ViT uses a MLP with one hidden layer,
            # but no further details about layer size or activation function
            nn.Linear(model_dim, cls_hidden_dim),
            nn.ReLU(),
            nn.Linear(cls_hidden_dim, n_classes)
        )

    def forward(self, x):
        b, c, h, w = x.size()
        p = self.patch_size

        # Reshape image into sequence of flattened patches
        img_patch = x.view(b, (h*w)//p**2, p**2 * c)

        # Embed patches using linear projection
        patch_embeddings = self.embed_patches(img_patch)

        # Copy cls token along the batch dimension
        cls_tokens = self.cls_token.repeat(b, 1, 1) # b, 1, model_dim

        # Prepend cls tokens so that they are at position 0 of the image patches
        patch_embeddings = torch.cat((cls_tokens, patch_embeddings), dim=1)
        _, t, _ = patch_embeddings.size()

        # Retrieve position embeddings for each patch
        position_embeddings = self.embed_positions(torch.arange(t))

        # Add patch and position embeddings to get the input into the transformer blocks
        tf_input = patch_embeddings + position_embeddings

        # Run input through transformer blocks
        tf_output = self.transformer_blocks(tf_input)

        # Pass output at cls position into classification MLP
        out = self.classification_head(tf_output[:, 0, :])

        return out

In [78]:
vt = VisionTransformer(64, 64*4, 8, 4, 16, 1000, 3, 300, 10)
out = vt(img)
out.shape

torch.Size([32, 10])