In [75]:
import torch.nn as nn
import torch
import math

### InputEmbedding

In [76]:
class InputEmbedding(nn.Module):
    def __init__(self, embed_dim, vocab_size):
        super(InputEmbedding, self).__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.embed_dim)

In [77]:
embedding_layer = InputEmbedding(vocab_size=10_000, embed_dim=512)
embedded_output = embedding_layer(torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]))
embedded_output.shape

torch.Size([2, 4, 512])

### Position Embedding

In [78]:
position = torch.arange(0, 14).unsqueeze(1)
position.shape

torch.Size([14, 1])

In [79]:
div_term = torch.exp(torch.arange(
    0, 512, 2, dtype=torch.float) * -(math.log(10000) / 512)).unsqueeze(0)

div_term.shape

torch.Size([1, 256])

In [80]:
value = (position * div_term).unsqueeze(0)
value.shape

torch.Size([1, 14, 256])

In [81]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        N = 10000.0
        pe = torch.zeros(max_len, embed_dim)
        self.embeddim = nn.Embedding(max_len, embed_dim)

        position = torch.arange(0, max_len).unsqueeze(1)
        # step = 2 ,vì pe cũng nhảy step = 2 
        div_term = torch.exp(torch.arange(0, embed_dim, 2, dtype=torch.float) * -(math.log(N) / embed_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # thêm chiều 0 để 'pe' broadcasting với 'x'
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):

        init_position_embedding = self.embeddim(x)
        # slicing 'pe' để có cùng shape với 'x'
        return init_position_embedding + self.pe[:, :x.size(1)]
    

In [82]:
## Test
pos_encoding_layer = PositionalEncoding(embed_dim=512, max_len=14)
positions = torch.arange(0, 10).expand(3, 10)
pos_encoding = pos_encoding_layer(positions)
pos_encoding.shape

torch.Size([3, 10, 512])

### Input Embedding, Positional Encoding

In [83]:
positions = torch.arange(0, 10).expand(3, 10)
positions

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [84]:
class TokenAndPositionEmbedding(nn.Module): 
    def __init__(self, vocab_size, embed_dim, max_length, device ='cpu'):
        super(TokenAndPositionEmbedding, self).__init__()
        self.embed_dim = embed_dim
        self.device = device
        self.token_embedding = InputEmbedding(embed_dim, vocab_size)
        self.position_embedding = PositionalEncoding(embed_dim, max_length)
        
    def forward(self, x):
        N, seq_len = x.size()
        positions = torch.arange(0, seq_len).expand(N, seq_len).to(self.device)
        return self.token_embedding(x) + self.position_embedding(positions)

## Encoder

In [85]:
class TransformerEncoderBlock (nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dims, prob_drop):
        super().__init__()

        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

        self.dropout1 = nn.Dropout(prob_drop)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dims),
            nn.ReLU(),
            nn.Linear(ff_dims, embed_dim)
        )
        self.dropout2 = nn.Dropout(prob_drop)
        self.layer_norm2 = nn.LayerNorm(embed_dim)

    def forward(self, k, q, v):
        attn_output, _ = self.multihead_attn(k, q, v)
        attn_output = self.dropout1(attn_output)
        attn_output = self.layer_norm1(attn_output + q)
        
        ff_output = self.ffn(attn_output)
        ff_output = self.dropout2(ff_output)
        output = self.layer_norm2(ff_output + attn_output)

        return output

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, max_length, n_layers, embed_dim, num_heads, ff_dims, prob_drop, device='cpu'):
        super().__init__()

        self.token_pos_embedding = TokenAndPositionEmbedding(vocab_size, embed_dim, max_length, device)

        self.encoder_blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, ff_dims, prob_drop) for _ in range(n_layers)
        ])
    def forward(self, x):
        x = self.token_pos_embedding(x)
        for encoder_block in self.encoder_blocks:
            x = encoder_block(x, x, x)
        return x


In [87]:
# Test Encoder

src = torch.randint(
    high=2,
    size=(2, 50),
    dtype=torch.int64
)

test_x = TransformerEncoder(vocab_size=10_000, max_length=50, n_layers=8, embed_dim=512, num_heads=8, ff_dims=2048, prob_drop=0.1)(src)
test_x.shape

torch.Size([2, 50, 512])

## Decoder

In [None]:
class DecoderTransformersBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dims, prob_drop):
        super().__init__(embed_dim, num_heads, ff_dims, prob_drop)
        self.masked_multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, mask=True)
        self.dropout1 = nn.Dropout(prob_drop)
        self.layer_norm1 = nn.LayerNorm(embed_dim)

        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.dropout2 = nn.Dropout(prob_drop)
        self.layer_norm2 = nn.LayerNorm(embed_dim)

        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dims),
            nn.ReLU(),
            nn.Linear(ff_dims, embed_dim)
        )
        self.dropout3 = nn.Dropout(prob_drop)
        self.layer_norm3 = nn.LayerNorm(embed_dim)
    
    def forward(self, target, enc_output, tgt_mask, src_mask):
        masked_attn_output, _ = self.masked_multihead_attn(
            target, target, target, attn_mask=tgt_mask)
        masked_attn_output = self.dropout1(masked_attn_output)
        masked_attn_output = self.layer_norm1(masked_attn_output + target)

        attn_output, _ = self.multihead_attn(
            masked_attn_output, enc_output, enc_output, attn_mask=src_mask)
        
        attn_output = self.dropout2(attn_output)
        attn_output = self.layer_norm2(attn_output + masked_attn_output)

        ffn_output = self.ffn(attn_output)
        ffn_output = self.dropout3(ffn_output)
        output = self.layer_norm3(ffn_output + attn_output)

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, max_length, embed_dim, num_heads, ff_dims, prob_drop, n_layers, device='cpu'):
        super().__init__()

        self.token_pos_embedding = TokenAndPositionEmbedding(vocab_size, embed_dim, max_length, device)

        self.decoder_blocks = nn.ModuleList([
            DecoderTransformersBlock(embed_dim, num_heads, ff_dims, prob_drop) for _ in range(n_layers)
        ])

    def forward(self, x, enc_output, tgt_mask, src_mask):
        x = self.token_pos_embedding(x)
        for decoder_block in self.decoder_blocks:
            x = decoder_block(x, enc_output, tgt_mask, src_mask)
        return x