In [1]:
from importlib.metadata import version

print("torch version:", version("torch"))

torch version: 2.0.1


In [2]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts   (x^3)
    [0.22, 0.58, 0.33], # with     (x^4)
    [0.77, 0.25, 0.10], # one      (x^5)
    [0.05, 0.80, 0.55]] # step     (x^6)
)

In [3]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [13]:
query = inputs[1]  # journey (x^2)


attention_scores = torch.empty(inputs.shape[0])  # (num_tokens,) = (6,)
for i, x in enumerate(inputs):
    attention_scores[i] = torch.dot(query, x)  # x^2 . x^(i)
    
print("Attention scores:", attention_scores)

Attention scores: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [12]:
query

tensor([0.5500, 0.8700, 0.6600])

In [11]:
for i, score in enumerate(inputs):
    print(i, score)

0 tensor([0.4300, 0.1500, 0.8900])
1 tensor([0.5500, 0.8700, 0.6600])
2 tensor([0.5700, 0.8500, 0.6400])
3 tensor([0.2200, 0.5800, 0.3300])
4 tensor([0.7700, 0.2500, 0.1000])
5 tensor([0.0500, 0.8000, 0.5500])


In [8]:
for i, x in enumerate(inputs):
    print(f"Attention score for x^{i+1}: {attention_scores[i]:.4f}")  # i+1 because i starts at 0

Attention score for x^1: 0.9544
Attention score for x^2: 1.4950
Attention score for x^3: 1.4754
Attention score for x^4: 0.8434
Attention score for x^5: 0.7070
Attention score for x^6: 1.0865


In [14]:
for idx, element in enumerate(inputs[0]):
    print(f"Element {idx}: {element}")

Element 0: 0.4300000071525574
Element 1: 0.15000000596046448
Element 2: 0.8899999856948853


In [15]:
res = 0.

for idx, element in enumerate(inputs[0]):
    res += inputs[0][idx] * query[idx] # x^1 * x^2[idx]

print("Result:", res)
print("Dot product:", torch.dot(inputs[0], query))  # x^1 . x^2

Result: tensor(0.9544)
Dot product: tensor(0.9544)


In [None]:
attention_weights_2 = attention_scores / attention_scores.sum()

print("Attention weights:", attention_weights_2)
attention_weights_2.sum()  # should be 1.0

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])


tensor(1.0000)

In [17]:
def softmax(x):
    exp_x = torch.exp(x)
    return exp_x / exp_x.sum()

attention_weights_2 = softmax(attention_scores)
print("Attention weights (softmax):", attention_weights_2)
attention_weights_2.sum()  # should be 1.0

Attention weights (softmax): tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


tensor(1.)

In [19]:
attention_weights_2 = torch.softmax(attention_scores, dim=0)


print("Attention weights:", attention_weights_2)
print("Sum of attention weights:", attention_weights_2.sum())  # should be 1.0

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum of attention weights: tensor(1.)


In [20]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [25]:
query.shape

torch.Size([3])

In [27]:
query = inputs[1]  # journey (x^2)

context_vecs_2 = torch.zeros_like(query)  # (embedding_dim,) = (3,)
for i, x_vec in enumerate(inputs):
    context_vecs_2 += attention_weights_2[i] * x_vec  # a^(i) * x^(i)

print("Context vector:", context_vecs_2)
print("Context vector shape:", context_vecs_2.shape)  # (embedding_dim,)

Context vector: tensor([0.4419, 0.6515, 0.5683])
Context vector shape: torch.Size([3])


In [28]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [29]:
attention_scores_all = torch.empty(6, 6)  # (num_tokens, num_tokens) = (6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attention_scores_all[i, j] = torch.dot(x_i, x_j)  # x^(i) . x^(j)
        
print("Attention scores (all):", attention_scores_all)
attention_scores_all.shape
    

Attention scores (all): tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


torch.Size([6, 6])

In [30]:
# Can be achieved with matrix multiplication as well
aattention_scores_all_matmul = inputs @ inputs.T
aattention_scores_all_matmul

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [31]:
attention_scores_all = torch.matmul(inputs, inputs.T)  # (num_tokens, num_tokens) = (6, 6)
attention_scores_all

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [32]:
# Normalize each row with softmax
attention_weights_all = torch.softmax(attention_scores_all, dim=1)  # (6, 6)
print("Attention weights (all):", attention_weights_all)
print("Shape", attention_weights_all.shape)

Attention weights (all): tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
Shape torch.Size([6, 6])


In [34]:
attention_weights_all.sum(dim=-1)  # each row should sum to 1.0

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [35]:
attention_weights_all[-1].sum()

tensor(1.)

In [37]:
# Compute all context vectors
context_vectors_all = torch.zeros_like(inputs)  # (num_tokens, embedding_dim) = (6, 3)
for i in range(inputs.shape[0]):  # for each token
    for j in range(inputs.shape[0]):  # for each token
        context_vectors_all[i] += attention_weights_all[i, j] * inputs[j]
        
print(context_vectors_all)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [38]:
# In a much simpler way using matrix multiplication
context_vectors_all_matmul = attention_weights_all @ inputs
context_vectors_all_matmul

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

In [41]:
# Sanity check for the 2nd context vector
context_vecs_2

tensor([0.4419, 0.6515, 0.5683])

In [43]:
inputs.shape

torch.Size([6, 3])

In [42]:
# Query vector
# Key vector
# Value vector

x_2 = inputs[1]  # journey (x^2)
input_dimension = inputs.shape[1]  # embedding_dim = 3
output_dimension = 2 

In [48]:
x_2, x_2.shape

(tensor([0.5500, 0.8700, 0.6600]), torch.Size([3]))

In [44]:
torch.manual_seed(123)  # for reproducibility

W_query = torch.nn.Parameter(torch.rand(input_dimension, output_dimension), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(input_dimension, output_dimension), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(input_dimension, output_dimension), requires_grad=False)

In [46]:
W_query.shape, W_key.shape, W_value.shape

(torch.Size([3, 2]), torch.Size([3, 2]), torch.Size([3, 2]))

In [49]:
# Now compute the query, key, and value vectors for x^2
query_2 = x_2 @ W_query  # (output_dimension,) = (2,)
key_2 = x_2 @ W_key      # (output_dimension,) = (2,)
value_2 = x_2 @ W_value  # (output_dimension,) = (2,)

In [51]:
query_2, key_2, value_2

(tensor([0.4306, 1.4551]), tensor([0.4433, 1.1419]), tensor([0.3951, 1.0037]))

In [52]:
query_2.shape

torch.Size([2])

In [53]:
inputs.shape

torch.Size([6, 3])

In [54]:
# We can obtain all keys and values via matrix multiplication
keys = inputs @ W_key      # (num_tokens, output_dimension) = (6, 2)
values = inputs @ W_value  # (num_tokens, output_dimension) = (6, 2)

keys, values

(tensor([[0.3669, 0.7646],
         [0.4433, 1.1419],
         [0.4361, 1.1156],
         [0.2408, 0.6706],
         [0.1827, 0.3292],
         [0.3275, 0.9642]]),
 tensor([[0.1855, 0.8812],
         [0.3951, 1.0037],
         [0.3879, 0.9831],
         [0.2393, 0.5493],
         [0.1492, 0.3346],
         [0.3221, 0.7863]]))

In [55]:
keys.shape, values.shape

(torch.Size([6, 2]), torch.Size([6, 2]))

In [56]:
# compute the attention score w22

attention_score_22 = torch.dot(query_2, keys[1])  # w_22 = q_2 . k_2
print("Attention score w_22:", attention_score_22)

Attention score w_22: tensor(1.8524)


In [57]:
# Generalizing to compute all

attention_scores_22 = torch.matmul(keys, query_2)  # (num_tokens,) = (6,)
attention_scores_22

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [60]:
query_2.shape[0], keys.shape[-1]

(2, 2)

In [61]:
# compute the attention weights

d_k = query_2.shape[0]  # output_dimension
attention_weights_22 = torch.softmax(attention_scores_22 / d_k**0.5, dim=-1)  # (num_tokens,) = (6,)
print("Attention weights:", attention_weights_22)

Attention weights: tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [62]:
attention_weights_22.shape

torch.Size([6])

In [63]:
# compute the context vector
context_vecs_2 = attention_weights_2 @ values
print(context_vecs_2)

tensor([0.3069, 0.8188])


In [65]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    
    def __init__(self, input_dimension, output_dimension):
        super().__init__()
        self.output_dimension = output_dimension
        
        self.weight_query = nn.Parameter(torch.rand(input_dimension, output_dimension))
        self.weight_key   = nn.Parameter(torch.rand(input_dimension, output_dimension))
        self.weight_value = nn.Parameter(torch.rand(input_dimension, output_dimension))
        
    
    def forward(self, x: torch.Tensor):
        keys = x @ self.weight_key
        queries = x @ self.weight_query
        values = x @ self.weight_value
        
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vectors = attention_weights @ values
        return context_vectors
    

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(input_dimension, output_dimension)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [66]:
# More efficient implementation with nn.Linear
import torch.nn as nn

class SelfAttention_v2(nn.Module):
    def __init__(self, input_dimension, output_dimension, qkv_bias=False):
        super().__init__()
        self.output_dimension = output_dimension
        
        self.weight_query = nn.Linear(input_dimension, output_dimension, bias=qkv_bias)
        self.weight_key   = nn.Linear(input_dimension, output_dimension, bias=qkv_bias)
        self.weight_value = nn.Linear(input_dimension, output_dimension, bias=qkv_bias)
        
    def forward(self, x: torch.Tensor):
        queries = self.weight_query(x)
        keys = self.weight_key(x)
        values = self.weight_value(x)
        
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attention_weights @ values
        return context_vector
    

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(input_dimension, output_dimension)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


In [67]:
sa_v2.named_parameters()

<generator object Module.named_parameters at 0x14bd1ac40>

In [68]:
queries = sa_v2.weight_query(inputs)
keys = sa_v2.weight_key(inputs)
# values = sa_v2.weight_value(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [69]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [70]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [71]:
row_sums = masked_simple.sum(dim=-1, keepdims=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)
