In [2]:
from einops import rearrange
import torch
import torch.nn as nn
import numpy as np

In [19]:
from transformers import DistilBertTokenizerFast, DistilBertModel


def get_embedding(sentence):
    tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
    tokens = tokenizer.encode(sentence, return_tensors='pt', padding="max_length", max_length=20)
    model = DistilBertModel.from_pretrained("distilbert-base-uncased")
    return model.embeddings.word_embeddings(tokens)

In [20]:
x1 = get_embedding('my name is jungwoo')
x2 = get_embedding('hi bye')

In [33]:
batch_sample = torch.cat([x1, x2])
batch_sample.size()

d_model = 768
max_length = 20

In [64]:
import math
class PositionalEncoding(nn.Module):
    def __init__(
            self, 
            d_model: int, 
            dropout: float, 
            max_length: int,
        ):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
 
        # Encoding - From formula
        pos_encoding = torch.zeros(max_length, d_model)
        positions = rearrange(torch.arange(0, max_length, dtype=torch.float), 'm -> m 1')

        division_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)) / d_model) # 1000^(2i/dim_model)

        pos_encoding[:, 0::2] = torch.sin(positions * division_term)
        pos_encoding[:, 1::2] = torch.cos(positions * division_term)

        # Saving buffer (same as parameter without gradients needed)
        self.pos_encoding = rearrange(pos_encoding, 'm d -> m 1 d')
 
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

In [65]:
pe = PositionalEncoding(d_model, 0.1, max_length)

In [67]:
batch_embedding = pe(batch_sample)

In [111]:
attn = MultiHeadAttention(d_model=d_model, nheads=8, dropout=0.1)

In [112]:
attn(batch_embedding, batch_embedding,batch_embedding)

tensor([[[-0.0000, -0.1108, -0.0000,  ..., -0.0000, -0.2244,  0.1460],
         [-0.4496, -0.1814, -0.0760,  ..., -0.3067, -0.2866,  0.0262],
         [-0.0000, -0.1287, -0.1443,  ..., -0.2137, -0.1096,  0.0155],
         ...,
         [-0.4968, -0.0000, -0.2834,  ..., -0.1626, -0.1893,  0.0697],
         [-0.4083, -0.0000, -0.2405,  ..., -0.2550, -0.0000,  0.0725],
         [-0.3810, -0.1904, -0.0164,  ..., -0.2757, -0.1953,  0.1651]],

        [[-0.4527, -0.0000, -0.2503,  ..., -0.3200, -0.4835,  0.0000],
         [-0.0000, -0.0000, -0.3693,  ..., -0.3026, -0.5328,  0.0376],
         [-0.3926, -0.2420, -0.2051,  ..., -0.3418, -0.2692,  0.0199],
         ...,
         [-0.4813, -0.1311, -0.3769,  ..., -0.2857, -0.4601, -0.0304],
         [-0.4268, -0.0524, -0.3326,  ..., -0.0000, -0.4997, -0.0761],
         [-0.4626, -0.0930, -0.0000,  ..., -0.3061, -0.4862,  0.0489]]],
       grad_fn=<MulBackward0>)

In [136]:
class SelfAttention(nn.Module):

    def __init__(
            self,
            temperature: float = 0.1,
            attn_dropout: float = 0.1

    ):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(
            self,
            q,
            k,
            v,
            mask = None   
    ):
        d_k = q.size(0)
        k = rearrange(k, 'b n m d -> b n d m')
        attn = torch.matmul(q / self.temperature, k)
        attn = torch.div(attn, np.sqrt(d_k))
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        
        attn = F.softmax(attn, dim = -1)
        attn = self.dropout(attn)
        output = torch.matmul(attn, v)
        return output 



class MultiHeadAttention(nn.Module):
    """
    b: batch size
    m: max_seq length
    n: nheads
    h: head_dim
    """
    def __init__(
        self,
        d_model,
        nheads,
        dropout: float = 0.1,
        attn_dropout: float = 0.1,
        bias = True,

    ):
        super().__init__()
        self.d_model = d_model
        self.nheads = nheads
        self.dropout = nn.Dropout(dropout)
        assert self.d_model % nheads == 0

        self.q = nn.Linear(d_model, d_model, bias = bias)
        self.k = nn.Linear(d_model, d_model, bias = bias)
        self.v = nn.Linear(d_model, d_model, bias = bias)
        self.selfattn = SelfAttention(
            temperature = 0.1,
            attn_dropout = attn_dropout
        )
        self.o = nn.Linear(d_model, d_model, bias = bias)

    def forward(
            self,
            src,
            src_mask = None
    ):
        # MultiheadAttention
        q = rearrange(self.q(src), 'b m (n h) -> b n m h', n = self.nheads)
        k = rearrange(self.k(src), 'b m (n h) -> b n m h', n = self.nheads)
        v = rearrange(self.k(src), 'b m (n h) -> b n m h', n = self.nheads)

        output = self.selfattn(q, k, v, mask = src_mask)
        output = rearrange(output, 'b n m h -> b m (n h)', n = self.nheads)
        output = self.o(output)
        output = self.dropout(output)
        return output

class FeedForwardBlock(nn.Module):

    def __init__(
        self,
        d_model: int,
        dim_feedforward: int,
        activation = nn.SiLU(),
        dropout: float = 0.1,
    ):
        super().__init__()
        
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            activation,
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.feedforward(x)

class ResidualConnection(nn.Module):
    
    def __init__(self, layer):
        super().__init__()
        self.layer = layer

    def forward(self, x, **kwargs):
        return self.layer(x, **kwargs) + x

class PostNormalization(nn.Module):

    def __init__(self, layer, d_model):
        super().__init__()
        self.layer = layer
        self.d_model = d_model
        self.layernorm = nn.LayerNorm(d_model)

    def forward(self, x, **kwargs):
        return self.layernorm(self.layer(x, **kwargs))
    
class PreNormalization(nn.Module):

    def __init__(self, layer, d_model):
        super().__init__()
        self.layer = layer
        self.d_model = d_model
        self.layernorm = nn.LayerNorm(d_model)

    def forward(self, x, **kwargs):
        return self.layer(self.layernorm(x), **kwargs)

class VanillaEncoderLayer(nn.Module):
    def __init__(self, 
            d_model, 
            nheads, 
            dim_feedforward,  
            dropout,
            attn_dropout 
                 
    ):
        super().__init__()
        
        self.attn_layer_norm = nn.LayerNorm(d_model)
        self.ff_layer_norm = nn.LayerNorm(d_model)
        self.attn = PostNormalization(
            ResidualConnection(
                MultiHeadAttention(
                d_model = d_model, 
                nheads = nheads, 
                dropout = dropout, 
                attn_dropout = attn_dropout
                )
            ),
            d_model = d_model
        )
        
        
        self.feedforward = PostNormalization(
            ResidualConnection(
                FeedForwardBlock(
                d_model = d_model,
                dim_feedforward = dim_feedforward,
                dropout = dropout
                )
            ),
            d_model = d_model
        )

        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask = None):
        

        #self attention & residual connection
        output = self.attn(src, src_mask)
        output = self.feedforward(output)
        
        return src

In [137]:
d_model//8

96

In [138]:
enc_layer = VanillaEncoderLayer(d_model=d_model, nheads=n_heads, dim_feedforward=1024,dropout=0.1, attn_dropout=0.1)

In [139]:
enc_layer(batch_embedding)

TypeError: PostNormalization.forward() takes 2 positional arguments but 3 were given

In [118]:
n_heads = 8
head_dim = d_model // n_heads
Q = batch_embedding.view(2, -1, n_heads, head_dim).permute(0, 2, 1, 3)

In [89]:
Q.size()

torch.Size([2, 8, 20, 96])

In [87]:
rearrange(batch_embedding, 'b m (n h) -> b n m h', n=n_heads).size()

torch.Size([2, 8, 20, 96])

In [83]:
batch_embedding.size()

torch.Size([2, 20, 768])