In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torchtext.vocab import GloVe, vocab

# Model

In [68]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.mask = None

    def set_mask(self, mask, value = float("-inf")):
        self.mask = mask
        self.value = value

    def forward(self, Q, K, V, d_k):
        QK_T = torch.matmul(Q, torch.transpose(K, -1, -2))
        attention_scores = torch.div(QK_T, torch.sqrt(d_k))

        if self.mask is not None:
            attention_scores.masked_fill_(self.mask, self.value)


        softmax = F.softmax(attention_scores, dim = -1)
        return torch.matmul(softmax, V)

In [69]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int):
        super(MultiHeadAttention, self).__init__()

        self.h = h
        self.d_k_value = torch.Tensor([d_k])
        self.linear = nn.ModuleList()
        self.W_O = nn.Parameter(torch.Tensor(h*d_v, d_model))
        self.attention = Attention()

        for _ in range(self.h):
            linear = nn.ModuleList([nn.Linear(d_k, d_model), nn.Linear(d_k, d_model), nn.Linear(d_v, d_model)])
            self.linear.append(linear)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def set_mask(self, mask, value = float("-inf")):
        self.attention.set_mask(mask, value)

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        assert len(Q.shape) == len(K.shape) == len(V.shape), f"invalid dimensions, got Q:{Q.shape}, K: {K.shape}, V:{V.shape}"

        heads = [self.attention(layer[0](Q), layer[1](K), layer[2](V), self.d_k_value) for layer in self.linear]
        concat_heads = torch.cat(heads, dim = -1)
        return torch.matmul(concat_heads, self.W_O)

In [70]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, in_features: int, out_features: int, intermediate_features = None):
        super(PositionWiseFeedForward, self).__init__()
        if intermediate_features is None:
            self.intermediate_features = in_features*4

        self.layers = nn.Sequential(
            nn.Linear(in_features, self.intermediate_features),
            nn.ReLU(),
            nn.Linear(self.intermediate_features, out_features)
        )

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X):
        return self.layers(X)

In [71]:
class EncoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int):
        super(EncoderStack, self).__init__()

        self.multi_head_attention = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def set_mask(self, mask, value = float("-inf")):
        self.multi_head_attention.set_mask(mask, value)

    def forward(self, X):
        sublayer_1_output = self.multi_head_attention(X, X, X)
        sublayer_1_normalised = self.layer_norm1(X + sublayer_1_output)
        sublayer_2_output = self.position_wise_feed_forward(sublayer_1_normalised)
        output = self.layer_norm2(sublayer_1_normalised + sublayer_2_output)
        return output

In [72]:
class DecoderStack(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int):
        super(DecoderStack, self).__init__()

        self.multi_head_attention1 = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.multi_head_attention2 = MultiHeadAttention(d_model, d_k, d_v, h)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.position_wise_feed_forward = PositionWiseFeedForward(d_v, d_v)
        self.layer_norm3 = nn.LayerNorm(d_model)

    def set_mask(self, autoregressive_padding_mask, padding_mask, value = float("-inf")):
        self.multi_head_attention1.set_mask(autoregressive_padding_mask, value = value)
        self.multi_head_attention2.set_mask(padding_mask, value = value)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def forward(self, X, Q, K):

        V = self.multi_head_attention1(X, X, X)
        V_norm = self.layer_norm1(X + V)
        sublayer_2_output = self.multi_head_attention2(Q, K, V_norm)
        sublayer_2_normalised = self.layer_norm2(V_norm + sublayer_2_output)
        sublayer_3_output = self.position_wise_feed_forward(sublayer_2_normalised)
        output = self.layer_norm3(sublayer_2_normalised + sublayer_3_output)

        return output

In [73]:
class PositionalEncoding(nn.Module):
    """Positional encoding."""
    def __init__(self, num_hiddens, max_len=1000, dropout = 0):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # Create a long enough `P`
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)


    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

In [74]:
class AttentionIsAllYouNeed(nn.Module):
    def __init__(self, d_model: int,
                 d_k: int,
                 d_v: int,
                 h: int,
                 output: int,
                 number_of_encoder_stacks: int,
                 number_of_decoder_stacks: int,
                 padding_token = None,
                 mask_value = float("-inf")):
        super(AttentionIsAllYouNeed, self).__init__()

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h
        self.padding_token = padding_token
        self.mask_value = mask_value

        self.cached_query = None
        self.cached_key = None

        self.autoregressive_padding_mask = None
        self.encoder_padding_mask = None
        self.decoder_padding_mask = None

        self.embedding = None
        self.positional_encoding = PositionalEncoding(d_model)

        encoder_list = nn.ModuleList()

        for _ in range(number_of_encoder_stacks):
            encoder_list.append(EncoderStack(d_model, d_k, d_v, h))

        self.encoder = nn.Sequential(*encoder_list)

        self.decoder = nn.ModuleList()

        for _ in range(number_of_decoder_stacks):
            self.decoder.append(DecoderStack(d_model, d_k, d_v, h))

        self.final_layer = nn.Linear(d_v, output)

    def init_weights(self, init_fn):
        self.apply(init_fn)

    def init_embedding(self, pretrained_embeddings, freeze_embeddings = False, sparse = False):
        self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze = freeze_embeddings, sparse=sparse)

    def init_state(self, Query, Key):
       self.cached_query = Query
       self.cached_key = Key

    def clear_cache(self):
        self.cached_query = None
        self.cached_key = None

        self.autoregressive_padding_mask = None
        self.encoder_padding_mask = None
        self.decoder_padding_mask = None

    def create_encoder_mask(self, X):
        token_len = X.size(1)

        if self.padding_token != None:
            padded_tokens = (X == self.padding_token)
            padding_mask = padded_tokens.unsqueeze(1).expand(-1, token_len, -1)
            self.encoder_padding_mask = padding_mask

    def create_decoder_mask(self, X):
        token_len = X.size(1)

        self.autoregressive_padding_mask = torch.triu(torch.ones(token_len, token_len), diagonal=1).bool()

        if self.padding_token != None:
            padded_tokens = (X == self.padding_token)
            padding_mask = padded_tokens.unsqueeze(1).expand(-1, token_len, -1)
            self.autoregressive_padding_mask = self.autoregressive_padding_mask | padding_mask
            self.decoder_padding_mask = padding_mask


    def set_mask(self):
        for stack in self.encoder:
            stack.set_mask(self.encoder_padding_mask, self.mask_value)

        if self.autoregressive_padding_mask is not None: # Indicates create_decoder_mask hasn't been called yet
            for stack in self.decoder:
                stack.set_mask(self.autoregressive_padding_mask, self.decoder_padding_mask, self.mask_value)

    def encode(self, X, cache = True):
        self.create_encoder_mask(X)
        self.set_mask()

        X = self.embedding(X)
        X = self.positional_encoding(X)
        state = self.encoder(X)

        if cache:
            self.cached_query = state
            self.cached_key = state
            return None
        else:
            return (state, state)

    def decode(self, Y, Q = None, K = None):
        self.create_decoder_mask(Y)
        self.set_mask()

        if not (Q and K):
            assert self.cached_query is not None, "No cached state to use"
            assert self.cached_key is not None, "No cached state to use"

            Q = self.cached_query
            K = self.cached_key



        Y = self.embedding(Y)
        Y = self.positional_encoding(Y)

        for stack in self.decoder:
            Y = stack(Y, Q, K)

        Y_hat = self.final_layer(Y)

        return Y_hat

    def forward(self, X, Y):
        self.encode(X, cache = True)
        Y_hat = self.decode(Y)
        return Y_hat

In [75]:
def get_glove_pre_trained_with_special_tokens(dim = 100):
    glove_vectors = GloVe(name='6B', dim=dim)

    pad_token = "<pad>"
    pad_index = 0

    eos_token = "<eos>"
    eos_index = 1

    bos_token = "<bos>"
    bos_index = 2

    glove_vocab = vocab(glove_vectors.stoi)

    glove_vocab.insert_token(pad_token, pad_index)
    glove_vocab.insert_token(eos_token, eos_index)
    glove_vocab.insert_token(bos_token, bos_index)

    glove_vocab.set_default_index(bos_index)

    pretrained_embeddings = glove_vectors.vectors
    pretrained_embeddings = torch.cat((torch.randn(3,pretrained_embeddings.shape[1]),pretrained_embeddings))
    pretrained_embeddings[0] = torch.zeros(pretrained_embeddings.shape[1]) # Setting padding token embedding as 0s

    return pretrained_embeddings, glove_vocab

# Running the model

In [76]:
def kaiming_custom_init(m):
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        init.constant_(m.weight, 1)
        init.constant_(m.bias, 0)
    elif isinstance(m, nn.Parameter):
        if m.requires_grad:
            init.kaiming_normal_(m, mode='fan_out', nonlinearity='relu')

In [81]:
pretrained_embeddings, glove_vocab = get_glove_pre_trained_with_special_tokens()
attention_is_all_you_need = AttentionIsAllYouNeed(d_model = 100,
                                                  d_k = 100,
                                                  d_v = 100,
                                                  h = 6,
                                                  output = pretrained_embeddings.shape[0],
                                                  number_of_encoder_stacks = 5,
                                                  number_of_decoder_stacks = 5,
                                                  padding_token=0)

attention_is_all_you_need.init_weights(kaiming_custom_init)
attention_is_all_you_need.init_embedding(pretrained_embeddings)

In [82]:
X = torch.tensor([[6, 90, 13, 100, 1, 0], [2, 7, 8, 60, 1, 0]])
Y = torch.tensor([[51, 42, 967, 1, 0, 0], [2, 90, 56, 80, 1, 0]])

x = torch.tensor([[6, 90, 13, 100, 1], [2, 7, 8, 60, 1]])
y = torch.tensor([[51, 42, 967, 1, 0], [2, 90, 56, 80, 1]])

output_1 = attention_is_all_you_need(X, Y)
output_2 = attention_is_all_you_need(x, y)

display(output_1.shape)
display(output_2.shape)

torch.Size([2, 6, 400003])

torch.Size([2, 5, 400003])

In [84]:
output_1

tensor([[[-0.0014, -0.0017,  0.0279,  ...,  0.0017, -0.0066,  0.0010],
         [-0.0188,  0.0033,  0.0096,  ...,  0.0038, -0.0041, -0.0079],
         [-0.0269, -0.0134, -0.0054,  ..., -0.0234, -0.0263, -0.0420],
         [-0.0215,  0.0205, -0.0092,  ..., -0.0294, -0.0041, -0.0373],
         [-0.0120,  0.0107, -0.0043,  ..., -0.0438, -0.0115, -0.0328],
         [-0.0140,  0.0113, -0.0050,  ..., -0.0384, -0.0102, -0.0196]],

        [[-0.0115, -0.0128, -0.0102,  ..., -0.0325, -0.0280, -0.0368],
         [-0.0019, -0.0068,  0.0288,  ..., -0.0098,  0.0134, -0.0035],
         [-0.0374,  0.0159,  0.0278,  ...,  0.0041, -0.0041,  0.0005],
         [-0.0317,  0.0010,  0.0109,  ..., -0.0022, -0.0206, -0.0242],
         [-0.0182,  0.0171, -0.0084,  ..., -0.0301, -0.0033, -0.0379],
         [-0.0140,  0.0113, -0.0050,  ..., -0.0384, -0.0102, -0.0196]]],
       grad_fn=<ViewBackward0>)

# Testing

In [35]:
def _sequence_mask(X, valid_len, value=0):
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis.

    Defined in :numref:`sec_attention-scoring-functions`"""
    # X: 3D tensor, valid_lens: 1D or 2D tensor

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [36]:
valid_lens = torch.tensor([3, 2])
value = -1.e16

X = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
maxlen = X.size(1)
mask = torch.arange((maxlen))[None, :] < valid_lens[:, None]
X[~mask] = value

display(X.shape)
display(torch.arange((maxlen))[None, :])
display(torch.arange((maxlen))[None, :] < valid_lens[:, None])
display(X)

torch.Size([2, 5])

tensor([[0, 1, 2, 3, 4]])

tensor([[ True,  True,  True, False, False],
        [ True,  True, False, False, False]])

tensor([[                 1,                  2,                  3,
         -10000000000000000, -10000000000000000],
        [                 4,                  5, -10000000000000000,
         -10000000000000000, -10000000000000000]])

In [37]:
valid_lens = torch.tensor([[3, 2, 1], [3, 3, 3]])
value = -1.e16

X = torch.ones(2, 3, 5)

shape = X.shape
# display(valid_lens.dim())
# display(torch.repeat_interleave(valid_lens, shape[1]))
display(X.reshape(-1, shape[-1]).shape)

if valid_lens.dim() == 1:
    valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
    valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
display(valid_lens)
X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
display(X)
display(nn.functional.softmax(X.reshape(shape), dim=-1))

torch.Size([6, 5])

tensor([3, 2, 1, 3, 3, 3])

tensor([[ 1.0000e+00,  1.0000e+00,  1.0000e+00, -1.0000e+06, -1.0000e+06],
        [ 1.0000e+00,  1.0000e+00, -1.0000e+06, -1.0000e+06, -1.0000e+06],
        [ 1.0000e+00, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06],
        [ 1.0000e+00,  1.0000e+00,  1.0000e+00, -1.0000e+06, -1.0000e+06],
        [ 1.0000e+00,  1.0000e+00,  1.0000e+00, -1.0000e+06, -1.0000e+06],
        [ 1.0000e+00,  1.0000e+00,  1.0000e+00, -1.0000e+06, -1.0000e+06]])

tensor([[[0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000]]])

In [38]:
batch_size = 2
seq_len = 6
d_k = 64

# Simulate query (Q) and key (K) matrices
Q = torch.rand(batch_size, seq_len, d_k)
K = torch.rand(batch_size, seq_len, d_k)

# Calculate attention scores
attention_scores = torch.bmm(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
display(attention_scores.shape)

# Create a mask for padding (1 for valid positions and 0 for padded positions)
# For example, assume the first sequence has 3 valid tokens and the second has 4
mask_v = torch.tensor([[True, True, True, False, False, False], [True, True, True, True, False, False]])
mask = mask_v.unsqueeze(1).expand(-1, seq_len, -1)
display(mask.shape)
display(mask)

# Apply the mask to the attention scores
masked_attention_scores = attention_scores.masked_fill_(~mask, float('-inf'))

display(F.softmax(masked_attention_scores, dim = -1))

torch.Size([2, 6, 6])

torch.Size([2, 6, 6])

tensor([[[ True,  True,  True, False, False, False],
         [ True,  True,  True, False, False, False],
         [ True,  True,  True, False, False, False],
         [ True,  True,  True, False, False, False],
         [ True,  True,  True, False, False, False],
         [ True,  True,  True, False, False, False]],

        [[ True,  True,  True,  True, False, False],
         [ True,  True,  True,  True, False, False],
         [ True,  True,  True,  True, False, False],
         [ True,  True,  True,  True, False, False],
         [ True,  True,  True,  True, False, False],
         [ True,  True,  True,  True, False, False]]])

tensor([[[0.3250, 0.2849, 0.3902, 0.0000, 0.0000, 0.0000],
         [0.3572, 0.3377, 0.3051, 0.0000, 0.0000, 0.0000],
         [0.3075, 0.3108, 0.3817, 0.0000, 0.0000, 0.0000],
         [0.3352, 0.3525, 0.3123, 0.0000, 0.0000, 0.0000],
         [0.3633, 0.3016, 0.3351, 0.0000, 0.0000, 0.0000],
         [0.3345, 0.3327, 0.3327, 0.0000, 0.0000, 0.0000]],

        [[0.2011, 0.2337, 0.2528, 0.3124, 0.0000, 0.0000],
         [0.2570, 0.2572, 0.2337, 0.2521, 0.0000, 0.0000],
         [0.2304, 0.2498, 0.2618, 0.2580, 0.0000, 0.0000],
         [0.2510, 0.2636, 0.2366, 0.2488, 0.0000, 0.0000],
         [0.2893, 0.2415, 0.2201, 0.2491, 0.0000, 0.0000],
         [0.2442, 0.2662, 0.2391, 0.2505, 0.0000, 0.0000]]])

In [57]:
glove_vectors = GloVe(name='6B', dim=100)

pad_token = "<pad>"
pad_index = 0

eos_token = "<eos>"
eos_index = 1

bos_token = "<bos>"
bos_index = 2

glove_vocab = vocab(glove_vectors.stoi)

glove_vocab.insert_token(pad_token, pad_index)
# glove_vocab.set_default_index(pad_index)

glove_vocab.insert_token(eos_token, eos_index)
# glove_vocab.set_default_index(eos_index)

glove_vocab.insert_token(bos_token, bos_index)
# glove_vocab.set_default_index(bos_index)

glove_vocab.set_default_index(bos_index)

display(glove_vectors.vectors.shape)

pretrained_embeddings = glove_vectors.vectors
pretrained_embeddings = torch.cat((torch.randn(3,pretrained_embeddings.shape[1]),pretrained_embeddings))
display(pretrained_embeddings.shape)
pretrained_embeddings[0] = torch.zeros(pretrained_embeddings.shape[1]) # Setting padding token embedding as 0s
display(pretrained_embeddings.shape)

torch.Size([400000, 100])

torch.Size([400003, 100])

torch.Size([400003, 100])

In [59]:
len(glove_vocab.get_itos())

400002

In [63]:
encoding_dim = 100
embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze = True, sparse=True)
embedding.eval()
pos_encoding = PositionalEncoding(encoding_dim)
pos_encoding.eval()
Linear = nn.Linear(encoding_dim, encoding_dim)


X = embedding(torch.tensor([[0, 1, 2, 0, 0, 0, 0], [0, 1, 2, 0, 0, 0, 0]]))
X = pos_encoding(X)
X = Linear(X)
display(X.shape)
attention_scores = torch.bmm(X, X.transpose(1, 2)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
display(attention_scores.shape)

torch.Size([2, 7, 100])

torch.Size([2, 7, 7])

In [71]:
pos_encoding.P.shape

torch.Size([1, 1000, 100])

In [14]:
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
display(P.shape)
display(P)

torch.Size([1, 60, 32])

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  5.3317e-01,  ...,  1.0000e+00,
           1.7783e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.0213e-01,  ...,  1.0000e+00,
           3.5566e-04,  1.0000e+00],
         ...,
         [ 4.3616e-01,  8.9987e-01,  5.9521e-01,  ...,  9.9984e-01,
           1.0136e-02,  9.9995e-01],
         [ 9.9287e-01,  1.1918e-01,  9.3199e-01,  ...,  9.9983e-01,
           1.0314e-02,  9.9995e-01],
         [ 6.3674e-01, -7.7108e-01,  9.8174e-01,  ...,  9.9983e-01,
           1.0492e-02,  9.9994e-01]]])

In [8]:
multi_head_attention = MultiHeadAttention(512, 512, 512, 8)
multi_head_attention.linear

ModuleList(
  (0-7): 8 x ModuleList(
    (0-2): 3 x Linear(in_features=512, out_features=512, bias=True)
  )
)

In [9]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = 64  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
Q = torch.randn(batch_size, seq_length, d_k)  # Queries
K = torch.randn(batch_size, seq_length, d_k)  # Keys
V = torch.randn(batch_size, seq_length, model_dim)  # Values

# Print the shapes for confirmation
print("Shape of Q:", Q.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of K:", K.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of V:", V.shape)  # Expected: (batch_size, seq_length, model_dim)

multi_head_attention = MultiHeadAttention(model_dim, d_k, model_dim, h)
output = multi_head_attention(Q, K, V)
print(f'Shape of output:{output.shape}')
print(output)

Shape of Q: torch.Size([2, 5, 64])
Shape of K: torch.Size([2, 5, 64])
Shape of V: torch.Size([2, 5, 64])
Shape of output:torch.Size([2, 5, 64])
tensor([[[        nan,  4.0096e+37,  6.2351e+37, -4.2478e+36, -2.0465e+37,
          -4.1833e+36,         nan,         nan, -7.8783e+36,  1.0117e+37,
          -9.4423e+35,  2.0174e+36, -5.8751e+36, -3.1546e+36, -3.0629e+36,
          -1.6364e+37, -1.8854e+36, -5.5364e+35,  3.2821e+36, -3.7999e+36,
          -8.9872e+35, -2.5219e+36,  3.8568e+36, -5.8303e+36,  1.0008e+37,
           4.1014e+35, -5.0816e+36, -1.4643e+35, -4.0584e+37, -1.0357e+38,
                  nan, -3.9588e+36,  1.1894e+36,  9.2627e+36,         nan,
           3.2849e+36, -2.2424e+36,  6.2773e+36,  6.5724e+36, -2.8173e+36,
          -1.7648e+37, -4.5027e+37,  9.5633e+35, -3.9691e+37,  2.6361e+36,
           6.8147e+36,  2.2129e+37,  4.7346e+36,  3.1635e+36,  1.1845e+36,
          -9.1324e+36,         nan,  3.2725e+37, -2.1265e+37, -2.5488e+37,
                  nan, -1.6300e

In [10]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = d_k  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
X = torch.randn(batch_size, seq_length, d_k)  # Queries
encoder_stack = EncoderStack(model_dim, d_k, model_dim, h)
encoder_output = encoder_stack(X)
print(encoder_output.shape)
print(encoder_output)

decoder_stack = DecoderStack(model_dim, d_k, model_dim, h)
decoder_output = decoder_stack(X, encoder_output, encoder_output)
print(decoder_output.shape)
print(decoder_output)

torch.Size([2, 5, 64])
tensor([[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, 

In [11]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
d_k = 64         # Dimension of the keys (and queries)
model_dim = d_k  # Dimension of the model
h = 8

# Generate random tensors for Q, K, V
X = torch.randn(batch_size, seq_length, d_k)
Y = torch.randn(batch_size, seq_length, d_k)

attention_is_all_you_need = AttentionIsAllYouNeed(model_dim, d_k, model_dim, h, 1, 1)
output = attention_is_all_you_need(X, Y)
print(output.shape)
print(output)

torch.Size([2, 5, 64])
tensor([[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, 

In [68]:
X.shape

3

False

In [63]:
X = torch.tensor([[2, 3, 4, 1, 0, 0, 0], [6, 8, 24, 7, 1, 0, 0]])

encoding_dim = 100
embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze = True, sparse=True)
embedding.eval()
pos_encoding = PositionalEncoding(encoding_dim)
pos_encoding.eval()
Linear = nn.Linear(encoding_dim, encoding_dim)
attention = Attention()


attention_is_all_you_need.create_mask(X)
autoregressive_padding_mask = attention_is_all_you_need.autoregressive_padding_mask
padding_mask = attention_is_all_you_need.padding_mask

X_ = embedding(X)
X_ = pos_encoding(X_)
X_ = Linear(X_)

attention.set_mask(autoregressive_padding_mask)
auto_X = attention(X_, X_, X_, torch.tensor(encoding_dim))
attention.set_mask(padding_mask)
pad_X = attention(X_, X_, X_, torch.tensor(encoding_dim))


display(auto_X)
display(pad_X)

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3862, 0.6138, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2037, 0.3103, 0.4860, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0186, 0.0134, 0.0176, 0.9504, 0.0000, 0.0000, 0.0000],
         [0.2284, 0.1735, 0.2044, 0.3937, 0.0000, 0.0000, 0.0000],
         [0.2611, 0.1619, 0.1860, 0.3910, 0.0000, 0.0000, 0.0000],
         [0.2990, 0.1558, 0.1674, 0.3778, 0.0000, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4494, 0.5506, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2418, 0.2702, 0.4879, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1864, 0.2136, 0.2419, 0.3582, 0.0000, 0.0000, 0.0000],
         [0.0125, 0.0194, 0.0199, 0.0142, 0.9341, 0.0000, 0.0000],
         [0.1262, 0.1511, 0.1852, 0.1579, 0.3795, 0.0000, 0.0000],
         [0.1348, 0.1535, 0.1783, 0.1525, 0.3809, 0.0000, 0.0000]]],
       grad_fn=<SoftmaxBackward0>)

tensor([[[0.9736, 0.0072, 0.0056, 0.0136, 0.0000, 0.0000, 0.0000],
         [0.1937, 0.3078, 0.2318, 0.2668, 0.0000, 0.0000, 0.0000],
         [0.1389, 0.2116, 0.3315, 0.3180, 0.0000, 0.0000, 0.0000],
         [0.0186, 0.0134, 0.0176, 0.9504, 0.0000, 0.0000, 0.0000],
         [0.2284, 0.1735, 0.2044, 0.3937, 0.0000, 0.0000, 0.0000],
         [0.2611, 0.1619, 0.1860, 0.3910, 0.0000, 0.0000, 0.0000],
         [0.2990, 0.1558, 0.1674, 0.3778, 0.0000, 0.0000, 0.0000]],

        [[0.3581, 0.2049, 0.1657, 0.1165, 0.1548, 0.0000, 0.0000],
         [0.2021, 0.2476, 0.1825, 0.1317, 0.2361, 0.0000, 0.0000],
         [0.1531, 0.1710, 0.3088, 0.1397, 0.2274, 0.0000, 0.0000],
         [0.1455, 0.1668, 0.1888, 0.2797, 0.2192, 0.0000, 0.0000],
         [0.0125, 0.0194, 0.0199, 0.0142, 0.9341, 0.0000, 0.0000],
         [0.1262, 0.1511, 0.1852, 0.1579, 0.3795, 0.0000, 0.0000],
         [0.1348, 0.1535, 0.1783, 0.1525, 0.3809, 0.0000, 0.0000]]],
       grad_fn=<SoftmaxBackward0>)