In [2]:
import math
import torch
from torch import nn
import torch.nn.functional as F

In [25]:
class MultiHeadAttention(nn.Module):
    def __init__(self, q_dim, k_dim, v_dim, hidden_dim, num_head, dropout):
        super(MultiHeadAttention, self).__init__()
        self.num_head = num_head
        self.hidden_dim = hidden_dim
        self.W_q = nn.Linear(q_dim, hidden_dim)
        self.W_k = nn.Linear(k_dim, hidden_dim)
        self.W_v = nn.Linear(v_dim, hidden_dim)
        self.W_o = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, query, key, value, valid_len = None, causal = False):
        head_dim = self.hidden_dim // self.num_head
        B, Lq, _ = query.shape
        _, Lk, _ = key.shape
        _, Lv, _ = value.shape
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        Q = Q.reshape(B, Lq, self.num_head, head_dim).permute(0,2,1,3)
        K = K.reshape(B, Lk, self.num_head, head_dim).permute(0,2,1,3)
        V = V.reshape(B, Lv, self.num_head, head_dim).permute(0,2,1,3)
        scale = Q @ K.transpose(2,3) / (head_dim**0.5)
        if valid_len != None:
            mask = torch.arange(Lk)[None, None, None, : ] >= valid_len[:, None, None, None]
            scale = scale.masked_fill(mask, -1e6)
        if causal:
            causal_mask = torch.zeros(Lq,Lk).triu(1).bool()
            scale = scale.masked_fill(causal_mask[None, None, :,:], -1e6)
        weight = F.softmax(scale, dim = -1) 
        weight = self.dropout(weight)
        out = weight @ V
        O = out.reshape(B, Lq, self.hidden_dim)
        return self.W_o(O)

In [26]:
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, ff_dim, dropout):
        super().__init()
        self.fc1 = nn.Linear(hidden_dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        return  self.fc2(self.dropout(F.relu(self.fc1(x))))

In [27]:
class EncoderBlock(nn.Module):
    def __init__(self, hidden_dim, num_head, ff_dim, dropout):
        super().__init()
        self.attention = MultiHeadAttention(hidden_dim, hidden_dim, hidden_dim, hidden_dim, num_head, dropout)
        self.mlp = FeedForward(hidden_dim, ff_dim, dropout)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
    def forward(self, x):
        att = self.attention(x, x, x)
        x = self.norm1(x + att)
        m = self.mlp(x)
        x = self.norm2(x + m)
        return x    

In [29]:
class Bert(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layer, num_head, ff_dim, max_len = 1000, dropout = 0.5):
         super().__init__()
         self.num_layer = num_layer
         self.tokenEmbedding = nn.Embedding(vocab_size, hidden_dim)
         self.positionEmbedding = nn.Embedding(max_len, hidden_dim)
         self.segEmbedding = nn.Embedding(2, hidden_dim)
         self.encoder = nn.ModuleList([EncoderBlock(hidden_dim, num_head, ff_dim, dropout) for _ in range(num_layer)])
         self.norm = nn.LayerNorm(hidden_dim)
         self.mlm = nn.Linear(hidden_dim, vocab_size)
         self.nsp = nn.Linear(hidden_dim, 2)
    def forward(self, tokens, seg):
          B, L = tokens.shape
          pos = torch.arange(L).unsqueeze(0).expand(B, L)
          X = self.tokenEmbedding(tokens) + self.positionEmbedding(pos) + self.segEmbedding(seg)
          for layer in self.encoder:
              X = layer(X)
          X = self.norm(X)
          mlm_head = self.mlm(X)
          nsp_head = self.nsp(X[:, 0, :])
          return mlm_head, nsp_head

In [6]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size = 224, patch_size = 16, channel = 3, hidden_dim = 768):
        super().__init__()
        self.hidden_dim =hidden_dim
        self.patch_size = patch_size
        self.num_patchs = (image_size//patch_size)**2
        self.patch_dim = self.patch_size**2*channel
        self.proj = nn.Linear(self.patch_dim, hidden_dim)
    def forward(self, X):
        B, C, H, W = X.shape
        X = X.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        X = X.contiguous().view(B, C, self.num_patchs, -1)
        X = X.permute(0, 2, 1, 3).contiguous().view(B, -1, self.patch_dim)
        return self.proj(X)
        

In [10]:
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, channel, hidden_dim, num_layer, num_head, ff_dim, num_class, dropout = 0.5):
        super().__init__()
        self.patchembed = PatchEmbedding(image_size, patch_size, channel, hidden_dim)
        self.num_layer = num_layer
        self.image_size = image_size
        self.patch_size = patch_size
        self.clsToken = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        self.posEmbed = nn.Parameter(torch.zeros(1, (image_size//patch_size)**2 + 1, hidden_dim))
        self.encoder = nn.ModuleList([EncoderBlock(hidden_dim, num_head, ff_dim, dropout) for _ in range(num_layer)])
        self.mlp = nn.Linear(hidden_dim, num_class)
    def forward(self, X):
        B = X.size(0)
        X = self.patchembed(X)
        cls = self.clsToken.expand(B, -1, -1)
        X = torch.cat((cls, X), dim = 1)
        pos = self.posEmbed[:, :X.size(1), :]
        X = X + pos
        for layer in self.encoder:
            X = layer(X)
        return self.mlp(X[:, 0, :])
        
        
        