[![Open notebook in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/afondiel/computer-science-notebook/tree/master/core/ai-ml/deep-learning-notes/neural-nets/notebook/attention_is_all_you_need.ipynb)

In [None]:
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt

<img width="600" height="300" src="https://cdn-uploads.huggingface.co/production/uploads/6438a9027de34e8ea7e4b257/4MMtJDefZBU8dpmHana6B.png">


Source: https://huggingface.co/blog/Jaward/coding-your-first-attention

In [None]:
def attention(Q, K, V):
  dk = Q.size(-1)
  scores = torch.matmul(Q, K.transpose(-2, -1))/np.sqrt(dk)
  if mask is not None:
    scores = scores.masked_fill(mask == 0, -1e9)
  p_attn = F.softmax(scores, dim=-1)
  out = torch.matmul(p_attn, V)
  return out


In [None]:
# prompt: write a code to test this function: attention(Q, K, V)

import torch.nn.functional as F

# sample input tensors
Q = torch.randn(2, 3, 4)
K = torch.randn(2, 3, 4)
V = torch.randn(2, 3, 4)
mask = None

# call the attention function
output = attention(Q, K, V)

# print the output
print(f"Q: {Q}, \nK: {K}, \nV: {V}")
print(f"Q-shape: {Q.shape}, \nK-shape: {K.shape}, \nV-shape: {V.shape}")


Q: tensor([[[ 0.3248, -0.2816,  0.5649, -0.5883],
         [ 0.1014, -0.2892, -1.1494, -0.0343],
         [ 1.1611, -1.1778, -0.4965, -2.5952]],

        [[-0.8530, -0.4816,  0.1375,  1.1548],
         [-1.0562, -0.5972, -2.0863, -0.6770],
         [ 0.1306, -0.0780, -0.2148, -1.5332]]]), 
K: tensor([[[-0.6672,  0.3029, -1.2540, -0.2968],
         [-0.1970, -1.3694,  0.1549,  0.1588],
         [ 0.2104,  0.0468,  0.0048, -0.3037]],

        [[-0.0449, -0.5207, -0.1919, -1.0864],
         [ 0.0602, -0.0561,  0.7803, -0.5877],
         [ 0.1192,  1.6522, -0.0620,  1.8709]]]), 
V: tensor([[[-1.4642,  1.5672, -0.4723,  1.6154],
         [-0.7321,  1.4285, -0.7772, -2.8687],
         [ 0.2703, -0.6095, -0.6333, -2.0936]],

        [[ 0.0725, -0.3944,  0.8649,  0.9405],
         [-0.2680, -0.1354,  0.0665,  0.0850],
         [-0.6947,  0.6490,  0.2722,  1.5652]]])
Q-shape: torch.Size([2, 3, 4]), 
K-shape: torch.Size([2, 3, 4]), 
V-shape: torch.Size([2, 3, 4])


In [None]:
print(f"output: {output}, output-shape: {output.shape}")

output: tensor([[[-0.5134,  0.6832, -0.6544, -1.5743],
         [-0.8294,  0.9842, -0.5961, -0.5416],
         [-0.5480,  0.6991, -0.6429, -1.3978]],

        [[-0.4515,  0.2713,  0.3368,  1.1063],
         [-0.0725, -0.2338,  0.6568,  0.8553],
         [-0.0918, -0.2438,  0.5469,  0.6699]]]), output-shape: torch.Size([2, 3, 4])


<img width="600" height="800" src="https://cdn-uploads.huggingface.co/production/uploads/6438a9027de34e8ea7e4b257/c-RzcFcoyRVFqYCSxgvsS.png">




In [None]:


# Scaled Dot-Product Attention // Self-Attention
class SingleHeadAttention(nn.Module):
    def __init__(self, in_dim, attn_dim):
        super(SingleHeadAttention, self).__init__()
        self.Q_linear = nn.Linear(in_dim, attn_dim)
        self.K_linear = nn.Linear(in_dim, attn_dim)
        self.V_linear = nn.Linear(in_dim, attn_dim)

    def forward(self, Q, K, V, mask=None, dropout=None):
        Q_proj = self.Q_linear(Q)
        K_proj = self.K_linear(K)
        V_proj = self.V_linear(V)

        dk = Q.size(-1)

        scores = torch.matmul(Q_proj, K_proj.transpose(-2, -1)) / math.sqrt(dk)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        out = torch.matmul(p_attn, V_proj)

        return out

In [None]:
# Test for SingleHeadAttention
single_head_attn = SingleHeadAttention(in_dim=512, attn_dim=64)
Q_test = torch.randn(2, 10, 512)
K_test = torch.randn(2, 10, 512)
V_test = torch.randn(2, 10, 512)
output_single = single_head_attn(Q_test, K_test, V_test)
print("Single Head Attention Output Shape:", output_single.shape)  # Expected: [2, 10, 64]

Single Head Attention Output Shape: torch.Size([2, 10, 64])


<img width="600" height="800" src="https://cdn-uploads.huggingface.co/production/uploads/6438a9027de34e8ea7e4b257/QfUeWOfSU1J64OR7yIYUn.png">

In [None]:
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, attn_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList(
            SingleHeadAttention(in_dim, attn_dim) for _ in range(num_heads)
        )
        self.linear = nn.Linear(num_heads * attn_dim, in_dim)

    def forward(self, Q, K, V, mask=None, dropout=None):
        head_outputs = [head(Q, K, V, mask, dropout) for head in self.heads]
        concatenated_outputs = torch.cat(head_outputs, dim=-1)
        output = self.linear(concatenated_outputs)
        return output

In [None]:
# Test for MultiHeadAttention
multi_head_attn = MultiHeadAttention(in_dim=512, attn_dim=64, num_heads=8)
output_multi = multi_head_attn(Q_test, K_test, V_test)
print("Multi-Head Attention Output Shape:", output_multi.shape)  # Expected: [2, 10, 512]


## References

- [Let's build GPT: from scratch, in code, spelled out. - 37:46](https://www.youtube.com/watch?v=kCc8FmEb1nY)

HuggingFace:

- [On Coding Your First Attention - HF](https://huggingface.co/blog/Jaward/coding-your-first-attention)
- [Attention Is All You Need (paper overview)](https://huggingface.co/papers/1706.03762)

Google:
- [Attention Mechanism](https://www.youtube.com/watch?v=fjJOgb-E41w&list=PLIivdWyY5sqIlLF9JHbyiqzZbib9pFt4x&index=3)
- [Attention Mechanism: Overview](https://www.youtube.com/watch?v=8PmOaVYVeKY&list=PLBgogxgQVM9s0i9oloJwjIG-zj6Svsm20&index=2)

Transformer Notes:
- https://docs.google.com/document/d/19zFJ4qWq7u3x5sKCd3ej0v9z9Oozo08g_rRPuqTnYls/edit
- https://github.com/afondiel/computer-science-notes/blob/master/ai/nlp-notes/models/transformers-notes.md