In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [29]:
class Attention(nn.Module):
    def __init__(self, mask = None):
        super(Attention, self).__init__()
        self.mask = mask

    def set_mask(self, mask):
        self.mask = mask

    def forward(self, Q, K, V, d_k):
        QK_T = torch.matmul(Q, torch.transpose(K, -1, -2))
        QK_T_d_k = torch.div(QK_T, torch.sqrt(d_k))

        if self.mask is not None:
            QK_T_d_k += self.mask

        softmax = F.softmax(QK_T_d_k, dim = -1)
        return torch.matmul(softmax, V)

In [30]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, mask = None):
        super(MultiHeadAttention, self).__init__()

        self.h = h
        self.d_k_value = torch.Tensor([d_k])
        self.linear = nn.ModuleList()
        self.W_O = nn.Parameter(torch.Tensor(h*d_v, d_model))
        self.attention = Attention(mask)

        for _ in range(self.h):
            linear = nn.ModuleList([nn.Linear(d_k, d_model), nn.Linear(d_k, d_model), nn.Linear(d_v, d_model)])
            self.linear.append(linear)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def set_mask(self, mask):
        self.attention.set_mask(mask)

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        assert len(Q.shape) == len(K.shape) == len(V.shape), f"invalid dimensions, got Q:{Q.shape}, K: {K.shape}, V:{V.shape}"

        heads = [self.attention(layer[0](Q), layer[1](K), layer[2](V), self.d_k_value) for layer in self.linear]
        concat_heads = torch.cat(heads, dim = -1)
        return torch.matmul(concat_heads, self.W_O)

In [31]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, in_features: int, out_features: int, intermediate_features = None):
        super(PositionWiseFeedForward, self).__init__()
        if intermediate_features is None:
            self.intermediate_features = in_features*4

        self.layers = nn.Sequential(
            nn.Linear(in_features, self.intermediate_features),
            nn.ReLU(),
            nn.Linear(self.intermediate_features, out_features)
        )

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X):
        return self.layers(X)

In [32]:
class EncoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int):
        super(EncoderStack, self).__init__()

        self.multi_head_attention = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X):
        sublayer_1_output = self.multi_head_attention(X, X, X)
        sublayer_1_normalised = self.layer_norm1(X + sublayer_1_output)
        sublayer_2_output = self.position_wise_feed_forward(sublayer_1_normalised)
        output = self.layer_norm2(sublayer_1_normalised + sublayer_2_output)
        return output

In [33]:
class DecoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, mask = None):
        super(DecoderStack, self).__init__()

        self.multi_head_attention1 = MultiHeadAttention(d_model, d_k, d_v, h, mask)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.multi_head_attention2 = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm3 = nn.LayerNorm(d_model)

    def set_mask(self, mask):
        self.multi_head_attention1.set_mask(mask)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X, Q, K):

        V = self.multi_head_attention1(X, X, X)
        V_norm = self.layer_norm1(X + V)
        sublayer_2_output = self.multi_head_attention2(Q, K, V_norm)
        sublayer_2_normalised = self.layer_norm2(V_norm + sublayer_2_output)
        sublayer_3_output = self.position_wise_feed_forward(sublayer_2_normalised)
        output = self.layer_norm3(sublayer_2_normalised + sublayer_3_output)

        return output

In [None]:
class AttentionIsAllYouNeed(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, number_of_encoder_stacks: int, number_of_decoder_stacks: int):
        super(AttentionIsAllYouNeed, self).__init__()

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h


        encoder_list = nn.ModuleList()

        for _ in range(number_of_encoder_stacks):
            encoder_list.append(EncoderStack(d_model, d_k, d_v, h))

        self.encoder = nn.Sequential(*encoder_list)

        self.decoder = nn.ModuleList()

        for _ in range(number_of_decoder_stacks):
            self.decoder.append(DecoderStack(d_model, d_k, d_v, h))

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def set_mask(self, mask):
        for layer in self.decoder:
            layer.set_mask(mask)





# Tests

In [34]:
multi_head_attention = MultiHeadAttention(512, 512, 512, 8)
multi_head_attention.linear

ModuleList(
  (0-7): 8 x ModuleList(
    (0-2): 3 x Linear(in_features=512, out_features=512, bias=True)
  )
)

In [35]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = 64  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
Q = torch.randn(batch_size, seq_length, d_k)  # Queries
K = torch.randn(batch_size, seq_length, d_k)  # Keys
V = torch.randn(batch_size, seq_length, model_dim)  # Values

# Print the shapes for confirmation
print("Shape of Q:", Q.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of K:", K.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of V:", V.shape)  # Expected: (batch_size, seq_length, model_dim)

multi_head_attention = MultiHeadAttention(model_dim, d_k, model_dim, h)
output = multi_head_attention(Q, K, V)
print(f'Shape of output:{output.shape}')
print(output)

Shape of Q: torch.Size([2, 5, 64])
Shape of K: torch.Size([2, 5, 64])
Shape of V: torch.Size([2, 5, 64])
Shape of output:torch.Size([2, 5, 64])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [37]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = d_k  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
X = torch.randn(batch_size, seq_length, d_k)  # Queries
encoder_stack = EncoderStack(model_dim, d_k, model_dim, h)
encoder_output = encoder_stack(X)
print(encoder_output.shape)
print(encoder_output)

decoder_stack = DecoderStack(model_dim, d_k, model_dim, h, None)
decoder_output = decoder_stack(X, encoder_output, encoder_output)
print(decoder_output.shape)
print(decoder_output)

torch.Size([2, 5, 64])
tensor([[[-5.1103e-02, -1.5536e-01,  8.9463e-02, -1.3167e+00,  2.7258e+00,
          -9.3958e-01,  9.1497e-01,  1.2106e+00,  3.7205e-01,  3.5637e-01,
          -9.8485e-01,  8.5567e-01,  4.6632e-01,  1.9161e+00, -8.8724e-01,
           2.8030e-01, -7.9871e-01,  1.5041e+00,  1.3065e+00, -2.2650e-01,
          -5.8580e-01, -5.5293e-01, -5.7641e-01,  1.6829e+00, -8.5277e-01,
          -2.0138e+00, -9.7452e-01,  3.1980e-01, -8.7095e-01, -6.4506e-01,
          -5.8573e-01, -1.3367e+00,  5.3541e-01,  5.7812e-01, -9.1546e-01,
          -4.1406e-01,  1.4907e-01, -5.6542e-01,  1.1760e+00,  9.9776e-01,
          -3.5892e-01,  3.2156e-01, -7.2271e-01,  2.7709e+00,  2.0121e-01,
           5.6360e-01, -3.7481e-01, -1.1145e+00, -7.7653e-01, -2.0065e-01,
           4.9055e-01, -1.0025e+00,  1.6620e+00,  1.1149e+00, -9.8769e-01,
           2.2635e-01,  3.1325e-01,  6.6318e-01, -1.5772e+00, -2.2373e-01,
          -1.1477e+00, -9.7737e-01,  4.6852e-01, -5.1943e-01],
         [-6.8