In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchtext.vocab import GloVe, vocab

import string

# Model

In [2]:
class Attention(nn.Module):
    def __init__(self, mask = None):
        super(Attention, self).__init__()
        self.mask = mask

    def set_mask(self, mask):
        self.mask = mask

    def forward(self, Q, K, V, d_k):
        QK_T = torch.matmul(Q, torch.transpose(K, -1, -2))
        QK_T_d_k = torch.div(QK_T, torch.sqrt(d_k))

        if self.mask is not None:
            QK_T_d_k += self.mask

        softmax = F.softmax(QK_T_d_k, dim = -1)
        return torch.matmul(softmax, V)

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, mask = None):
        super(MultiHeadAttention, self).__init__()

        self.h = h
        self.d_k_value = torch.Tensor([d_k])
        self.linear = nn.ModuleList()
        self.W_O = nn.Parameter(torch.Tensor(h*d_v, d_model))
        self.attention = Attention(mask)

        for _ in range(self.h):
            linear = nn.ModuleList([nn.Linear(d_k, d_model), nn.Linear(d_k, d_model), nn.Linear(d_v, d_model)])
            self.linear.append(linear)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def set_mask(self, mask):
        self.attention.set_mask(mask)

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        assert len(Q.shape) == len(K.shape) == len(V.shape), f"invalid dimensions, got Q:{Q.shape}, K: {K.shape}, V:{V.shape}"

        heads = [self.attention(layer[0](Q), layer[1](K), layer[2](V), self.d_k_value) for layer in self.linear]
        concat_heads = torch.cat(heads, dim = -1)
        return torch.matmul(concat_heads, self.W_O)

In [4]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, in_features: int, out_features: int, intermediate_features = None):
        super(PositionWiseFeedForward, self).__init__()
        if intermediate_features is None:
            self.intermediate_features = in_features*4

        self.layers = nn.Sequential(
            nn.Linear(in_features, self.intermediate_features),
            nn.ReLU(),
            nn.Linear(self.intermediate_features, out_features)
        )

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X):
        return self.layers(X)

In [5]:
class EncoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int):
        super(EncoderStack, self).__init__()

        self.multi_head_attention = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X):
        sublayer_1_output = self.multi_head_attention(X, X, X)
        sublayer_1_normalised = self.layer_norm1(X + sublayer_1_output)
        sublayer_2_output = self.position_wise_feed_forward(sublayer_1_normalised)
        output = self.layer_norm2(sublayer_1_normalised + sublayer_2_output)
        return output

In [6]:
class DecoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, mask = None):
        super(DecoderStack, self).__init__()

        self.multi_head_attention1 = MultiHeadAttention(d_model, d_k, d_v, h, mask)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.multi_head_attention2 = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm3 = nn.LayerNorm(d_model)

    def set_mask(self, mask):
        self.multi_head_attention1.set_mask(mask)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X, Q, K):

        V = self.multi_head_attention1(X, X, X)
        V_norm = self.layer_norm1(X + V)
        sublayer_2_output = self.multi_head_attention2(Q, K, V_norm)
        sublayer_2_normalised = self.layer_norm2(V_norm + sublayer_2_output)
        sublayer_3_output = self.position_wise_feed_forward(sublayer_2_normalised)
        output = self.layer_norm3(sublayer_2_normalised + sublayer_3_output)

        return output

In [7]:
class PositionalEncoding(nn.Module):
    """Positional encoding."""
    def __init__(self, num_hiddens, max_len=1000, dropout = 0):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # Create a long enough `P`
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

In [8]:
class AttentionIsAllYouNeed(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int, number_of_encoder_stacks: int, number_of_decoder_stacks: int):
        super(AttentionIsAllYouNeed, self).__init__()

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.embedding = None
        self.positional_encoding = None

        encoder_list = nn.ModuleList()

        for _ in range(number_of_encoder_stacks):
            encoder_list.append(EncoderStack(d_model, d_k, d_v, h))

        self.encoder = nn.Sequential(*encoder_list)

        self.decoder = nn.ModuleList()

        for _ in range(number_of_decoder_stacks):
            self.decoder.append(DecoderStack(d_model, d_k, d_v, h))

        self.final_layer = nn.Linear(d_v, d_v)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def init_embedding(self, pretrained_embeddings, freeze_embeddings = False, sparse = False):
        self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze = freeze_embeddings, sparse=sparse)

    def set_mask(self, mask):
        for layer in self.decoder:
            layer.set_mask(mask)


    def forward(self, X, Y):
        X_encoded = self.encoder(X)
        for stack in self.decoder:
            Y = stack(Y, X_encoded, X_encoded)

        Y_hat = self.final_layer(Y)

        return Y_hat

# Data Loading

In [15]:
glove_vectors = GloVe(name='6B', dim=100)

start_token = "</s>"
start_index = 0

glove_vocab = vocab(glove_vectors.stoi)
glove_vocab.insert_token(start_token, start_index)
glove_vocab.set_default_index(start_index)
pretrained_embeddings = glove_vectors.vectors
pretrained_embeddings = torch.cat((torch.zeros(1,pretrained_embeddings.shape[1]),pretrained_embeddings))

In [14]:
pretrained_embeddings.shape

torch.Size([400001, 100])

In [25]:
glove_vocab["november"]

487

# Testing

In [30]:
encoding_dim = 100
embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze = True, sparse=True)
embedding.eval()
pos_encoding = PositionalEncoding(encoding_dim)
pos_encoding.eval()

X = embedding(torch.tensor([[7, 10, 72, 1, 8], [9, 20, 50, 60, 10]]))
X_pos = pos_encoding(X)
display(X_pos)
display(X)

tensor([[[ 8.5703e-02,  7.7799e-01,  1.6569e-01,  1.1337e+00,  3.8239e-01,
           1.3540e+00,  1.2870e-02,  1.2246e+00, -4.3817e-01,  1.5016e+00,
          -3.5874e-01,  6.5017e-01,  5.5156e-02,  1.6965e+00, -1.7958e-01,
           1.0679e+00,  3.9101e-01,  1.1604e+00, -2.6635e-01,  7.8862e-01,
           5.3698e-01,  1.4938e+00,  9.3660e-01,  1.6690e+00,  2.1793e-01,
           5.3358e-01,  2.2383e-01,  6.3796e-01, -1.7656e-01,  1.1748e+00,
          -2.0367e-01,  1.1393e+00,  1.9832e-02,  8.9587e-01, -2.0244e-01,
           1.5500e+00, -1.5460e-01,  1.9865e+00, -2.6863e-01,  7.0910e-01,
          -3.2866e-01,  6.5812e-01, -1.6943e-01,  5.7999e-01, -4.6727e-02,
           8.3673e-01,  7.0824e-01,  2.5089e-01, -9.1559e-02,  3.8220e-02,
          -1.9747e-01,  1.1028e+00,  5.5221e-01,  2.3816e+00, -6.5636e-01,
          -2.2502e+00, -3.1556e-01, -2.0550e-01,  1.7709e+00,  1.4026e+00,
          -7.9827e-01,  2.1597e+00, -3.3042e-01,  1.3138e+00,  7.7386e-01,
           1.2260e+00,  5

tensor([[[ 8.5703e-02, -2.2201e-01,  1.6569e-01,  1.3373e-01,  3.8239e-01,
           3.5401e-01,  1.2870e-02,  2.2461e-01, -4.3817e-01,  5.0164e-01,
          -3.5874e-01, -3.4983e-01,  5.5156e-02,  6.9648e-01, -1.7958e-01,
           6.7926e-02,  3.9101e-01,  1.6039e-01, -2.6635e-01, -2.1138e-01,
           5.3698e-01,  4.9379e-01,  9.3660e-01,  6.6902e-01,  2.1793e-01,
          -4.6642e-01,  2.2383e-01, -3.6204e-01, -1.7656e-01,  1.7480e-01,
          -2.0367e-01,  1.3931e-01,  1.9832e-02, -1.0413e-01, -2.0244e-01,
           5.5003e-01, -1.5460e-01,  9.8655e-01, -2.6863e-01, -2.9090e-01,
          -3.2866e-01, -3.4188e-01, -1.6943e-01, -4.2001e-01, -4.6727e-02,
          -1.6327e-01,  7.0824e-01, -7.4911e-01, -9.1559e-02, -9.6178e-01,
          -1.9747e-01,  1.0282e-01,  5.5221e-01,  1.3816e+00, -6.5636e-01,
          -3.2502e+00, -3.1556e-01, -1.2055e+00,  1.7709e+00,  4.0260e-01,
          -7.9827e-01,  1.1597e+00, -3.3042e-01,  3.1382e-01,  7.7386e-01,
           2.2595e-01,  5

In [32]:
pos_encoding.P.shape

torch.Size([1, 1000, 100])

In [14]:
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
display(P.shape)
display(P)

torch.Size([1, 60, 32])

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  5.3317e-01,  ...,  1.0000e+00,
           1.7783e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.0213e-01,  ...,  1.0000e+00,
           3.5566e-04,  1.0000e+00],
         ...,
         [ 4.3616e-01,  8.9987e-01,  5.9521e-01,  ...,  9.9984e-01,
           1.0136e-02,  9.9995e-01],
         [ 9.9287e-01,  1.1918e-01,  9.3199e-01,  ...,  9.9983e-01,
           1.0314e-02,  9.9995e-01],
         [ 6.3674e-01, -7.7108e-01,  9.8174e-01,  ...,  9.9983e-01,
           1.0492e-02,  9.9994e-01]]])

In [8]:
multi_head_attention = MultiHeadAttention(512, 512, 512, 8)
multi_head_attention.linear

ModuleList(
  (0-7): 8 x ModuleList(
    (0-2): 3 x Linear(in_features=512, out_features=512, bias=True)
  )
)

In [9]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = 64  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
Q = torch.randn(batch_size, seq_length, d_k)  # Queries
K = torch.randn(batch_size, seq_length, d_k)  # Keys
V = torch.randn(batch_size, seq_length, model_dim)  # Values

# Print the shapes for confirmation
print("Shape of Q:", Q.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of K:", K.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of V:", V.shape)  # Expected: (batch_size, seq_length, model_dim)

multi_head_attention = MultiHeadAttention(model_dim, d_k, model_dim, h)
output = multi_head_attention(Q, K, V)
print(f'Shape of output:{output.shape}')
print(output)

Shape of Q: torch.Size([2, 5, 64])
Shape of K: torch.Size([2, 5, 64])
Shape of V: torch.Size([2, 5, 64])
Shape of output:torch.Size([2, 5, 64])
tensor([[[        nan,  4.0096e+37,  6.2351e+37, -4.2478e+36, -2.0465e+37,
          -4.1833e+36,         nan,         nan, -7.8783e+36,  1.0117e+37,
          -9.4423e+35,  2.0174e+36, -5.8751e+36, -3.1546e+36, -3.0629e+36,
          -1.6364e+37, -1.8854e+36, -5.5364e+35,  3.2821e+36, -3.7999e+36,
          -8.9872e+35, -2.5219e+36,  3.8568e+36, -5.8303e+36,  1.0008e+37,
           4.1014e+35, -5.0816e+36, -1.4643e+35, -4.0584e+37, -1.0357e+38,
                  nan, -3.9588e+36,  1.1894e+36,  9.2627e+36,         nan,
           3.2849e+36, -2.2424e+36,  6.2773e+36,  6.5724e+36, -2.8173e+36,
          -1.7648e+37, -4.5027e+37,  9.5633e+35, -3.9691e+37,  2.6361e+36,
           6.8147e+36,  2.2129e+37,  4.7346e+36,  3.1635e+36,  1.1845e+36,
          -9.1324e+36,         nan,  3.2725e+37, -2.1265e+37, -2.5488e+37,
                  nan, -1.6300e

In [10]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = d_k  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
X = torch.randn(batch_size, seq_length, d_k)  # Queries
encoder_stack = EncoderStack(model_dim, d_k, model_dim, h)
encoder_output = encoder_stack(X)
print(encoder_output.shape)
print(encoder_output)

decoder_stack = DecoderStack(model_dim, d_k, model_dim, h, None)
decoder_output = decoder_stack(X, encoder_output, encoder_output)
print(decoder_output.shape)
print(decoder_output)

torch.Size([2, 5, 64])
tensor([[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, 

In [11]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = d_k  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
X = torch.randn(batch_size, seq_length, d_k)
Y = torch.randn(batch_size, seq_length, d_k)

attention_is_all_you_need = AttentionIsAllYouNeed(model_dim, d_k, model_dim, h, 1, 1)
output = attention_is_all_you_need(X, Y)
print(output.shape)
print(output)

torch.Size([2, 5, 64])
tensor([[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, 