info


In [2]:
import torch
import torch.nn as nn


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, num_patch, in_dim, latent_dim) -> None:
        super().__init__()

        self.project = nn.Conv2d(in_channels=in_dim, out_channels=latent_dim, kernel_size=image_size, stride=image_size)
        self.cls_token = nn.Parameter(torch.zeros((1,1,latent_dim))) # for concat, one as single patch, latent_dim
        self.pos_embed = nn.Parameter(torch.rand((num_patch + 1, latent_dim))) # for [CLS](+)Patches = 1 + num_path

    def forward(self, img:torch.Tensor):
        # Projection -> Flatten -> add [CLS] token -> Position
        # used in TNT (TokenS-To-TokenS) and CaiT (Cross-Attention Interaction Transformer)

        patches = self.project(img) # (B, C, W, H) -> (B, D, P, P) 
        patches = patches.flatten(2).transpose(1, 2) # (B, num_patch, D)

        patches = torch.concat([self.cls_token, patches], dim=1) # (B, num_patch + 1, D)
        patches = patches + self.pos_embed # (B, num_patch + 1, D) + (num_path + 1, D)...BOARDCAST

        return patches

In [4]:
class PatchEmbedding_ViT(nn.Module):
    def __init__(self, image_size, num_patch, in_dim, latent_dim) -> None:
        super(PatchEmbedding_ViT, self).__init__()

        patch_size = image_size // int(num_patch**0.5) # P*P = H*W / N

        self.project = nn.Linear(in_features= patch_size*patch_size*in_dim, out_features=latent_dim)
        self.cls_token = nn.Parameter(torch.zeros(1,1,latent_dim))
        self.pos_embed = nn.Parameter(torch.zeros(...))

    def forward(self, img):
        # Flatten -> Projection -> Add [CLS] token -> Position
        B, _,_,_ = img.shape

        patches = img.reshape(B, self.num_patch, -1) # (B, C, W, H) -> (B, N, P*P*C ) -- P*P = H*W/N 
        patches = self.project(patches) # (B, N, P*P*C) -> (B, N, D)
        patches = torch.concat([self.cls_token, patches])
        patches = patches + self.pos_embed

        return patches