<a href="https://colab.research.google.com/github/DavidCastroPena/Vaswani2017/blob/main/replicatingAttentionIsAllWhatYouNeed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:


# when we are using a transformer to analyze say a sentence ("the cat sat"), first what happens is that each word gets stored in a
# query: q_sat. Now each query or what i am looking to analize is compared (dot product) with its all the keys; all words have keys: k_the, k_cat,
# and k_sat. The dot product geometrically shows that if the direction of two vectors is related, this will reflect a large dot
#product, which shows that the two tokens are related

"""
qsat‚ãÖkthe= 0.1
	‚Äã
ùëûsat‚ãÖùëòcat=2.1

qsat‚ãÖkcat=1.5

"""

# "sat" is more related to cat than the

#Recall, the keys and query are linear transformation of their embeddings. Intuitively, we say that the query is a question
#per token that aims to uncover the role of the specific word in a given text. The word ‚Äúquestion‚Äù is shorthand for
#something very precise: The query defines a direction in vector space along which relevance is measured.



#Lesson 1 goal: implement and understand:

#Attention(Q,K,V)=softmax(QK^T/root(d_k)*V

#Key concepts:

#PyTorch tensors

#Matrix multiplication

#Softmax

#Masking

#Shape reasoning


#Setup
import torch
import torch.nn.functional as F

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)


device: cpu


In [2]:

# Q,K ,V, and a mask are tensor. A tensor is a multi-dimensional array of numbers and in transformers all the meaning
#arises from learned tensor transformations and interactions

# embeddings are vectors, but to make computations or to represent sentences, instead of converting to
# vectors in a sequential way or one by one, tensors allow you to represent text in a more efficient way


# A vector is a tensor (a 1-D tensor)

# A matrix is a tensor (a 2-D tensor)

# So the real question is:

# Why do we need higher-dimensional tensors instead of just one vector at a time? Imagine you process one word at a time,
# using vectors only. Sentence: ‚ÄúThe cat sat‚Äù

# You would have to do this:

# Take embedding of "The" ‚Üí vector

# Compare it to "cat" ‚Üí vector

# Compare it to "sat" ‚Üí vector

# Repeat for "cat"

# Repeat for "sat"

# That‚Äôs:

# nested loops

# sequential computation

# very slow

# hard to parallelize

# messy gradients

# This is basically how early RNNs worked.

# 3Ô∏è‚É£ What tensors give you: structure

# Tensors let you represent many things at once.

# Instead of:

# ‚ÄúOne word ‚Üí one vector‚Äù

# You represent:

# All words, all positions, all heads, all batches ‚Äî at the same time


def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, heads, q_len, d_k)
    K: (batch, heads, k_len, d_k)
    V: (batch, heads, k_len, d_v)
    mask: (batch, heads, q_len, k_len) with 1 for allowed, 0 for blocked (optional)

    returns:
      out: (batch, heads, q_len, d_v)
      attn: (batch, heads, q_len, k_len)
    """
    d_k = Q.size(-1)

    # (batch, heads, q_len, k_len)
    scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)

    if mask is not None:
        # Set blocked positions to a large negative value so softmax ~ 0 there.
        scores = scores.masked_fill(mask == 0, -1e9)

    attn = F.softmax(scores, dim=-1)
    out = attn @ V
    return out, attn


In [3]:
# Lets create a small example to see the attention at work

batch, heads, seq_len, d_k, d_v = 1, 1, 4, 3, 2

Q = torch.randn(batch, heads, seq_len, d_k, device=device)
K = torch.randn(batch, heads, seq_len, d_k, device=device)
V = torch.randn(batch, heads, seq_len, d_v, device=device)

out, attn = scaled_dot_product_attention(Q, K, V)

print("Q shape:", Q.shape)
print("attn shape:", attn.shape)
print("out shape:", out.shape)

print("\nAttention matrix (q_len x k_len):\n", attn[0,0].detach().cpu())
print("\nRow sums (should be ~1):\n", attn[0,0].sum(dim=-1).detach().cpu())


Q shape: torch.Size([1, 1, 4, 3])
attn shape: torch.Size([1, 1, 4, 4])
out shape: torch.Size([1, 1, 4, 2])

Attention matrix (q_len x k_len):
 tensor([[0.3921, 0.2020, 0.0964, 0.3094],
        [0.2307, 0.2815, 0.3649, 0.1228],
        [0.4548, 0.1400, 0.0701, 0.3350],
        [0.1593, 0.2621, 0.4593, 0.1193]])

Row sums (should be ~1):
 tensor([1.0000, 1.0000, 1.0000, 1.0000])
