In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [13]:
class Attention(nn.Module):
    def __init__(self, mask = None):
        super(Attention, self).__init__()
        self.mask = mask

    def set_mask(self, mask):
        self.mask = mask

    def forward(self, Q, K, V, d_k):
        QK_T = torch.matmul(Q, torch.transpose(K, -1, -2))
        QK_T_d_k = torch.div(QK_T, torch.sqrt(d_k))

        if self.mask is not None:
            QK_T_d_k += self.mask

        softmax = F.softmax(QK_T_d_k, dim = -1)
        return torch.matmul(softmax, V)

In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, mask = None):
        super(MultiHeadAttention, self).__init__()

        self.h = h
        self.d_k_value = torch.Tensor([d_k])
        self.linear = nn.ModuleList()
        self.W_O = nn.Parameter(torch.Tensor(h*d_v, d_model))
        self.attention = Attention(mask)

        for _ in range(self.h):
            linear = nn.ModuleList([nn.Linear(d_k, d_model), nn.Linear(d_k, d_model), nn.Linear(d_v, d_model)])
            self.linear.append(linear)

    def set_mask(self, mask):
        self.attention.set_mask(mask)

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        assert len(Q.shape) == len(K.shape) == len(V.shape), f"invalid dimensions, got Q:{Q.shape}, K: {K.shape}, V:{V.shape}"

        heads = [self.attention(layer[0](Q), layer[1](K), layer[2](V), self.d_k_value) for layer in self.linear]
        concat_heads = torch.cat(heads, dim = -1)
        return torch.matmul(concat_heads, self.W_O)

In [15]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, in_features: int, out_features: int, intermediate_features = None):
        super(PositionWiseFeedForward, self).__init__()
        if intermediate_features is None:
            self.intermediate_features = in_features*4

        self.layers = nn.Sequential(
            nn.Linear(in_features, self.intermediate_features),
            nn.ReLU(),
            nn.Linear(self.intermediate_features, out_features)
        )

    def forward(self, X):
        return self.layers(X)

In [16]:
class EncoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int):
        super(EncoderStack, self).__init__()

        self.multi_head_attention = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def set_mask(self, mask):
        self.multi_head_attention.set_mask(mask)

    def forward(self, X):
        sublayer_1_output = self.multi_head_attention(X, X, X)
        sublayer_1_normalised = self.layer_norm1(X + sublayer_1_output)
        sublayer_2_output = self.position_wise_feed_forward(sublayer_1_normalised)
        output = self.layer_norm2(sublayer_1_normalised + sublayer_2_output)
        return output

In [17]:
class DecoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, mask = None):
        super(DecoderStack, self).__init__()

        self.multi_head_attention1 = MultiHeadAttention(d_model, d_k, d_v, h, mask)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.multi_head_attention2 = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm3 = nn.LayerNorm(d_model)

    def set_mask(self, mask):
        self.multi_head_attention1.set_mask(mask)

    def forward(self, X, Q, K):

        V = self.multi_head_attention1(X, X, X)
        V_norm = self.layer_norm1(X + V)
        sublayer_2_output = self.multi_head_attention2(Q, K, V_norm)
        sublayer_2_normalised = self.layer_norm2(V_norm + sublayer_2_output)
        sublayer_3_output = self.position_wise_feed_forward(sublayer_2_normalised)
        output = self.layer_norm3(sublayer_2_normalised + sublayer_3_output)

        return output

# Tests

In [18]:
multi_head_attention = MultiHeadAttention(512, 512, 512, 8)
multi_head_attention.linear

ModuleList(
  (0-7): 8 x ModuleList(
    (0-2): 3 x Linear(in_features=512, out_features=512, bias=True)
  )
)

In [26]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = 64  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
Q = torch.randn(batch_size, seq_length, d_k)  # Queries
K = torch.randn(batch_size, seq_length, d_k)  # Keys
V = torch.randn(batch_size, seq_length, model_dim)  # Values

# Print the shapes for confirmation
print("Shape of Q:", Q.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of K:", K.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of V:", V.shape)  # Expected: (batch_size, seq_length, model_dim)

multi_head_attention = MultiHeadAttention(model_dim, d_k, model_dim, h)
output = multi_head_attention(Q, K, V)
print(f'Shape of output:{output.shape}')
print(output)

Shape of Q: torch.Size([2, 5, 64])
Shape of K: torch.Size([2, 5, 64])
Shape of V: torch.Size([2, 5, 64])
Shape of output:torch.Size([2, 5, 64])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [28]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = d_k  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
X = torch.randn(batch_size, seq_length, d_k)  # Queries
encoder_stack = EncoderStack(model_dim, d_k, model_dim, h)
encoder_output = encoder_stack(X)
print(encoder_output.shape)
print(encoder_output)

decoder_stack = DecoderStack(model_dim, d_k, model_dim, h, None)
decoder_output = decoder_stack(X, encoder_output, encoder_output)
print(decoder_output.shape)
print(decoder_output)

torch.Size([2, 5, 64])
tensor([[[-4.3076e-01,  6.0484e-01,  5.4121e-02, -5.1875e-01,  1.1079e+00,
           1.4728e+00,  1.5404e+00, -2.3107e+00, -1.7760e+00, -3.9833e-01,
           1.6302e+00,  1.7585e+00, -5.2821e-01, -1.5424e+00,  5.9872e-01,
           1.3088e+00, -1.8706e-01, -1.9072e-01, -7.2235e-01, -3.9332e-01,
          -6.6101e-01,  8.5264e-01,  1.2726e+00, -6.0114e-01, -2.6549e-01,
           1.1766e+00,  3.0314e-01, -1.4057e+00, -1.1424e+00,  8.1405e-01,
           6.5785e-01,  2.3066e-01,  3.4055e-01, -1.7748e-01,  3.4725e-01,
          -1.2393e+00, -1.9766e+00,  5.8639e-01, -5.5633e-01, -6.5792e-01,
          -3.2908e-01,  3.4763e-02,  7.1937e-01,  4.5273e-02,  1.9821e-01,
          -1.1837e-01,  6.7208e-01,  7.0503e-01, -2.4000e+00, -7.1432e-02,
           1.1389e+00,  8.2356e-01, -2.8059e-02,  1.8813e+00, -7.5849e-01,
           7.8319e-01, -3.1497e-01, -8.8514e-01,  1.5571e+00, -5.3806e-01,
          -4.2848e-01,  6.4868e-01, -5.9181e-01, -1.7197e+00],
         [ 1.6