In [17]:
from importlib.metadata import version
import torch
import torch.nn.functional as F

print("torch version:", version("torch"))
print(torch.cuda.is_available() )

torch version: 2.7.1
True


In [2]:
a = torch.tensor([3.0, 4.0, 5.0, 6.0])

norm = torch.dot(a,a)

print(f"norm: {norm}")

norm: 86.0


In [5]:
a = torch.tensor(
    [[0, 1, 2, 3],
    [4, 5, 6, 7],
    [8, 9, 10, 11]]
)

print(a.shape)

torch.Size([3, 4])


In [8]:
torch.matmul(a, a.T)

tensor([[ 14,  38,  62],
        [ 38, 126, 214],
        [ 62, 214, 366]])

In [13]:
torch.dot(a[2], a[2])

tensor(366)

In [27]:
# (l, d)
x = torch.tensor(
    [[0, 1, 2, 3],
    [4, 5, 6, 7],
    [8, 9, 10, 11]]
)
x = x.float()

w = x @ x.T  # (l, l)
print(w.shape)   

w = F.softmax(w, dim=-1)  # (l, l)
print(w.shape)

x_prime = w @  x  # (l, d)
print(x_prime.shape)

torch.Size([3, 3])
torch.Size([3, 3])
torch.Size([3, 4])


In [25]:
print(x.dtype)

torch.float32


In [20]:
x = x.float()

In [21]:
x.dtype

torch.float32

In [23]:
w

tensor([[1.4252e-21, 3.7751e-11, 1.0000e+00],
        [0.0000e+00, 6.0546e-39, 1.0000e+00],
        [0.0000e+00, 0.0000e+00, 1.0000e+00]])

In [28]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

print(inputs.shape)

torch.Size([6, 3])


In [33]:
w = inputs @ inputs.T  # (l, l)
print(w.shape)   
print(w)  

s = F.softmax(w, dim=-1)  # (l, l)
print(s.shape)

x_prime = s @  inputs  # (l, d)
print(x_prime.shape)

print(f"x_prime:{x_prime}")

torch.Size([6, 6])
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
torch.Size([6, 6])
torch.Size([6, 3])
x_prime:tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [43]:
query = inputs[1]  # 2nd input token is the query
print(f"query: {query}")

attn_scores_2 = torch.empty(inputs.shape[0])
print(f"attn_scores_2: {attn_scores_2}")

for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors)

print(attn_scores_2)

query: tensor([0.5500, 0.8700, 0.6600])
attn_scores_2: tensor([3.3427e-40, 0.0000e+00, 1.3329e-08, 5.2439e-08, 8.9683e-44, 0.0000e+00])
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [44]:
attn_scores_2 / attn_scores_2.sum()

tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])

In [46]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)

print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [45]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [40]:
print(torch.exp(attn_scores_2).sum(dim=0))
print(torch.exp(attn_scores_2).sum())

tensor(18.7453)
tensor(18.7453)


In [47]:
query = inputs[1] # 2nd input token is the query

context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


In [48]:
print(inputs)

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


In [49]:
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [50]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [51]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [52]:
sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

1.0

In [56]:
attn_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [61]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [62]:
x_prime 

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

In [63]:
all_context_vecs == x_prime

tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True]])