In [None]:
import torch
from torch import nn

# We simulate data like this:
BATCH_SIZE = 2
SEQ_LENGTH = 10
EMBEDDING_DIM = 64
MODEL_DIM = 64
# Fake input
dummy_input = torch.rand(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_DIM)

s1, s2 = dummy_input[0, :], dummy_input[1, :]
s1 = s1.unsqueeze(0)
s2 = s2.unsqueeze(0)

In [228]:
from torch import nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, model_dim):
        super().__init__()

        self.w_key = nn.Linear(model_dim, model_dim, bias=False)
        self.w_query = nn.Linear(model_dim, model_dim, bias=False)
        self.w_value = nn.Linear(model_dim, model_dim, bias=False)
        self.scale = model_dim ** 0.5

    def forward(self, batch, encoder_output=None, masked=False):
        '''
        if encoder output is None, it calculates self-attention; otherwise, cross-attention
        '''
        key = self.w_key(encoder_output) if (encoder_output is not None) else self.w_key(batch)
        query = self.w_query(batch)
        value = self.w_value(encoder_output) if (encoder_output is not None) else self.w_value(batch)

        attention_scores = query @ key.transpose(-2, -1)
        if masked:
            # mask_indexes = torch.triu_indices(attention_scores.shape[1], attention_scores.shape[2], offset=1)
            # attention_scores[:, mask_indexes[0], mask_indexes[1]] = -torch.inf

            seq_len = batch.shape[1]
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(batch.device)
            attention_scores = attention_scores.masked_fill(mask.bool(), float('-inf'))

        return (
            F.softmax(
                attention_scores / self.scale,
                dim=-1
            ) @ value
        )

In [229]:
class Encoder(nn.Module):
    def __init__(self, model_dim):
        super().__init__()
        self.attention = Attention(model_dim)
        self.layer_norm_attn = nn.LayerNorm(normalized_shape=model_dim)
        self.layer_norm_ffn = nn.LayerNorm(normalized_shape=model_dim)
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, model_dim*4),
            nn.ReLU(),
            nn.Linear(model_dim*4, model_dim),
        )
    
    def add_and_norm(self, input_embedding, attention_vector, norm):
        return norm(
            input_embedding + attention_vector
        )

    def forward(self, x):
        attention_vector = self.attention(x, masked=False)
        x = self.add_and_norm(x, attention_vector, norm=self.layer_norm_attn)
        ffn_output = self.ffn(x)
        x = self.add_and_norm(x, ffn_output, norm=self.layer_norm_ffn)
        
        return x

In [230]:
class Decoder(nn.Module):
    def __init__(self, model_dim):
        super().__init__()
        self.self_attention = Attention(model_dim)
        self.cross_attention = Attention(model_dim)
        self.layer_norm_self_attn = nn.LayerNorm(normalized_shape=model_dim)
        self.layer_norm_cross_attn = nn.LayerNorm(normalized_shape=model_dim)
        self.layer_norm_fnn = nn.LayerNorm(normalized_shape=model_dim)
        self.fnn = nn.Sequential(
            nn.Linear(model_dim, model_dim * 4),
            nn.ReLU(),
            nn.Linear(model_dim * 4, model_dim)
        )

    def add_and_norm(self, input_embedding, attention_vector, norm):
        return norm(
            input_embedding + attention_vector
        )

    def forward(self, x, encoder_output):
        self_attention_vector = self.self_attention(x, masked=True)
        x = self.add_and_norm(x, self_attention_vector, norm=self.layer_norm_self_attn)
        cross_attention_vector = self.cross_attention(x, encoder_output=encoder_output, masked=False)
        x = self.add_and_norm(x, cross_attention_vector, norm=self.layer_norm_cross_attn)
        fnn_output = self.fnn(x)
        x = self.add_and_norm(x, fnn_output, norm=self.layer_norm_fnn)

        return x

In [231]:
encoder = Encoder(model_dim=MODEL_DIM)
decoder = Decoder(model_dim=MODEL_DIM)

encoder_outputs = encoder(dummy_input)
decoder_outputs = decoder(dummy_input, encoder_outputs)

In [232]:
# attention = Attention(model_dim=MODEL_DIM)
# attention(dummy_input, masked=True)