In [2]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class Attention(nn.Module):
    def __init__(self, mask = None):
        super(Attention, self).__init__()
        self.mask = mask

    def set_mask(self, mask):
        self.mask = mask

    def forward(self, Q, K, V, d_k):
        QK_T = torch.matmul(Q, torch.transpose(K, -1, -2))
        QK_T_d_k = torch.div(QK_T, torch.sqrt(d_k))

        if self.mask is not None:
            QK_T_d_k += self.mask

        softmax = F.softmax(QK_T_d_k, dim = -1)
        return torch.matmul(softmax, V)

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, mask = None):
        super(MultiHeadAttention, self).__init__()

        self.h = h
        self.d_k_value = torch.Tensor([d_k])
        self.linear = nn.ModuleList()
        self.W_O = nn.Parameter(torch.Tensor(h*d_v, d_model))
        self.attention = Attention(mask)

        for _ in range(self.h):
            linear = nn.ModuleList([nn.Linear(d_k, d_model), nn.Linear(d_k, d_model), nn.Linear(d_v, d_model)])
            self.linear.append(linear)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def set_mask(self, mask):
        self.attention.set_mask(mask)

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        assert len(Q.shape) == len(K.shape) == len(V.shape), f"invalid dimensions, got Q:{Q.shape}, K: {K.shape}, V:{V.shape}"

        heads = [self.attention(layer[0](Q), layer[1](K), layer[2](V), self.d_k_value) for layer in self.linear]
        concat_heads = torch.cat(heads, dim = -1)
        return torch.matmul(concat_heads, self.W_O)

In [5]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, in_features: int, out_features: int, intermediate_features = None):
        super(PositionWiseFeedForward, self).__init__()
        if intermediate_features is None:
            self.intermediate_features = in_features*4

        self.layers = nn.Sequential(
            nn.Linear(in_features, self.intermediate_features),
            nn.ReLU(),
            nn.Linear(self.intermediate_features, out_features)
        )

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X):
        return self.layers(X)

In [6]:
class EncoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int):
        super(EncoderStack, self).__init__()

        self.multi_head_attention = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X):
        sublayer_1_output = self.multi_head_attention(X, X, X)
        sublayer_1_normalised = self.layer_norm1(X + sublayer_1_output)
        sublayer_2_output = self.position_wise_feed_forward(sublayer_1_normalised)
        output = self.layer_norm2(sublayer_1_normalised + sublayer_2_output)
        return output

In [7]:
class DecoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, mask = None):
        super(DecoderStack, self).__init__()

        self.multi_head_attention1 = MultiHeadAttention(d_model, d_k, d_v, h, mask)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.multi_head_attention2 = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm3 = nn.LayerNorm(d_model)

    def set_mask(self, mask):
        self.multi_head_attention1.set_mask(mask)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X, Q, K):

        V = self.multi_head_attention1(X, X, X)
        V_norm = self.layer_norm1(X + V)
        sublayer_2_output = self.multi_head_attention2(Q, K, V_norm)
        sublayer_2_normalised = self.layer_norm2(V_norm + sublayer_2_output)
        sublayer_3_output = self.position_wise_feed_forward(sublayer_2_normalised)
        output = self.layer_norm3(sublayer_2_normalised + sublayer_3_output)

        return output

In [8]:
class AttentionIsAllYouNeed(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, number_of_encoder_stacks: int, number_of_decoder_stacks: int):
        super(AttentionIsAllYouNeed, self).__init__()

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h


        encoder_list = nn.ModuleList()

        for _ in range(number_of_encoder_stacks):
            encoder_list.append(EncoderStack(d_model, d_k, d_v, h))

        self.encoder = nn.Sequential(*encoder_list)

        self.decoder = nn.ModuleList()

        for _ in range(number_of_decoder_stacks):
            self.decoder.append(DecoderStack(d_model, d_k, d_v, h))

        self.final_layer = nn.Linear(d_v, d_v)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def set_mask(self, mask):
        for layer in self.decoder:
            layer.set_mask(mask)


    def forward(self, X, Y):
        X_encoded = self.encoder(X)
        for stack in self.decoder:
            Y = stack(Y, X_encoded, X_encoded)

        Y_hat = self.final_layer(Y)

        return Y_hat

# Tests

In [9]:
multi_head_attention = MultiHeadAttention(512, 512, 512, 8)
multi_head_attention.linear

ModuleList(
  (0-7): 8 x ModuleList(
    (0-2): 3 x Linear(in_features=512, out_features=512, bias=True)
  )
)

In [22]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = 64  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
Q = torch.randn(batch_size, seq_length, d_k)  # Queries
K = torch.randn(batch_size, seq_length, d_k)  # Keys
V = torch.randn(batch_size, seq_length, model_dim)  # Values

# Print the shapes for confirmation
print("Shape of Q:", Q.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of K:", K.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of V:", V.shape)  # Expected: (batch_size, seq_length, model_dim)

multi_head_attention = MultiHeadAttention(model_dim, d_k, model_dim, h)
output = multi_head_attention(Q, K, V)
print(f'Shape of output:{output.shape}')
print(output)

Shape of Q: torch.Size([2, 5, 64])
Shape of K: torch.Size([2, 5, 64])
Shape of V: torch.Size([2, 5, 64])
Shape of output:torch.Size([2, 5, 64])
tensor([[[-7.7158e+36, -3.9808e+35, -6.3614e+36, -1.1052e+34,  1.9115e+36,
          -1.5771e+34,  1.2227e+35, -6.0128e+36, -9.0464e+35,  1.0916e+37,
           1.9311e+35, -2.3737e+37, -1.0969e+36, -2.4792e+34, -2.3777e+34,
           2.8126e+36,  8.5865e+36, -7.0440e+36,  5.9270e+34, -1.1741e+35,
           3.2193e+34,  1.0127e+34, -5.5693e+34, -2.9127e+35,  4.4485e+35,
           2.9197e+32, -2.6269e+36, -5.2739e+36,  5.5094e+36,  2.5264e+35,
          -2.7202e+35,  8.2020e+36, -1.2285e+35, -3.8584e+36, -1.1080e+37,
          -7.2806e+36, -7.1910e+34,  6.5290e+36, -4.3211e+36, -1.9017e+36,
          -2.5199e+36, -6.7443e+35, -1.1761e+37, -4.0528e+36, -3.9418e+34,
          -1.2263e+34, -5.7671e+34,  1.2610e+37, -8.7630e+35,  2.4204e+36,
          -4.2404e+36, -1.6847e+36,  2.8929e+36, -4.0092e+36,  4.4405e+34,
          -9.0855e+36, -4.8206e

In [26]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = d_k  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
X = torch.randn(batch_size, seq_length, d_k)  # Queries
encoder_stack = EncoderStack(model_dim, d_k, model_dim, h)
encoder_output = encoder_stack(X)
print(encoder_output.shape)
print(encoder_output)

decoder_stack = DecoderStack(model_dim, d_k, model_dim, h, None)
decoder_output = decoder_stack(X, encoder_output, encoder_output)
print(decoder_output.shape)
print(decoder_output)

torch.Size([2, 5, 64])
tensor([[[-0.1039,  1.4438,  0.8375, -1.0493, -1.2741,  0.7297,  3.1908,
           0.0864, -0.7791, -0.2108, -0.5236, -0.2370, -0.7821,  1.7480,
           0.4905, -0.4368, -0.9548, -0.5537, -0.6005,  1.2184,  1.1913,
          -0.6256, -0.7231, -0.3026, -2.4158,  0.4136,  2.2691,  1.7100,
           0.5128,  0.3327, -1.1567, -0.2180,  0.6617,  0.1009, -0.3515,
           0.4473, -0.9262,  0.3200, -0.8453, -0.0689,  0.5461, -0.5779,
          -1.4740,  1.2302,  1.5161, -0.7276,  0.8295, -0.4684,  0.0144,
           0.8289, -1.1465,  0.2345, -0.9876, -0.4537, -0.3000,  0.1688,
          -1.3168,  0.3393, -0.6929,  1.1514, -1.6150,  0.7342, -0.7498,
           0.3517],
         [-0.7210, -0.6491,  0.1282,  0.1741,  1.3528,  0.1402, -0.5439,
          -0.9692, -0.3829, -0.8711,  0.5084, -0.7831, -0.0687,  0.0108,
          -0.9466, -0.3992,  0.0377, -1.8376,  0.2573, -0.2666, -0.6284,
           0.8730, -0.4118,  0.7247, -0.1336,  0.1962,  1.6095,  0.0335,
        

In [34]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = d_k  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
X = torch.randn(batch_size, seq_length, d_k)
Y = torch.randn(batch_size, seq_length, d_k)

attention_is_all_you_need = AttentionIsAllYouNeed(model_dim, d_k, model_dim, h, 1, 1)
output = attention_is_all_you_need(X, Y)
print(output.shape)
print(output)

torch.Size([2, 5, 64])
tensor([[[ 0.3942, -1.6487,  0.4268,  0.8895,  0.9078, -0.1566,  0.3509,
           0.3038, -0.4853, -0.5535, -0.0410, -0.6531, -0.7741,  0.2140,
           0.1809,  0.4186,  0.3990,  0.2221,  0.6776,  0.2289,  0.2376,
           0.4875,  0.7917,  0.1212, -0.4958,  0.6184, -0.0243,  0.8819,
           0.0215,  0.8609,  0.4352, -1.9928, -0.1582,  1.0375, -0.6470,
          -0.5815,  0.3517,  0.6750,  0.2117,  0.8613, -0.6931,  0.1117,
           0.6397,  0.0764, -0.5994,  1.1049,  0.2024,  0.4491, -0.0975,
           0.4945,  1.1976, -0.5432, -0.2462, -0.3191,  0.5830,  0.7496,
           0.0413,  0.1388,  0.4894, -0.7450,  0.2143,  0.6303,  0.8272,
          -0.8080],
         [ 0.1982, -0.0075,  0.7543, -0.4238, -0.4128, -0.0271, -0.4509,
           0.4221, -1.2759, -0.3138,  0.2374, -0.1568, -0.7380,  0.6623,
          -0.0675, -0.6450,  0.4892, -0.6939, -0.3818, -0.6615,  0.5972,
          -0.9151,  1.3200, -0.0820,  1.2155, -0.7662,  0.3730, -0.4824,
        