In [1]:
import torch
import torch.nn as nn

#### **Simplified Self-Attention**


In [2]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

**Single Input Attention Calculation**


In [8]:
query = inputs[1] # Getting the second row, or the features of the token "journey"

attn_scores_2 = torch.empty(inputs.shape[0])

# loop through the tokens in the sequence
for i, x_i in enumerate(inputs):
    # take the dot product of each token embedding vector
    attn_scores_2[i] = torch.dot(x_i, query)
    
    # Dot Product example:
    # torch.Tensor([1,2,3,4,5]).dot(torch.Tensor([2,1,2,1,2]))
    # (1*2)+(2*1)+(3*2)+(4*1)+(5*2) = tensor(24.)

attn_scores_2


tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [19]:
# Normalize the attention scores
attn_scores_norm = attn_scores_2 / torch.sum(attn_scores_2)
print(f"Unnormailzed: {attn_scores_2}")
print(f"Normailzed: {attn_scores_norm}")

Unnormailzed: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
Normailzed: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])


In [30]:
# But softmax is more desireable for normalization
# Do note that there are underflow and overflow issues that come from this softmax implementation
print(attn_scores_2.exp() / attn_scores_2.exp().sum())

# This softmax implementation is preffered.
print(attn_scores_2.softmax(dim=0))

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [46]:
# calculate the updated embeddings
inputs.T @ attn_scores_2.softmax(dim=0)

tensor([0.4419, 0.6515, 0.5683])

**Full Attention Calculation**


In [100]:
# Get the attention scores for the query
attn_scores = inputs @ inputs.T

# normalize w/ softmax
attn_weights = attn_scores.softmax(dim=-1)

In [101]:
# Compute vectors
attn_weights @ inputs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

#### **Self-Attention**


#### **Causal Attention**


#### **Multi-head Attention**
