## ViT Walkthrough

1. Patch Embedding
2. Self-Attention
3. Feed Forward Network
4. Transformer Block
5. ViT


In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
img = Image.open("porsche918.jpg")
transform = Compose([Resize((224, 224)), ToTensor()])
img = transform(img).unsqueeze(0)

### Patch Embedding

- in order to handle 2D imgs, we need to convert them from `HxWxC` to `Nx(P^2*C)`, where `HxW` is the height and width of the image, `C` is the number of channels, `(P,P)` is the resolution of the patches, and `N = HW/P^2` (the number of patches in the image).


In [None]:
patch_size = 16
b, c, h, w = img.shape
patch_dim = patch_size * patch_size * c
num_patches = (h // patch_size) ** 2
embed_size = 768
patches = nn.Sequential(
    nn.LayerNorm(h),
    Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size, p2=patch_size),
    nn.Linear(patch_dim, embed_size),
    nn.LayerNorm(embed_size),
)

cls_token = nn.Parameter(torch.randn(embed_size))
pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_size))

In [None]:
img_patches = patches(img)
img_patches.shape

In [None]:
cls_token.shape

In [None]:
cls_token = repeat(cls_token, "d -> b 1 d", b=b)

In [None]:
# prepend the cls_token to the patches
img_patches = torch.cat([cls_token, img_patches], dim=1)
img_patches.shape  # now we see why we added 1 to the num_patches, when initializing pos_embed

In [None]:
img_patches += pos_embed
img_patches.shape

that completes the patch_embedding section of the vision transformer! now, let's put it all together


In [None]:
class PatchEmbedding(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        patch_size: int = 16,
        emb_size: int = 768,
        img_size: int = 224,
        dropout: float = 1e-3,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        patch_dim = patch_size**2 * in_channels  # each patch is patch_size x patch_size x in_channels
        num_patches = (
            img_size // patch_size
        ) ** 2  # (h // patch_size) * (w // patch_size) -> (img_size // patch_size) * (img_size // patch_size) ** 2 since h = w

        self.patch_embedding = nn.Sequential(
            nn.LayerNorm(img_size),
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim, emb_size),
            nn.LayerNorm(emb_size),
        )

        self.cls_token = nn.Parameter(torch.randn(emb_size))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_size))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, c, h, w = x.shape
        x = self.patch_embedding(x)
        cls_tokens = repeat(self.cls_token, "d -> b 1 d", b=batch_size)
        x = torch.cat((cls_tokens, x), dim=1)

        x += self.pos_embedding
        x = self.dropout(x)
        return x

In [None]:
PatchEmbedding()(img).shape  # same shape as above

MultiHead Self-Attention (MHA)

- this is where the magic happens!


In [None]:
num_heads = 8
queries = nn.Linear(embed_size, embed_size)
keys = nn.Linear(embed_size, embed_size)
values = nn.Linear(embed_size, embed_size)
proj = nn.Linear(embed_size, embed_size)

In [None]:
queries(img_patches).shape  # same shape for the keys and values

In [None]:
queries = rearrange(queries(img_patches), "b n (h d) -> b h n d", h=num_heads)
keys = rearrange(keys(img_patches), "b n (h d) -> b h n d", h=num_heads)
values = rearrange(values(img_patches), "b n (h d) -> b h n d", h=num_heads)

In [None]:
# compute the dot product between the queries and keys
scores = torch.einsum("bhad, bhcd -> bhac", queries, keys)
scores.shape

In [None]:
scores /= embed_size ** (1 / 2)  # think of this as some sort of normalization
attn = F.softmax(
    scores, dim=-1
)  # the glorious attention! softmax ensures that the values are between 0 and 1 and sum to 1
attn.shape

In [None]:
# apply attn to the values matrix
out = torch.einsum("bhad, bhdc -> bhac", attn, values)
out.shape

In [None]:
out = rearrange(out, "b h n d -> b n (h d)")
out.shape

In [None]:
# lastly, apply the projection
out = proj(out)
out.shape

again, you know the drill, let's put it together!


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size: int = 768, num_heads: int = 8, dropout: float = 1e-3) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        self.attn_dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_size, embed_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        queries = rearrange(self.query(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.key(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.value(x), "b n (h d) -> b h n d", h=self.num_heads)
        # dot product between queries and keys
        scores = torch.einsum("bhad, bhcd -> bhac", queries, keys)
        scores /= self.embed_size ** (1 / 2)
        attn = F.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        # dot prod between attn scores and the values
        out = torch.einsum("bhad, bhdc -> bhac", attn, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.proj(out)
        return out

In [None]:
MultiHeadAttention()(img_patches).shape  # nice

Feed Forward Network (FFN)

- this is allows the model to learn more about the features of the image


In [None]:
expansion = 4
dropout = 1e-3
ffn = nn.Sequential(
    nn.Linear(embed_size, embed_size * expansion),
    nn.GELU(),
    nn.Dropout(dropout),
    nn.Linear(embed_size * expansion, embed_size),
    nn.Dropout(dropout),
)

In [None]:
out = ffn(out)
out.shape

you know the drill!


In [None]:
class FFN(nn.Module):
    def __init__(self, embed_size: int = 768, expansion: int = 4, dropout: float = 1e-3) -> None:
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(embed_size, embed_size * expansion),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_size * expansion, embed_size),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.ffn(x)

In [None]:
FFN()(out).shape  # nice

## Transformer Block

- let's put it all (well, most) together


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size: int = 768, num_heads: int = 8, dropout: float = 1e-3) -> None:
        super().__init__()

        self.attn = MultiHeadAttention(embed_size, num_heads, dropout)
        self.ffn = FFN(embed_size, 4, dropout)
        self.ffn_norm = nn.LayerNorm(embed_size)
        self.attn_norm = nn.LayerNorm(embed_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x))
        x = x + self.ffn(self.ffn_norm(x))
        return x

In [None]:
TransformerBlock()(img_patches).shape  # nice

ViT

- lets _actually_ put it together now


In [None]:
class ViT(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        patch_size: int = 16,
        img_size: int = 224,
        embed_size: int = 768,
        num_heads: int = 8,
        depth: int = 12,
        num_classes: int = 1000,
        dropout: float = 1e-3,
    ) -> None:
        super().__init__()

        self.patch_embed = PatchEmbedding(in_channels, patch_size, embed_size, img_size, dropout)
        self.blocks = nn.ModuleList([TransformerBlock(embed_size, num_heads, dropout) for _ in range(depth)])
        self.fc = nn.Linear(embed_size, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        for block in self.blocks:
            x = block(x)
        x = self.fc(x)
        return x

In [None]:
ViT()(img).shape  # nice

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT().to(device)
summary(ViT(), (3, 224, 224), device=device)