In [67]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

device = torch.device("mps") if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


In [68]:
batch_size = 1
vocab_size = 6           # max_seq_len
embedding_size = 8       # d_model

In [69]:
torch.manual_seed(132)
positional_embed = torch.randn(batch_size, vocab_size, embedding_size)


PART 3: Self Attention & Multi-Head Attention

- Input embeddings capture meaning of the word; positonal encoding captures positional significance

- Now self-attention captures relation of words with each other (How?)

    - Queries Q: “What am I looking for?”
    - Keys K: “What do others offer?”
    - Values V: “What do I take from them?”

- Each of these are initialized as a separate weight matrix then multiplied with the original embedding matrix (with pe)

- To compute attention score:
    - calculate q.kT -> kT ensures that each token attends to every other token in the sequence
    - divide by sqrt(d_model) -> this scales the values to prevent overly large attention scores, which can lead to unstable gradients
    - take softmax -> normalizes attention score

- Multi-head attention is simply self attention applied parallely to all heads

In [70]:
vocab_size = 8
d_model = 8 # embedding_size
heads = 2 # change for multi-head attention

#dk, dv = 6, 6      # dk = d_model / number of heads 
dk = d_model // heads  
dv = d_model // heads  

In [71]:
wq = torch.randn(d_model, dk)  # size: 8, 4
wk = torch.randn(d_model, dk)
wv = torch.randn(d_model, dv)

wq.size()


torch.Size([8, 4])

In [72]:
q = positional_embed @ wq
q  # ([1, 6, 4])

tensor([[[ 0.7588, -0.5084, -3.7370,  2.5728],
         [ 2.2291, -1.9921, -0.0241,  0.0073],
         [ 2.6539,  4.3607, -3.9289, -1.5583],
         [-1.2607,  1.8096,  4.1857,  3.5735],
         [-0.2976,  0.2599, -0.8966,  3.9972],
         [ 2.5838,  0.4920, -1.7484,  5.4579]]])

In [73]:
k = positional_embed @ wk
k   # ([1, 6, 4])


tensor([[[-4.2813, -1.4716, -0.8473,  1.1307],
         [ 1.3028, -3.2291,  3.8855, -1.2693],
         [-5.6884,  1.1877,  0.5104,  1.7375],
         [ 0.3116, -3.2665,  1.4287,  1.8076],
         [ 0.0575, -1.8384,  2.4462,  2.0588],
         [-4.2674, -2.8822,  0.9956,  2.5251]]])

In [74]:
v = positional_embed @ wv
v

tensor([[[ 2.7858,  0.2538,  3.8358, -0.5133],
         [ 4.3383,  0.0158,  0.9234,  4.0296],
         [ 2.5303,  0.5589,  1.3418, -0.1750],
         [ 4.1207, -2.6188,  2.8017, -0.8283],
         [-0.6290,  1.1233,  3.1431, -2.9573],
         [ 4.9291, -0.6216,  4.2550,  0.6335]]])

In [75]:
v.size()

torch.Size([1, 6, 4])

In [76]:
qkT = q @ k.transpose(-2, -1)
qkT.size()  # size: [1, 6, 6] @ [6, 6, 1] = [1, 6, 5]

#qkT = torch.matmul(q, k.transpose(-2, -1))


torch.Size([1, 6, 6])

In [77]:
score = qkT / np.sqrt(dk)
score.size()    # [1, 6, 6]

torch.Size([1, 6, 6])

In [78]:
softmax = torch.softmax(score, dim = -1) # dim=-1 normalizes across keys (last dim), so each row sums to 1
softmax.size()

torch.Size([1, 6, 6])

In [79]:
softmax

tensor([[[5.9723e-01, 5.1160e-05, 3.0764e-02, 1.8298e-01, 2.3849e-02,
          1.6512e-01],
         [2.5808e-04, 7.0204e-01, 3.7513e-06, 2.5148e-01, 4.5167e-02,
          1.0501e-03],
         [2.9461e-01, 6.2736e-03, 6.4987e-01, 1.7608e-02, 3.1512e-02,
          1.2517e-04],
         [5.6472e-04, 9.3755e-04, 7.7077e-01, 2.4161e-03, 1.3603e-01,
          8.9283e-02],
         [8.7890e-02, 3.0153e-05, 2.8013e-01, 4.9009e-02, 6.4149e-02,
          5.1879e-01],
         [2.5369e-03, 5.1094e-05, 1.2667e-03, 5.3412e-01, 4.4564e-01,
          1.6383e-02]]])

In [80]:
weights = softmax @ v
weights.size()

torch.Size([1, 6, 4])

In [81]:
weights

tensor([[[ 3.2947, -0.3863,  3.6224, -0.4292],
         [ 4.0594, -0.5974,  1.5002,  2.4876],
         [ 2.5456,  0.4273,  2.1568, -0.3474],
         [ 2.3204,  0.5219,  1.8515, -0.4792],
         [ 3.6725, -0.1999,  3.2594,  0.0043],
         [ 2.0119, -0.9070,  2.9783, -1.7512]]])

In [82]:
row_sums = softmax.sum(dim=-1)  # sums along the last dimension
print("Row sums:", row_sums)

Row sums: tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]])


In [83]:
# org

class SelfAttention(nn.Module):

    def __init__(self, d_model, dk, dv):
        super().__init__()
        self.wq = nn.Linear(d_model, dk)  
        self.wk = nn.Linear(d_model, dk)  
        self.wv = nn.Linear(d_model, dv)  

    def forward(self, x):
        q = self.wq(x)  
        k = self.wk(x)  
        v = self.wv(x) 

        return self.scaled_dot_product_attention(q, k, v)

    def scaled_dot_product_attention(self, q, k, v):
        scores = torch.matmul(q, k.transpose(-2, -1))  # computes dot-product attention scores
        dk = q.size(-1)  # gets embedding dimension
        scores = scores / torch.sqrt(torch.tensor(float(dk)))  # normalizes scores

        weights = F.softmax(scores, dim=-1)  
        output = torch.matmul(weights, v)  # computes weighted sum 

        return output


In [84]:
class SelfAttention(nn.Module):

    def __init__(self, d_model, dk, dv):
        super().__init__()
        self.wq = nn.Linear(d_model, dk)  
        self.wk = nn.Linear(d_model, dk)  
        self.wv = nn.Linear(d_model, dv)  

    def forward(self, x):

        print(f"Input x shape: {x.shape}")  

        q = self.wq(x)  
        k = self.wk(x)  
        v = self.wv(x) 

        print(f"Query (q) shape: {q.shape}")  
        print(f"Key (k) shape: {k.shape}")  
        print(f"Value (v) shape: {v.shape}") 

        return self.scaled_dot_product_attention(q, k, v)

    def scaled_dot_product_attention(self, q, k, v):
        """
        q, k: [batch_size, seq_len, dk]
        v: [batch_size, seq_len, dv]
        """
        scores = torch.matmul(q, k.transpose(-2, -1))  
        print(f"Scores shape (q @ k.T): {scores.shape}")  #

        dk = q.size(-1) 
        scores = scores / torch.sqrt(torch.tensor(float(dk))) 
        print(f"Scaled scores shape: {scores.shape}") 

        weights = F.softmax(scores, dim=-1)  
        print(f"Softmax weights shape: {weights.shape}") 

        output = torch.matmul(weights, v)  
        print(f"Output shape (weights @ v): {output.shape}") 

        return output




In [85]:
model = SelfAttention(d_model = 8, dk = 6, dv = 6)

attention = model(positional_embed)

print("Output:", attention)
print("\nOutput shape:", attention.size())


Input x shape: torch.Size([1, 6, 8])
Query (q) shape: torch.Size([1, 6, 6])
Key (k) shape: torch.Size([1, 6, 6])
Value (v) shape: torch.Size([1, 6, 6])
Scores shape (q @ k.T): torch.Size([1, 6, 6])
Scaled scores shape: torch.Size([1, 6, 6])
Softmax weights shape: torch.Size([1, 6, 6])
Output shape (weights @ v): torch.Size([1, 6, 6])
Output: tensor([[[ 0.1561, -0.2973,  0.4194, -0.5169, -0.2889, -0.0047],
         [ 0.2791, -0.4766,  0.5885, -0.7013, -0.4492,  0.0928],
         [ 0.1080, -0.2189,  0.3061, -0.4617, -0.2391, -0.0028],
         [ 0.1573, -0.2738,  0.4379, -0.4548, -0.2977, -0.0793],
         [ 0.1876, -0.2317,  0.3680, -0.5485, -0.3321,  0.0698],
         [ 0.1673, -0.2950,  0.4363, -0.5147, -0.3023, -0.0130]]],
       grad_fn=<UnsafeViewBackward0>)

Output shape: torch.Size([1, 6, 6])
