## attention is all you need

<p align="center">
    <img src="./images/transformer.png" alt="transformer architecture" width="400"/>
</p>


In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

In [2]:
class InputEmbedding(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        # the embedding dimension, called d_model in the attention is all you need paper.
        self.d_model = d_model
        self.vocab_size = vocab_size
        # nn.Embedding maps indices (here, indices of words in the vocabulary) to the same
        # tensor (a key-value lookup). The embeddings themselves are learnt as part of the model training.
        # num_embeddings is same as vocab size, as an embedding is learnt for each item in the vocabulary
        self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    def forward(self, indices):
        return self.emb(indices) * np.sqrt(self.d_model)


indices = torch.tensor([1, 123, 678, 21, 90])
ie = InputEmbedding(512, 10_000)
out = ie(indices)
print(indices.shape)
print(out.shape)

torch.Size([5])
torch.Size([5, 512])


In [3]:
d = nn.Dropout(0.3)
# will zero ~3 elements at random. Used for regularisation during training
t = torch.rand(10)
d(t)

tensor([0.0000, 0.0000, 1.1098, 0.0000, 0.5565, 0.1319, 1.1614, 0.8198, 0.9523,
        0.0000])

In [4]:
1e4

10000.0

In [15]:
class PositionalEncoding(nn.Module):
    """this is only defined once and used during training and inference"""

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        # maximum sequence length you expect to see, so that we can generate positional encodings
        # upto that length
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # the positional encoding will be same dimension as the embedding for the sequence
        self.pe = torch.zeros(seq_len, d_model)
        self.set_positional_encoding()

    def set_positional_encoding(self):
        """Positional encoding, as described in the paper: attention is all you need.
        These can be predefined or learned, but the authors found no difference and chose this
        as it would allow it to generalize over sequence lengths greater than ones seen during training.
        """
        # pos from the paper
        position = torch.arange(0, self.seq_len, 1, dtype=torch.float).unsqueeze(1)
        # 2i from the paper: i is the dimension, here we operate on the even dimensions for both sin and cos
        # we apply the sin'd and cos'd sequences to even and odd dimensions respectively
        i_2 = torch.arange(0, self.d_model, 2)
        # numerically stable way of computing 1/((10_000)^(2i/d_model)) from the paper
        denominator = torch.exp((i_2 / self.d_model) * np.log(1e4))
        # across all seq, just even,odd dimensions
        self.pe[:, 0::2] = torch.sin(position / denominator)
        self.pe[:, 1::2] = torch.cos(position / denominator)
        # add batch dimension at start
        # pe is (1,seq_len, d_model)
        self.pe = self.pe.unsqueeze(0)
        # # a buffer is part of a module's state (state_dict) when saved, but is not a parameter that is
        # # tuned.
        # self.register_buffer("pe", self.pe)

    def forward(self, x):
        # x has a batch dimension
        # x shape is (batch, seq_len, d_model)
        # we add only the pe corresponding to x's sequence length
        x += (self.pe[:, : x.shape[1], :]).requires_grad_(False)
        # apply dropout
        x = self.dropout(x)
        return x


indices = torch.tensor([1, 123, 678, 21, 90]).unsqueeze(0)  # with batch dim
ie = InputEmbedding(512, 10_000)
max_seq_len = 1000
pe = PositionalEncoding(512, max_seq_len, 0.3)


out = ie(indices)
print(indices.shape)
print(out.shape)
out_pe = pe(out)
print(out_pe.shape)

torch.Size([1, 5])
torch.Size([1, 5, 512])
torch.Size([1, 5, 512])


In [25]:
# In batchnorm, we calculate the mean, std dev across the batch dimension, ie, statistics
# across the batch, one per feature. after normalisation, we do an affine transform
# (m,b from y=mx+b), and m,b are learnable. this is so that it doesnt necessarily stick to zero mean
# and unit variance (which would limit model expressivity).
# But batchnorm will get representative mean,var only if batch size is a good enough size. Layernorm
# is batch size independant, and the mean,var is calculated across all features, for each input example independantly,
# and also applied to that input example alone independantly. It's used in transformers, rnn, etc
# where batch sizes may vary. see https://arxiv.org/abs/1607.06450
# note this has nothing to do with neural netowork layers, just the way it's applied to the data within the network
class LayerNorm(nn.Module):
    def __init__(self, eps=1e-7) -> None:
        super().__init__()
        self.eps = eps
        self.m = nn.Parameter(torch.tensor([1.0]))
        self.c = nn.Parameter(torch.tensor([1.0]))

    def forward(self, x):
        # stats along the feature dimension
        # can also do keepdim=True instead of collapsing along mean dimension and adding it later
        mu = torch.mean(x, dim=-1).unsqueeze(-1)
        std = torch.std(x, dim=-1).unsqueeze(-1)
        # again, the learnable params m and c allow not nonzero mean and non-unit variance
        return self.m * ((x - mu) / (std + self.eps)) + self.c


t = torch.rand(32, 100, 3)
ln = LayerNorm()
ln(t).shape

torch.Size([32, 100, 3])

In [33]:
class FFN(nn.Module):
    def __init__(self, d_model, n_hidden, dropout: float) -> None:
        super().__init__()
        # self.d_model = d_model
        # self.n_hidden = n_hidden
        self.net = nn.Sequential(
            nn.Linear(d_model, n_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, d_model),
        )

    def forward(self, x):
        return self.net(x)


# batch,n,dim
# can give any number of examples (n)
# features is what each neuron operates on and learns a function of them, for the output.
t = torch.rand(32, 100, 512)
ffn = FFN(512, 2048, 0.3)
ffn(t).shape

torch.Size([32, 100, 512])

In [68]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float):
        super().__init__()
        # the embedding dimension will be split into n_heads equal chunks, with
        # each chunk (possibly) attending to it differently
        assert d_model % n_heads == 0
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.Wk = nn.Linear(d_model, d_model)
        self.Wq = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    @staticmethod
    def attention(k, q, v, mask, dropout: nn.Dropout):
        d_k = q.shape[-1]
        # similarity (as a probability distribution) between what i have that is useful to other parts of me
        # , and what i need. Scale by root of the multi head attention dimension
        # shape is (batch, n_heads, seq_len, seq_len)
        similarities: torch.Tensor = (k @ q.transpose(-1, -2)) / np.sqrt(d_k)
        # mask_shape= (1, 1, seq_len, seq_len). make it a small value so softmax zeros it out
        if mask:
            # similarities[mask] = -torch.inf
            similarities.masked_fill_(mask == 0, -torch.inf)
        similarities_probability = F.softmax(similarities, dim=-1)
        if dropout:
            similarities_probability = dropout(similarities_probability)
        # gather appropriate values as a weighted(weights learnt) sum
        attention = similarities_probability @ v
        # similarities_probability used for visualisation
        return attention, similarities_probability

    def forward(self, k, q, v, mask):
        # for encoder, k,q,v all are input (x), but seperated it here because in the decoder, some might be form
        # the encoder and some might be from the decoder
        # x : (batch, seq_len, dim)
        # https://youtu.be/XfpMkf4rD6E?list=LL&t=1472
        # there are the (for me's) below as this is self attention
        k = self.Wk(k)  # what info do i have that is useful (for me)
        q = self.Wq(q)  # what info do i need (from me)?
        v = self.Wv(v)  # what info I publicly reveal (to other parts of me)?

        # split k,q,v along embedding dimension into h parts
        # we want (batch, head, seq, d_k) : each head should see the whole seq
        k = k.view(k.shape[0], k.shape[1], self.n_heads, self.d_k).transpose(1, 2)
        q = q.view(q.shape[0], q.shape[1], self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(v.shape[0], v.shape[1], self.n_heads, self.d_k).transpose(1, 2)
        mask = None
        # out has a shape batch, heads, seq_len, dk
        attention, self.similarity_prob = MultiHeadAttention.attention(
            k, q, v, mask, self.dropout
        )
        # view requires the tensor to be in continuous memory, which transpose disturbs, so we make it contiguous after transpose
        attention = (
            attention.transpose(1, 2).contiguous().view(k.shape[0], -1, self.d_model)
        )
        output = self.Wo(attention)
        return output


t = torch.rand(32, 100, 512)
mask = torch.ones(32, 100, 100)
mha = MultiHeadAttention(512, 4, 0.3)
# k,q,v all form the same input
mha(t, t, t, mask).shape

torch.Size([32, 100, 512])

In [46]:
l = nn.Linear(512, 512)
len(list(l.parameters())[1])

512

In [55]:
t = torch.rand(32, 4, 100, 128)

# q = torch.rand(512, 512)

(F.softmax((t @ t.transpose(-1, -2)), dim=1) @ (torch.rand(32, 4, 100, 128))).shape

torch.Size([32, 4, 100, 128])

In [70]:
class NormAndAdd(nn.Module):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNorm()

    def forward(self, x, prev_layer):
        # take the input to the previous layer, add it with the output of the previous layer (when fed in NORMALISED input)
        # this is different from the add&norm in the original paper, and is one of the few improvements to the original architecture.
        # the previous layer takes in normalised INPUT, not normalising the output. This prevents exploding gradients, etc and helps
        # in better training dynamics
        return self.dropout(prev_layer(self.norm(x))) + x

In [73]:
class EncoderLayer(nn.Module):
    def __init__(self, mha: MultiHeadAttention, ffn: FFN, dropout: float) -> None:
        super().__init__()
        self.mha = mha
        self.ffn = ffn
        # residual connections - two because they have seperate layernorms and layernorm parameters
        self.rc1 = NormAndAdd(dropout)
        self.rc2 = NormAndAdd(dropout)

    def forward(self, x, src_mask):
        out = self.rc1(x, lambda x: self.mha(x, x, x, src_mask))
        out = self.rc2(out, self.ffn)
        return out


t = torch.rand(32, 100, 512)
mask = torch.ones(32, 100, 100)
mha = MultiHeadAttention(512, 4, 0.3)
ffn = FFN(512, 2048, 0.3)
ecl = EncoderLayer(mha, ffn, 0.3)

ecl(t, mask).shape

torch.Size([32, 100, 512])

In [77]:
class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm()

    def forward(self, x, src_mask):
        for layer in self.layers:
            x = self.norm(layer(x, src_mask))
        return x


t = torch.rand(32, 100, 512)
mask = torch.ones(32, 100, 100)


encoder = Encoder(
    nn.ModuleList(
        [
            EncoderLayer(MultiHeadAttention(512, 4, 0.3), FFN(512, 2048, 0.3), 0.3)
            for _ in range(10)
        ]
    )
)

encoder(t, mask).shape

torch.Size([32, 100, 512])