In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# 1. Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Tạo ma trận PE
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (seq_len, batch, d_model)
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [3]:
# 2. Scaled Dot-Product Attention
def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        attn = dropout(attn)

    return torch.matmul(attn, value), attn

In [4]:
# 3. Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.attn = None  # để debug nếu cần

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections + split into heads
        Q = self.linear_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.linear_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.linear_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply attention on all heads
        x, self.attn = scaled_dot_product_attention(Q, K, V, mask=mask, dropout=self.dropout)

        # Concatenate heads and apply final linear
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.linear_out(x)

In [5]:
# 4. Feed Forward (Position-wise FFN)
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [6]:
# 5. Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention + residual + norm
        attn_output = self.self_attn(x, x, x, mask=mask)
        x = self.norm1(x + self.dropout(attn_output))

        # Feed-forward + residual + norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [7]:
# 6. Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Masked self-attention
        attn1 = self.self_attn(x, x, x, mask=tgt_mask)
        x = self.norm1(x + self.dropout(attn1))

        # Cross-attention (Encoder-Decoder attention)
        attn2 = self.cross_attn(x, enc_output, enc_output, mask=src_mask)
        x = self.norm2(x + self.dropout(attn2))

        # Feed-forward
        ff = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff))
        return x

In [8]:
# 7. Transformer hoàn chỉnh
class Transformer(nn.Module):
    def __init__(self, src_vocab_size: int, tgt_vocab_size: int,
                 d_model: int = 512, num_heads: int = 8, num_layers: int = 6,
                 d_ff: int = 2048, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len, dropout)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.final_linear = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # src: (batch, src_len), tgt: (batch, tgt_len)

        src = self.src_embedding(src) * math.sqrt(self.src_embedding.embedding_dim)
        tgt = self.tgt_embedding(tgt) * math.sqrt(self.tgt_embedding.embedding_dim)

        src = self.pos_encoder(src.transpose(0, 1)).transpose(0, 1)   # (batch, seq, d)
        tgt = self.pos_encoder(tgt.transpose(0, 1)).transpose(0, 1)

        # Encoder
        enc_output = src
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, src_mask)

        # Decoder
        dec_output = tgt
        for layer in self.decoder_layers:
            dec_output = layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.final_linear(dec_output)
        return output


# Tạo mask (rất quan trọng)
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
    """ Tạo mask cho decoder self-attention (ngăn nhìn tương lai) """
    mask = torch.triu(torch.ones(sz, sz), diagonal=1).bool()
    return ~mask  # True = được phép chú ý

In [9]:
# Ví dụ sử dụng
if __name__ == "__main__":
    model = Transformer(
        src_vocab_size=10000,
        tgt_vocab_size=10000,
        d_model=512,
        num_heads=8,
        num_layers=6,
        d_ff=2048,
        dropout=0.1
    )

    # batch=2, src_len=10, tgt_len=12
    src = torch.randint(0, 10000, (2, 10))
    tgt = torch.randint(0, 10000, (2, 12))

    tgt_mask = generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

    output = model(src, tgt, tgt_mask=tgt_mask)
    print(output.shape) 

torch.Size([2, 12, 10000])
