In [21]:
import torch
from torch import nn

# We simulate data like this:
BATCH_SIZE = 2
SEQ_LENGTH = 10
EMBEDDING_DIM = 64
MODEL_DIM = 64
# Fake input
dummy_input = torch.rand(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_DIM)

s1, s2 = dummy_input[0, :], dummy_input[1, :]
s1 = s1.unsqueeze(0)
s2 = s2.unsqueeze(0)


dummy_raw = torch.randint(0, 100, (BATCH_SIZE, SEQ_LENGTH))

In [2]:
import math

class PositionalEncoder(nn.Module):
    def __init__(self, model_dim, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, model_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, model_dim, 2).float() * (-math.log(10000.0) / model_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.shape[1], :]

In [3]:
# pe = PositionalEncoder(MODEL_DIM)
# pe(dummy_input).shape

In [29]:
from torch import nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, model_dim):
        super().__init__()

        self.w_key = nn.Linear(model_dim, model_dim, bias=False)
        self.w_query = nn.Linear(model_dim, model_dim, bias=False)
        self.w_value = nn.Linear(model_dim, model_dim, bias=False)
        self.scale = model_dim ** 0.5

    def forward(self, batch, encoder_output=None, mask=None):
        '''
        if encoder output is None, it calculates self-attention; otherwise, cross-attention
        '''
        key = self.w_key(encoder_output) if (encoder_output is not None) else self.w_key(batch)
        query = self.w_query(batch)
        value = self.w_value(encoder_output) if (encoder_output is not None) else self.w_value(batch)

        attention_scores = query @ key.transpose(-2, -1)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask.bool(), float('-inf'))
        # [batch_size, seq_length, seq_length]
         
        return (
            F.softmax(
                attention_scores / self.scale,
                dim=-1
            ) @ value
        )

In [5]:
# attention = Attention(model_dim=MODEL_DIM)
# attention(dummy_input)

In [6]:
class EncoderLayer(nn.Module):
    def __init__(self, model_dim):
        super().__init__()
        self.attention = Attention(model_dim)
        self.layer_norm_attn = nn.LayerNorm(normalized_shape=model_dim)
        self.layer_norm_ffn = nn.LayerNorm(normalized_shape=model_dim)
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, model_dim*4),
            nn.ReLU(),
            nn.Linear(model_dim*4, model_dim),
        )
    
    def add_and_norm(self, input_embedding, attention_vector, norm):
        return norm(
            input_embedding + attention_vector
        )

    def forward(self, x, mask=None):
        attention_vector = self.attention(x, mask=mask)
        x = self.add_and_norm(x, attention_vector, norm=self.layer_norm_attn)
        ffn_output = self.ffn(x)
        x = self.add_and_norm(x, ffn_output, norm=self.layer_norm_ffn)
        
        return x

In [77]:
class DecoderLayer(nn.Module):
    def __init__(self, model_dim):
        super().__init__()
        self.self_attention = Attention(model_dim)
        self.cross_attention = Attention(model_dim)
        self.layer_norm_self_attn = nn.LayerNorm(normalized_shape=model_dim)
        self.layer_norm_cross_attn = nn.LayerNorm(normalized_shape=model_dim)
        self.layer_norm_fnn = nn.LayerNorm(normalized_shape=model_dim)
        self.fnn = nn.Sequential(
            nn.Linear(model_dim, model_dim * 4),
            nn.ReLU(),
            nn.Linear(model_dim * 4, model_dim)
        )

    def add_and_norm(self, input_embedding, attention_vector, norm):
        return norm(
            input_embedding + attention_vector
        )

    def forward(self, x, encoder_output, tgt_mask=None, src_mask=None):
        # calculating the self-attention
        seq_len = x.shape[-2]
        triangle_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1).to(device=x.device)
        self_attention_vector = self.self_attention(x, mask=(triangle_mask | tgt_mask.to(x.device)))
        x = self.add_and_norm(x, self_attention_vector, norm=self.layer_norm_self_attn)

        # calculating the cross-attention
        seq_len = x.shape[1]
        cross_attention_vector = self.cross_attention(x, encoder_output=encoder_output, mask=src_mask)
        x = self.add_and_norm(x, cross_attention_vector, norm=self.layer_norm_cross_attn)
        
        # FNN layer
        fnn_output = self.fnn(x)
        x = self.add_and_norm(x, fnn_output, norm=self.layer_norm_fnn)

        return x

In [83]:
class Transformer(nn.Module):
    def __init__(self, model_dim, num_layers, src_vocab_size, tgt_vocab_size, padding_idx):
        super().__init__()
        self.model_dim = model_dim
        self.padding_idx = padding_idx

        self.encoder_embedding = nn.Embedding(
            num_embeddings=src_vocab_size,
            embedding_dim=model_dim,
            padding_idx=padding_idx
        )

        self.decoder_embedding = nn.Embedding(
            num_embeddings=tgt_vocab_size,
            embedding_dim=model_dim,
            padding_idx=padding_idx
        )

        self.positional_encoder = PositionalEncoder(model_dim, max_len=5000)

        self.encoder_layers = nn.ModuleList(EncoderLayer(model_dim) for _ in range(num_layers))
        self.decoder_layers = nn.ModuleList(DecoderLayer(model_dim) for _ in range(num_layers))

        self.fc_out = nn.Linear(model_dim, tgt_vocab_size)

    def forward(self, src, tgt):
        src_embedding = self.encoder_embedding(src) * math.sqrt(self.model_dim)
        tgt_embedding = self.decoder_embedding(tgt) * math.sqrt(self.model_dim)
        # [batch_size, seq_len, model_dim]

        src_embedding = self.positional_encoder(src_embedding)
        tgt_embedding = self.positional_encoder(tgt_embedding)
        # [batch_size, seq_len, model_dim]

        # --- Encoder ---
        src_mask = (src == self.padding_idx) # padding mask
        src_mask = src_mask.unsqueeze(1) # ! for multi-head attention, we should unsqueeze(2) in addition to what we have
        # [batch_size, 1, seq_len]
        for encoder in self.encoder_layers:
            src_embedding = encoder(src_embedding, mask=src_mask)
        encoder_outputs = src_embedding # just a rename for better readability

        # --- Decoder ---
        tgt_mask = (tgt == self.padding_idx) # padding mask
        tgt_mask = tgt_mask.unsqueeze(1)
        for decoder in self.decoder_layers:
            tgt_embedding = decoder(tgt_embedding, encoder_outputs, tgt_mask=tgt_mask, src_mask=src_mask)
        decoder_outputs = tgt_embedding

        return self.fc_out(decoder_outputs)

In [84]:
transformer = Transformer(
    model_dim=MODEL_DIM,
    num_layers=5,
    src_vocab_size=30000,
    tgt_vocab_size=30000,
    padding_idx=0,
)

transformer(dummy_raw, dummy_raw)

tensor([[[ 0.5174, -0.1248,  0.5785,  ..., -1.1002,  0.1090,  1.2529],
         [ 0.2743, -0.1104,  0.2703,  ..., -1.2902, -0.6096,  1.2949],
         [-0.1630,  0.0934,  0.3706,  ..., -1.1377, -0.2118,  1.1541],
         ...,
         [ 0.4326,  0.2181,  0.4488,  ..., -0.4757, -0.4084,  1.2081],
         [-0.4088,  0.1215,  1.0486,  ..., -1.3274,  0.2844,  1.7257],
         [ 0.2937, -0.0621,  0.7997,  ..., -0.5898,  0.5246,  0.9784]],

        [[-0.8206,  0.5910, -0.5584,  ...,  0.6371, -0.4118, -0.1085],
         [-1.0866,  0.7006, -0.8822,  ...,  0.6030, -0.6748, -0.0923],
         [-1.1180,  0.6025, -1.1476,  ...,  0.3079, -0.2585,  0.4914],
         ...,
         [-0.6805,  0.5788, -0.4112,  ...,  0.1675, -0.6317,  0.8033],
         [-0.5054,  0.5177, -0.3869,  ..., -0.1380, -0.3360,  0.7031],
         [-0.8101,  1.0963, -0.8200,  ..., -0.0516, -0.3498,  1.3404]]],
       grad_fn=<ViewBackward0>)

In [85]:
# encoder = EncoderLayer(model_dim=MODEL_DIM)
# decoder = DecoderLayer(model_dim=MODEL_DIM)

# encoder_outputs = encoder(dummy_input)
# decoder_outputs = decoder(dummy_input, encoder_outputs)

In [10]:
# attention = Attention(model_dim=MODEL_DIM)
# attention(dummy_input, masked=True)