In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self,
                 in_channel,
                 patch_size,
                 emb_dim=None):
        super().__init__()

        d = in_channel * (patch_size ** 2)

        if emb_dim is None:
            emb_dim = d
            
        self.patch_size = patch_size
        self.linear_projection = nn.Linear(in_features=d,
                                           out_features=emb_dim)
    
    def _tokenize(self, x:Tensor)->Tensor:
        input_size = x.shape[-1]
        patches_list = []
        for i in range(0, input_size-1, self.patch_size):
            for j in range(0, input_size-1, self.patch_size):
                patch = x[..., i:i+self.patch_size, j:j+self.patch_size]
                patches_list.append(patch)
        patches_list = torch.stack(patches_list, dim=1)
        return patches_list

    def forward(self, x: Tensor)->Tensor:
        out = self._tokenize(x)
        out = out.view(out.shape[0], out.shape[1], -1)
        out = self.linear_projection(out)
        return out

In [3]:
class PositionalEmbedding(nn.Module):
    def __init__(self, batch_size, num_patches, embedded_dim):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(batch_size, 1, embedded_dim))
        self.position = nn.Parameter(torch.randn(batch_size, num_patches+1, embedded_dim))
    
    def forward(self, x:Tensor):
        out = torch.cat([x, self.cls_token], dim=1)
        out += self.position
        return out

In [4]:
class EmbeddingBlock(nn.Module):
    def __init__(self,
                 batch_size: int,
                 in_channel: int,
                 patch_size: int,
                 num_patches: int):
        super().__init__()
        embedded_dim = in_channel * (patch_size ** 2)
        self.patch_embedding = PatchEmbedding(in_channel=in_channel, patch_size=patch_size)
        self.positional_embedding = PositionalEmbedding(batch_size=batch_size, num_patches=num_patches, embedded_dim=embedded_dim)

    def forward(self, x):
        out = self.patch_embedding(x)
        out = self.positional_embedding(out)
        return out

In [67]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.scailing_factor = (embed_dim // num_heads) ** 0.5

        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.fc_layer = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        queries = self.q_linear(x)
        keys = self.k_linear(x)
        values = self.v_linear(x)
        
        queries = self._split_dimension(queries)
        keys = self._split_dimension(keys)
        values = self._split_dimension(values)

        keys.transpose_(-1, -2)
        attention_score = torch.matmul(queries, keys)
        attention_score /= self.scailing_factor

        attention_weight = F.softmax(attention_score, dim=-1)
        attention = torch.matmul(attention_weight, values)
        attention = attention.transpose(1, 2).contiguous()
        batch_size, num_patches, *_  = attention.shape
        attention = attention.view(batch_size, num_patches, self.embed_dim)
        out = self.fc_layer(attention)
        return out
        
    def _split_dimension(self, x: Tensor):
        batch_size, num_patches, embed_dim = x.shape
        x = x.view(batch_size, num_patches, self.num_heads, embed_dim//self.num_heads)
        x.transpose_(1, 2)
        return x


In [79]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, embed_dim, expansion=4, drop_out=0.1):
        super().__init__(
            nn.Linear(embed_dim, embed_dim*expansion),
            nn.GELU(),
            nn.Dropout(drop_out),
            nn.Linear(embed_dim*expansion, embed_dim),
            )

In [80]:
class EncoderBlock(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 ):
        super().__init__()
        self.lm_layer1 = nn.LayerNorm(embed_dim)
        self.msa_block = MultiHeadAttention(embed_dim, num_heads)
        self.lm_layer2 = nn.LayerNorm(embed_dim)
        self.mlp_block = FeedForwardBlock(embed_dim)

    def forward(self, x):
        iden = x
        out = self.lm_layer1(x)
        out += iden
        iden = out
        out = self.mlp_block(out)
        out += iden
        return out

In [86]:
class Encoder(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 num_repeats,
                 ):
        super().__init__()
        self.blocks = [EncoderBlock(embed_dim, num_heads) for _ in range(num_repeats)]
        self.blocks = nn.ModuleList(self.blocks)
    
    def forward(self, x):
        out = x
        for block in self.blocks:
            out = block(out)
        return out

In [87]:
embedding_block = EmbeddingBlock(4, 3, 16, 196)
input_image = torch.randn(4, 3, 224, 224)
out = embedding_block(input_image)

In [88]:
out.shape

torch.Size([4, 197, 768])

In [89]:
encoder = Encoder(768, 8, 16)


In [91]:
encoder(out).shape

torch.Size([4, 197, 768])