In [1]:
import torch
import torch.nn as nn
from einops import rearrange, repeat

### Patch Embedding
It has three components:
- Convert the image into sequence of patches
- Add CLS token to sequence of patches
- Add positional encoding to all the patches. 

In [36]:
class PatchEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        image_height = config["image_height"]
        image_width = config["image_width"]
        im_channels = config["im_channels"]
        emb_dim = config["emb_dim"] # Transformer dimentions(D)
        patch_embd_dropout = config["patch_emb_dropout"]

        self.patch_height = config["patch_height"]
        self.patch_width = config["patch_width"]

        num_patches = (image_height // self.patch_height) * (image_width // self.patch_width)

        patch_dim = im_channels * self.patch_height * self.patch_width    
        
        # W belongs to R^(patch_dim x emb_dim)
        self.patch_emb = nn.Sequential(
            nn.LayerNorm(patch_dim),           
            nn.Linear(patch_dim, emb_dim),
            nn.LayerNorm(emb_dim),
        )
        
        # Positional information needs to be added to cls as well so 1+num_patches
        self.pos_emb = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))
        self.cls_token = nn.Parameter(torch.randn(emb_dim))    # CLS token belongs to R^emb_dim
        self.patch_emb_dropout = nn.Dropout(patch_embd_dropout)

    def forward(self, x):
        batch_size = x.shape[0]

        out  = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_height, p2=self.patch_width)  # split image into patches
        print(out.shape)

        out = self.patch_emb(out)
        cls_token = repeat(self.cls_token, 'd -> b n d', b=batch_size, n=1)
        out = torch.cat([cls_token, out], dim=1)
        out += self.pos_emb
        out = self.patch_emb_dropout(out)

        return out
