In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F

In [None]:
tokens = ["The", " ", "cat", " ", "sat", " ", "on", " ", "the", " ", "mat", "."]
n_tokens = len(tokens)
d_k = 6

# randomly initialize Q, K, V with Standard Normal distribution (mean=0, std=1)
Q = torch.randn(n_tokens, d_k) # n_tokens x d_k
K = torch.randn(n_tokens, d_k)
V = torch.randn(n_tokens, d_k)

# (n_tokens x d_k) @ (d_k x n_tokens) = (n_tokens x n_tokens)
scores = Q @ K.T 

# Values can become large, so we scale them down by the square root of d_k
# to prevent softmax from saturating
# scaling keeps variance of the dot product more consistent
# (n_tokens x n_tokens) / sqrt(d_k) = (n_tokens x n_tokens)
scaled_score = scores / (d_k ** 0.5)

# softmax to get attention weights last dimension
# For each query, softamx is applied across all keys
# converts each row to probaility distribution
# the last diimension corresponds to the keys
attn_weights = F.softmax(scaled_score, dim=-1)

# (n_tokens x n_tokens) @ (n_tokens x d_k) = (n_tokens x d_k)
# the attention weights are used to weight the values
# the result is a weighted sum of the values
output_original = attn_weights @ V

output_original

tensor([[-0.9442,  0.0367,  0.3134, -0.7202, -0.3174,  0.3125],
        [ 0.0459,  0.2607,  0.1560, -0.6120,  1.2428,  0.6578],
        [-0.8398, -0.1090,  0.1608, -0.4982, -0.4739,  0.0708],
        [-0.3803, -0.1351,  0.0249, -0.4601,  0.1185,  0.0723],
        [-0.4724,  0.0727,  0.0468, -0.5159,  0.2836,  0.1467],
        [-0.7485, -0.2667,  0.3595, -0.4836,  0.0442,  0.1493],
        [-0.3828, -0.1605,  0.1255, -0.2986,  0.1661,  0.0273],
        [-0.4709, -0.2345,  0.1679, -0.1312,  0.1520,  0.0548],
        [-0.5072, -0.2233,  0.1079, -0.4745,  0.1901, -0.1194],
        [-0.9217, -0.1705,  0.3464, -0.5381, -0.2762,  0.1383],
        [-1.0513, -0.1603,  0.4750, -0.6926, -0.2177,  0.2249],
        [-1.0403, -0.1923,  0.2474, -0.8217, -0.5872,  0.1576]])