# Naive Self Attention
Simple self attention mechanism without trainable weights

In [1]:
import torch

In [2]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

## Single Input Self Attention

In [3]:
# Calculating self attention for second token, i.e. "journey"
query = inputs[1]

# Step 1: Calculating attention scores
attn_scores = torch.empty(inputs.shape[0])
for i, emb in enumerate(inputs):
    attn_scores[i] = torch.dot(query, emb)

# Step 2: Calculating attention weights
attn_wts_tmp = attn_scores / attn_scores.sum(dim=0)

def softmax(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_wts_naive = softmax(attn_scores)
attn_wts = torch.softmax(attn_scores, dim=0)

# Step 3: Calculating context vector
context_vec = torch.zeros_like(query)
for i, wt in enumerate(attn_wts):
    context_vec += wt * inputs[i]

print("Original Representation: ", query)
print("Self Attention Representation: ", context_vec)


Original Representation:  tensor([0.5500, 0.8700, 0.6600])
Self Attention Representation:  tensor([0.4419, 0.6515, 0.5683])


## All Inputs Self Attention

In [4]:
# Calculating self attention for all inputs

# Step 1: Calculate attention scores
attn_scores_naive = torch.empty((6, 6))
for i, token_emb_i in enumerate(inputs):
    for j, token_emb_j in enumerate(inputs):
        attn_scores_naive[i][j] = torch.dot(token_emb_i, token_emb_j)

attn_scores = inputs @ inputs.T

# Step 2: Calculate attention weights
attn_wts_naive = torch.empty((6, 6))
for token_index, token_attn_score_vec in enumerate(attn_scores_naive):
    attn_wts_naive[token_index] = softmax(token_attn_score_vec)

attn_wts = torch.softmax(attn_scores, dim=1)

# Step 3: Calculate context vector
context_vecs_naive = torch.zeros_like(inputs)
for i in range(attn_scores_naive.shape[0]):
    for j in range(attn_wts_naive.shape[1]):
        attn_wt = attn_wts_naive[i][j]
        context_vecs_naive[i] += attn_wt * inputs[j]

context_vecs = attn_wts @ inputs

In [5]:
print(torch.equal(attn_scores_naive, attn_scores))
print(torch.equal(attn_wts_naive, attn_wts))
print(torch.equal(context_vecs_naive, context_vecs))

False
False
False


In [6]:
attn_scores_naive

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [7]:
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [8]:
torch.eq(attn_scores_naive, attn_scores)

tensor([[False, False,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True, False, False,  True],
        [ True,  True,  True, False,  True,  True],
        [ True,  True,  True,  True,  True,  True]])

In [9]:
print(attn_scores_naive[0][0], attn_scores[0][0])
print(attn_scores_naive.dtype, attn_scores.dtype)

tensor(0.9995) tensor(0.9995)
torch.float32 torch.float32
