In [8]:
import torch.nn as nn
import torch
import math

### InputEmbedding

In [9]:
class InputEmbedding(nn.Module):
    def __init__(self, embed_dim, vocab_size):
        super(InputEmbedding, self).__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.embed_dim)

In [10]:
embedding_layer = InputEmbedding(vocab_size=10_000, embed_dim=512)
embedded_output = embedding_layer(torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]))
embedded_output.shape

torch.Size([2, 4, 512])

### Position Embedding

In [11]:
position = torch.arange(0, 14).unsqueeze(1)
position.shape

torch.Size([14, 1])

In [12]:
div_term = torch.exp(torch.arange(
    0, 512, 2, dtype=torch.float) * -(math.log(10000) / 512)).unsqueeze(0)

div_term.shape

torch.Size([1, 256])

In [13]:
value = (position * div_term).unsqueeze(0)
value.shape

torch.Size([1, 14, 256])

In [20]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        N = 10000.0
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)

        # step = 2 ,vì pe cũng nhảy step = 2 
        div_term = torch.exp(torch.arange(0, embed_dim, 2, dtype=torch.float) * -(math.log(N) / embed_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # thêm chiều 0 để 'pe' broadcasting với 'x'
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        
        # slicing 'pe' để có cùng shape với 'x'
        return x + self.pe[:, :x.size(1)]
    

In [15]:
## Test
pos_encoding_layer = PositionalEncoding(embed_dim=512, max_len=14)

pos_encoding = pos_encoding_layer(embedded_output)
pos_encoding.shape

torch.Size([2, 4, 512])

### Input Embedding, Positional Encoding

In [16]:
positions = torch.arange(0, 10).expand(3, 10)
positions

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [17]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, vocab_size , embed_dim , max_length , device ='cpu'):
        super(TokenAndPositionEmbedding, self).__init__()
        self.device = device
        self.token_embedding = InputEmbedding(embed_dim, vocab_size)
        self.position_embedding = PositionalEncoding(embed_dim, max_length)
        
    def forward(self, x):
        N, seq_len = x.size()
        positions = torch.arange(0, seq_len).expand(N, seq_len).to(self.device)
        return self.token_embedding(x) + self.position_embedding(positions)

In [19]:
## Test
token_pos_embedding_layer = TokenAndPositionEmbedding(vocab_size=10_000, embed_dim=512, max_length=50)
token_pos_embedding_layer

TokenAndPositionEmbedding(
  (token_embedding): InputEmbedding(
    (embedding): Embedding(10000, 512)
  )
  (position_embedding): PositionalEncoding()
)

## Encoder

In [None]:
class TransformerEncoderBlock (nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dims, prob_drop, device='cpu'):
        super().__init__()

        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

        self.dropout1 = nn.Dropout(prob_drop)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dims),
            nn.ReLU(),
            nn.Linear(ff_dims, embed_dim)
        )
        self.dropout2 = nn.Dropout(prob_drop)
        self.layer_norm2 = nn.LayerNorm(embed_dim)

    def forward(self, k, q, v):
        attn_output, _ = self.multihead_attn(k, q, v)
        attn_output = self.dropout1(attn_output)
        attn_output = self.layer_norm1(attn_output + q)
        
        ff_output = self.ffn(attn_output)
        ff_output = self.dropout2(ff_output)
        output = self.layer_norm2(ff_output + attn_output)

        return output

