In [2]:
import torch 
# Check PyTorch version 
print(f'PyTorch version: {torch.__version__}')

# Check if CUDA is available and which version
print(f'CUDA available: {torch.cuda.is_available()}')
print(f'CUDA version: {torch.version.cuda if torch.cuda.is_available() else "Not available"}')


PyTorch version: 2.5.1
CUDA available: True
CUDA version: 11.8


In [3]:
from importlib.metadata import version 
print(f'torch version: {version("torch")}')


torch version: 2.5.1


In [4]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

print(inputs.shape)

torch.Size([6, 3])


In [7]:
# Step 1 calculate attention score 

query = inputs[1] # 2nd input token is the query 

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs): # enumerate iterates over the first dimension 
    # print(f'x_i = {x_i}')
    attn_scores_2[i] = torch.dot(x_i, query)

print(attn_scores_2)



tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [8]:
# Normalize attention scores 

attn_weights_2_tmp = attn_scores_2/attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())


Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [9]:
# Calculate attention weights using softmax 

def softmax_naive(x):
    return torch.exp(x)/torch.exp(x).sum(dim=0) # because we are going to softmax a 1D tensor, the dimension doesn't matter 
    # Usually, for 2D tensor, we use dim=1.

attn_weights_2_naive = softmax_naive(attn_scores_2)

print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())




Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [11]:
# In practice we use pytorch's softmax function 

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())



Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [14]:
# Calculate context vector 

query = inputs[1] # 2nd input token is the query 

context_vec_2 = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
    # print(f'x_i = {x_i} and attn_weights_i = {attn_weights_2[i]}')
    context_vec_2 += attn_weights_2[i] * x_i

print("Context vector:", context_vec_2)







Context vector: tensor([0.4419, 0.6515, 0.5683])


In [15]:
# Calculate Context vector for all query tokens 

# attention scores for all query tokens 
attn_scores = torch.zeros(inputs.shape[0], inputs.shape[0])

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i,j] = torch.dot(x_i, x_j) # attn_score should be a symmetric matrix
print(attn_scores)





tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [17]:
attn_scores = input @ input.T
print(attn_scores)

print(f'tensor shape = {attn_scores.shape}')

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
tensor shape = torch.Size([6, 6])


In [21]:
attn_weights = torch.softmax(attn_scores, dim = 1) # softmax over the column dimension 
print(attn_weights)

# Quick verification that each row sums to 1 
for i, x_i in enumerate(attn_weights):
    print(f'row{i} = {sum(x_i)}')


tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
row0 = 0.9999999403953552
row1 = 1.0
row2 = 1.0
row3 = 1.0
row4 = 1.0000001192092896
row5 = 1.0


In [32]:
# Copmute Context vector
all_context_vecs = attn_weights @ inputs 
print(all_context_vecs)

print(all_context_vecs[1,:])
# Check with previous computed second context vector 
print(context_vec_2)

assert torch.equal(all_context_vecs[1,:], context_vec_2), "Tensor are not equal"


tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])
tensor([0.4419, 0.6515, 0.5683])
tensor([0.4419, 0.6515, 0.5683])


# Implement Self-Attention with trainable weights

In [36]:
x_2 = inputs[1] # second input element 
print(x_2.shape)
d_in = inputs.shape[1]
print(d_in)

d_out = 2  # manually set output embedding size, d = 2 




torch.Size([3])
3


In [39]:
# Goal compute context vector 

torch.manual_seed(123)

# torch.nn.Parameter is a special kind of Tensor that is used to store the parameters of a model 
# requires_grad = False means that the parameter is not trainable, meaning it's fixed and freezed
W_query = torch.nn.Parameter(torch.rand(d_in,d_out), requires_grad = False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)

# Compute the query, key, and value vectors 
query_2 = x_2 @ W_query 
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value 

print(f'query_2 = {query_2}')
print(f'key_2 = {key_2}')


keys = inputs @ W_key 
values = inputs @ W_value 

print(f'keys = {keys}')
print(f'values = {values}')





query_2 = tensor([0.4306, 1.4551])
key_2 = tensor([0.4433, 1.1419])
keys = tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])
values = tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])


In [49]:
# Compute unnormalized attention scores for query 2

attn_scores_22 = query_2.dot(key_2)  
print(attn_scores_22)


attn_scores_2 = query_2@(keys.T)
print(f'attn_scores_2 = {attn_scores_2}')

d_k = keys.shape[1]
attn_weights_2 = torch.softmax(attn_scores_2/d_k**0.5, dim = 0) # 1D tensor only has one dimension 
print(f'attn_weights_2 = {attn_weights_2}')

context_vec_2 = attn_weights_2 @ values 
print(f'context_vec_2 = {context_vec_2}')


tensor(1.8524)
attn_scores_2 = tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
attn_weights_2 = tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
context_vec_2 = tensor([0.3061, 0.8210])


In [51]:
# Implementing a compact SelfAttention Class 
import torch.nn as nn 

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self,x):
        query = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value 
        attn_scores = query @ keys.T/(query.shape[1]**0.5)
        attn_weights = torch.softmax(attn_scores, dim = -1)
        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123) 
sa_v1 = SelfAttention_v1(d_in = 3, d_out = 2)
print(sa_v1.forward(inputs))




tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [57]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
    
    def forward(self, x):
        query = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = query @ keys.T / keys.shape[-1]**0.5
        attn_weights = torch.softmax(attn_scores, dim = -1)
        context_vec = attn_weights @ values 
        return context_vec 
    

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs)) # we don't use sa_v2.forward(inputs)
# because we are using the __call__ method inside nn.Module which by default calls the forward method: when we use sa_v2() it will trigger the __call__ method.


tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


# Hiding Future words with Causal Attention 

In [62]:
attn_scores = sa_v2.W_query(inputs) @ sa_v2.W_key(inputs).T

context_length = attn_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1)
print(mask)

attn_scores_causal = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(attn_scores_causal)

attn_weights_causal = torch.softmax(attn_scores_causal / (keys.shape[-1]**0.5), dim = -1)
print(attn_weights_causal)





tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [66]:
# Apply Dropouts 
torch.manual_seed(123)

dropout = torch.nn.Dropout(p = 0.5)
example = torch.ones(6,6)
print(dropout(example))


print(dropout(attn_weights_causal))









tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.0000, 0.4638, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.3968, 0.3775, 0.3941, 0.0000],
        [0.3869, 0.3327, 0.0000, 0.0000, 0.3331, 0.3058]],
       grad_fn=<MulBackward0>)


In [67]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) # 2 inputs with 6 tokens each, and each token has embedding dimension 3

torch.Size([2, 6, 3])


In [81]:
# Review masked_fill_() method to understand the code


"""  
Basic Syntax and Function
pythontensor.masked_fill_(mask, value)
This method:

Modifies tensor in-place (that's what the underscore _ indicates)
Replaces elements with value wherever the corresponding position in mask is True
Leaves elements unchanged wherever mask is False
"""

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout = 0.5, qkv_bias = False): # context_length is the maximum sequence length the model supports
        super().__init__()
        self.d_out = d_out 
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # register_buffer() is used for tensors that need to be asved with the model but aren't trainable parameters. 
        # Computing the causal mask repeatedly every forward pass would be wasteful, so by registering it as a buffer it's created only once during initialization and unlike weights that need to be trained, the mask is constant. 
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal = 1)) 
        # now a tensor attribute named self.mask is added to the class instance

    def forward(self, x):
        b, num_tokens, d_in = x.shape 
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2) # Explanations below
        # Because self.mask has a fixed shape of [context_length, context_length], but attn_scores has shape [batch_size, num_tokens, num_tokens],
        # we want to slice the full mask to match the current tensor being masked (here is the attn_scores)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values 
        return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in = d_in, d_out= d_out, context_length = context_length, dropout=0.0)
context_vec = ca(batch) 
print(context_vec)





"""

Original keys tensor:

Shape: [batch_size, sequence_length, d_k]
Example with batch_size=2, sequence_length=3, d_k=4:

keys = [
    # Batch item 1
    [
        [k11, k12, k13, k14],  # Key vector for token 1
        [k21, k22, k23, k24],  # Key vector for token 2
        [k31, k32, k33, k34]   # Key vector for token 3
    ],
    
    # Batch item 2
    [
        [m11, m12, m13, m14],  # Key vector for token 1
        [m21, m22, m23, m24],  # Key vector for token 2
        [m31, m32, m33, m34]   # Key vector for token 3
    ]
]


After keys.transpose(1,2):

Shape: [batch_size, d_k, sequence_length]
Transformed tensor:

transposed_keys = [
    # Batch item 1
    [
        [k11, k21, k31],  # First dimension of all key vectors
        [k12, k22, k32],  # Second dimension of all key vectors
        [k13, k23, k33],  # Third dimension of all key vectors
        [k14, k24, k34]   # Fourth dimension of all key vectors
    ],
    
    # Batch item 2
    [
        [m11, m21, m31],  # First dimension of all key vectors
        [m12, m22, m32],  # Second dimension of all key vectors
        [m13, m23, m33],  # Third dimension of all key vectors
        [m14, m24, m34]   # Fourth dimension of all key vectors
    ]
]

"""


tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


'\n\nOriginal keys tensor:\n\nShape: [batch_size, sequence_length, d_k]\nExample with batch_size=2, sequence_length=3, d_k=4:\n\nkeys = [\n    # Batch item 1\n    [\n        [k11, k12, k13, k14],  # Key vector for token 1\n        [k21, k22, k23, k24],  # Key vector for token 2\n        [k31, k32, k33, k34]   # Key vector for token 3\n    ],\n    \n    # Batch item 2\n    [\n        [m11, m12, m13, m14],  # Key vector for token 1\n        [m21, m22, m23, m24],  # Key vector for token 2\n        [m31, m32, m33, m34]   # Key vector for token 3\n    ]\n]\n\n\nAfter keys.transpose(1,2):\n\nShape: [batch_size, d_k, sequence_length]\nTransformed tensor:\n\ntransposed_keys = [\n    # Batch item 1\n    [\n        [k11, k21, k31],  # First dimension of all key vectors\n        [k12, k22, k32],  # Second dimension of all key vectors\n        [k13, k23, k33],  # Third dimension of all key vectors\n        [k14, k24, k34]   # Fourth dimension of all key vectors\n    ],\n    \n    # Batch item 2\n 

Tensor Concatenation Along Different Dimensions Visually Explained

In [None]:

""" 
Tensor Concatenation Along Different Dimensions Visually Explained

Let's visualize concatenation with a simple example. Assume:

2 attention heads
Batch size = 2
Sequence length = 3
Input dimension (d_in) = 4
Output dimension per head (d_out) = 2


Input tensor: [2, 3, 4] (batch_size, sequence_length, d_in)

[
  # Batch item 1
  [
    [i111, i112, i113, i114],  # Token 1 embedding
    [i121, i122, i123, i124],  # Token 2 embedding
    [i131, i132, i133, i134]   # Token 3 embedding
  ],
  
  # Batch item 2
  [
    [i211, i212, i213, i214],  # Token 1 embedding
    [i221, i222, i223, i224],  # Token 2 embedding
    [i231, i232, i233, i234]   # Token 3 embedding
  ]
]


Each head outputs: [2, 3, 2] (batch_size, sequence_length, d_out)
# Output from Head 1
[
  # Batch item 1
  [
    [h1_111, h1_112],  # Token 1 output
    [h1_121, h1_122],  # Token 2 output
    [h1_131, h1_132]   # Token 3 output
  ],
  
  # Batch item 2
  [
    [h1_211, h1_212],  # Token 1 output
    [h1_221, h1_222],  # Token 2 output
    [h1_231, h1_232]   # Token 3 output
  ]
]

# Output from Head 2
[
  # Batch item 1
  [
    [h2_111, h2_112],  # Token 1 output
    [h2_121, h2_122],  # Token 2 output
    [h2_131, h2_132]   # Token 3 output
  ],
  
  # Batch item 2
  [
    [h2_211, h2_212],  # Token 1 output
    [h2_221, h2_222],  # Token 2 output
    [h2_231, h2_232]   # Token 3 output
  ]
]


Concatenation along different dimensions:
1. Along dimension -1 (or 2) - Feature dimension (what's used in code):
Output shape: [2, 3, 4] (batch_size, sequence_length, d_out*num_heads)

[
  # Batch item 1
  [
    [h1_111, h1_112, h2_111, h2_112],  # Token 1 features from both heads
    [h1_121, h1_122, h2_121, h2_122],  # Token 2 features from both heads
    [h1_131, h1_132, h2_131, h2_132]   # Token 3 features from both heads
  ],
  
  # Batch item 2
  [
    [h1_211, h1_212, h2_211, h2_212],  # Token 1 features from both heads
    [h1_221, h1_222, h2_221, h2_222],  # Token 2 features from both heads
    [h1_231, h1_232, h2_231, h2_232]   # Token 3 features from both heads
  ]
]

2. Along dimension 0 - Batch dimension:
Output shape: [4, 3, 2] (batch_size*num_heads, sequence_length, d_out)

[
  # Batch 1, Head 1
  [
    [h1_111, h1_112],
    [h1_121, h1_122],
    [h1_131, h1_132]
  ],
  
  # Batch 2, Head 1
  [
    [h1_211, h1_212],
    [h1_221, h1_222],
    [h1_231, h1_232]
  ],
  
  # Batch 1, Head 2
  [
    [h2_111, h2_112],
    [h2_121, h2_122],
    [h2_131, h2_132]
  ],
  
  # Batch 2, Head 2
  [
    [h2_211, h2_212],
    [h2_221, h2_222],
    [h2_231, h2_232]
  ]
]

3. Along dimension 1 - Sequence dimension:
Output shape: [2, 6, 2] (batch_size, sequence_length*num_heads, d_out)

[
  # Batch item 1
  [
    [h1_111, h1_112],  # Token 1, Head 1
    [h1_121, h1_122],  # Token 2, Head 1
    [h1_131, h1_132],  # Token 3, Head 1
    [h2_111, h2_112],  # Token 1, Head 2
    [h2_121, h2_122],  # Token 2, Head 2
    [h2_131, h2_132]   # Token 3, Head 2
  ],
  
  # Batch item 2
  [
    [h1_211, h1_212],  # Token 1, Head 1
    [h1_221, h1_222],  # Token 2, Head 1
    [h1_231, h1_232],  # Token 3, Head 1
    [h2_211, h2_212],  # Token 1, Head 2
    [h2_221, h2_222],  # Token 2, Head 2
    [h2_231, h2_232]   # Token 3, Head 2
  ]
]


维度2（特征维度）拼接的优势
保持数据结构的语义完整性

批次大小（batch_size）保持不变：每个样本的独立性得到保持
序列长度（sequence_length）保持不变：每个token的位置信息得到保持
只是增加了每个token的特征维度：将不同注意力头的特征信息组合在一起
直观的特征表示

对于每个token，我们得到的是所有注意力头的特征的组合
相当于每个token都获得了来自不同注意力头的"视角"或"表示"

Preserves Data Structural Semantics

Batch size remains unchanged: maintains sample independence
Sequence length remains unchanged: preserves token position information
Only increases feature dimension: combines information from different attention heads
Intuitive Feature Representation

For each token, we get a combination of features from all attention heads
Each token receives "perspectives" or "representations" from different attention heads
"""

In [None]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

In [84]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)]
        )
    def forward(self, x):
        context_vecs = torch.cat([head(x) for head in self.heads], dim=-1)
        return context_vecs

torch.manual_seed(123)

context_length = batch.shape[1] # number of tokens 
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.00, num_heads = 2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)



tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


In [None]:
""" 
MultiHeadAttentionWrapper
└── self.heads (nn.ModuleList)
    ├── self.heads[0] (CausalAttention)
    │   ├── self.W_query (nn.Linear)
    │   │   ├── weight (nn.Parameter)
    │   │   └── bias (nn.Parameter) [if qkv_bias=True]
    │   ├── self.W_key (nn.Linear)
    │   │   ├── weight (nn.Parameter)
    │   │   └── bias (nn.Parameter) [if qkv_bias=True]
    │   ├── self.W_value (nn.Linear)
    │   │   ├── weight (nn.Parameter)
    │   │   └── bias (nn.Parameter) [if qkv_bias=True]
    │   ├── self.dropout (nn.Dropout)
    │   └── self.mask (registered buffer)
    ├── self.heads[1] (CausalAttention)
    │   [Same structure as heads[0]]
    └── [Additional heads...]
"""


In [86]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out, self.num_heads = d_out, num_heads 
        self.head_dim = self.d_out // self.num_heads

        self.W_query = nn.Linear(d_in, d_out, qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask", 
            torch.triu(torch.ones(context_length, context_length), diagonal = 1)
            )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape 
        
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Because now queries, key, values have shape (batch_size, num_tokens, d_out)
        # We want to split them into num_heads, each with shape (batch_size, num_tokens, num_heads, head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # Now we want to make them more intuitive by having the shape (batch_size, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)

        attn_scores = queries @ keys.transpose(2,3)
        
        # Use the mask: first truncate the mask to have the same shape as attn_scores, then convert them into boolean so we can use masked_fill_() method 
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 

        context_vec = attn_weights @ values
        # Now the context_vec has shape (batch_size, attn_heads, num_tokens, head_dim)
        # We want to convert it to shape (batch_size, num_tokens, num_heads, head_dim) so we can combine the last dimensions back to d_out
        context_vec = context_vec.transpose(1,2)
        # Now we want to combine the last two dimension to get the final shape (batch_size, num_tokens, d_out)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # we use .contiguous() to make sure the memory is contiguous because .view() requires contiguousu memory
        context_vec = self.out_proj(context_vec)
        return context_vec 

torch.manual_seed(123)
batch_size, num_tokens, d_in = batch.shape
d_out = 2

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads = 2, qkv_bias = False)
context_vecs = mha(batch)
print(context_vecs)

print("context_vecs.shape:", context_vecs.shape)






""" 
Without Output Projection:
Head 1 output: [0.5, 0.8]
Head 2 output: [0.3, 0.2]

Final output after concatenation: [0.5, 0.8, 0.3, 0.2]



With Output Projection:
Head 1 output: [0.5, 0.8]
Head 2 output: [0.3, 0.2]

Concatenated: [0.5, 0.8, 0.3, 0.2]

Weights matrix example:
W_out = [[0.1, 0.7, 0.2, 0.0],
         [0.5, 0.1, 0.4, 0.3],
         [0.0, 0.3, 0.9, 0.2],
         [0.2, 0.5, 0.1, 0.8]]

Final output after projection:
[0.34, 0.57, 0.54, 0.69]

The gist is that output projection gives more flexible representation and the final representation can distribute information optimally.


The final output is calculated by multiplying the concatenated vector by the weight matrix:
Output[0] - First output dimension:
(0.5 × 0.1) + (0.8 × 0.7) + (0.3 × 0.2) + (0.2 × 0.0) = 0.05 + 0.56 + 0.06 + 0 = 0.67
Output[1] - Second output dimension:
(0.5 × 0.5) + (0.8 × 0.1) + (0.3 × 0.4) + (0.2 × 0.3) = 0.25 + 0.08 + 0.12 + 0.06 = 0.51
Output[2] - Third output dimension:
(0.5 × 0.0) + (0.8 × 0.3) + (0.3 × 0.9) + (0.2 × 0.2) = 0 + 0.24 + 0.27 + 0.04 = 0.55
Output[3] - Fourth output dimension:
(0.5 × 0.2) + (0.8 × 0.5) + (0.3 × 0.1) + (0.2 × 0.8) = 0.1 + 0.4 + 0.03 + 0.16 = 0.69
So the correct final output is [0.67, 0.51, 0.55, 0.69].
What "Blending Information Across Heads" Means
Looking at the calculations above, we can see this blending happening:
Example: Output[2] = 0.55

Receives 0% contribution from first dimension of Head 1 (weight = 0.0)
Receives 24% contribution from second dimension of Head 1 (0.8 × 0.3 = 0.24)
Receives 49% contribution from first dimension of Head 2 (0.3 × 0.9 = 0.27)
Receives 7% contribution from second dimension of Head 2 (0.2 × 0.2 = 0.04)

This shows how the third dimension of the output (0.55) is primarily being influenced by information from Head 2's first dimension and Head 1's second dimension. The model is learning to extract and combine specific features from different attention heads.

"""


tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


' \nWithout Output Projection:\nHead 1 output: [0.5, 0.8]\nHead 2 output: [0.3, 0.2]\n\nFinal output after concatenation: [0.5, 0.8, 0.3, 0.2]\n\n\n\nWith Output Projection:\nHead 1 output: [0.5, 0.8]\nHead 2 output: [0.3, 0.2]\n\nConcatenated: [0.5, 0.8, 0.3, 0.2]\n\nWeights matrix example:\nW_out = [[0.1, 0.7, 0.2, 0.0],\n         [0.5, 0.1, 0.4, 0.3],\n         [0.0, 0.3, 0.9, 0.2],\n         [0.2, 0.5, 0.1, 0.8]]\n\nFinal output after projection:\n[0.34, 0.57, 0.54, 0.69]\n\nThe gist is that output projection gives more flexible representation and the final representation can distribute information optimally.\n\n\nThe final output is calculated by multiplying the concatenated vector by the weight matrix:\nOutput[0] - First output dimension:\n(0.5 × 0.1) + (0.8 × 0.7) + (0.3 × 0.2) + (0.2 × 0.0) = 0.05 + 0.56 + 0.06 + 0 = 0.67\nOutput[1] - Second output dimension:\n(0.5 × 0.5) + (0.8 × 0.1) + (0.3 × 0.4) + (0.2 × 0.3) = 0.25 + 0.08 + 0.12 + 0.06 = 0.51\nOutput[2] - Third output dime