## Steps for a simple self-attention mechanism
1. Get input vectors for each token
2. Get the dot product of your specific token, this is called the attention score
3. Normalize the weights from your attention score using softmax(or a different normalization technique)
4. Get the context vector of your specific token by using scalar multiplication

In [2]:
# Import libraries
import torch

In [3]:
# Input vectors for "Your journey starts with one step"

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55]] # step (x^6)
)

In [4]:
inputs.shape

torch.Size([6, 3])

In [5]:
print(inputs.shape[0])

6


In [6]:
print(torch.empty(inputs.shape[0]))

tensor([1.3910e+10, 4.2117e-41, 1.3910e+10, 4.2117e-41, 1.3910e+10, 4.2117e-41])


In [7]:
# Sample dot product of "Your" and "Journey"
# .43 * .55 + .15 * .87 + .89 * .66 = 0.2365 + 0.1305 + 0.5874 = 0.9544

dot_product = torch.dot(inputs[0], inputs[1])
print(f'Dot Product of "Your" and "Journey": {dot_product.item()}')

Dot Product of "Your" and "Journey": 0.9544000625610352


In [8]:
# Generating attention score for inputs[1] which is "journey"
# To generate an attention score for a given query vector, we compute the dot product

query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [9]:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    print(i, x_i)
    attn_scores_2[i] = torch.dot(x_i, query)
    print(attn_scores_2[i])
print(attn_scores_2)

0 tensor([0.4300, 0.1500, 0.8900])
tensor(0.9544)
1 tensor([0.5500, 0.8700, 0.6600])
tensor(1.4950)
2 tensor([0.5700, 0.8500, 0.6400])
tensor(1.4754)
3 tensor([0.2200, 0.5800, 0.3300])
tensor(0.8434)
4 tensor([0.7700, 0.2500, 0.1000])
tensor(0.7070)
5 tensor([0.0500, 0.8000, 0.5500])
tensor(1.0865)
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [10]:
# Normalize weights for "journey"
# This type of normalization is called "Linear Normalization" or "Min-max Scaling"
# Each token has an attention weight that sums to 1
# The attention weight vector has the same length as the number of tokens
# Each number in the vector corresponds to the attention weight for that token, so the first number is the attention weight for "Your", the second for "journey", and so on
# Linear Normalization is not preferred because it does not handle negative scores well and can lead to skewed distributions

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights using Linear Normalization:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights using Linear Normalization: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [11]:
attn_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [12]:
# Softmax Normalization for "journey"
# Softmax is used to convert raw attention scores into probabilities that sum to 1
# This is better than Linear Normalization because:
#  it emphasizes the highest scores more strongly
#  it ensures all weights are positive
#  it provides a smoother gradient

# Softmax function
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights using naive softmax:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights using naive softmax: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [13]:
# Using PyTorch's built-in softmax function which is optimized and numerically stable
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights from PyTorch's softmax:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights from PyTorch's softmax: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [14]:
# Generating context vector for "journey"
# The context vector is a weighted sum of the input vectors, where the weights are the attention weights
# Context Vector = Î£ (attention_weight_i * input_vector_i) which is a scalar-vector multiplication followed by a summation
# scalar multiplication = attention_weight_i * input_vector_i
# summation = sum(scalar multiplication)

query = inputs[1]

context_vec_2 = torch.zeros(query.shape)

for i,x_i in enumerate(inputs):
    print(context_vec_2)

    print("\n\n\n")
    print(f"{attn_weights_2[i]}   {x_i}")
    context_vec_2 += attn_weights_2[i]*x_i
print("Context vector for 'journey':", context_vec_2)

tensor([0., 0., 0.])




0.13854756951332092   tensor([0.4300, 0.1500, 0.8900])
tensor([0.0596, 0.0208, 0.1233])




0.2378913015127182   tensor([0.5500, 0.8700, 0.6600])
tensor([0.1904, 0.2277, 0.2803])




0.23327402770519257   tensor([0.5700, 0.8500, 0.6400])
tensor([0.3234, 0.4260, 0.4296])




0.12399158626794815   tensor([0.2200, 0.5800, 0.3300])
tensor([0.3507, 0.4979, 0.4705])




0.10818186402320862   tensor([0.7700, 0.2500, 0.1000])
tensor([0.4340, 0.5250, 0.4813])




0.15811361372470856   tensor([0.0500, 0.8000, 0.5500])
Context vector for 'journey': tensor([0.4419, 0.6515, 0.5683])


### Now that we've gone through the steps for the token "journey", this is how you compute the whole thing at once

In [15]:
# Calculate attention scores for all queries

attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [16]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [17]:
inputs.T

tensor([[0.4300, 0.5500, 0.5700, 0.2200, 0.7700, 0.0500],
        [0.1500, 0.8700, 0.8500, 0.5800, 0.2500, 0.8000],
        [0.8900, 0.6600, 0.6400, 0.3300, 0.1000, 0.5500]])

In [18]:
# Alternatively, we can compute all attention scores at once using matrix multiplication

attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [19]:
attn_scores.shape

torch.Size([6, 6])

In [20]:
# Get attention weights for all queries using softmax, all rows should sum up to 1
# dim=-1 indicates that softmax is applied across the last dimension (i.e., across each row)

attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

print("All row sums:", attn_weights.sum(dim=-1))

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [21]:
# Generating context vectors for all queries at once

all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [22]:
andy = torch.matmul(attn_weights, inputs)
andy

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

## Now we move on to making the self-attention mechanism trainable

In [23]:
# Define input and output dimensions, again we're using inputs[1] which is "journey" as our example

# x_2 is the input vector for "journey"
x_2 = inputs[1]

# The dimension of "journey" which is 3
d_in = inputs.shape[1]

# Why do we set the dimension to 2 instead of 3?
d_out = 2

In [24]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [44]:
# Initialize random weight matrices[3,2] for Query, Key, and Value
# requires_grad=False to indicate these are fixed for this example, if we were training a model, we would set requires_grad=True
# Since we set the seed manually, we will get the same random numbers every time for our weights

torch.manual_seed(123)

# Why are weights randomized?
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# @ symbol denotes matrix multiplication in PyTorch
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

query_2

tensor([0.4306, 1.4551])

In [45]:
W_query.shape

torch.Size([3, 2])

In [48]:
query_2

tensor([0.4306, 1.4551])

In [46]:
W_key.shape

torch.Size([3, 2])

In [49]:
key_2

tensor([0.4433, 1.1419])

In [47]:
W_value.shape

torch.Size([3, 2])

In [51]:
value_2.shape

torch.Size([2])

In [52]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [53]:
# Generate keys and values for all inputs at once

keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])

In [55]:
# Get attention score for query_2 against all keys

keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor(1.8524)
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [56]:
# Next, we have to normalize these attention scores to get attention weights
# We will use softmax normalization

# d_k is the dimension of the key vectors
d_k = keys.shape[1]

print(d_k)

# 
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print("The attention weights are:",attn_weights_2)

attn_weights_2.sum()

2
The attention weights are: tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


tensor(1.)

In [57]:
# We now compute the context vector for query_2 using the attention weights and the value vectors

context_vec_2 = attn_weights_2 @ values
print("The context vector is:",context_vec_2)

The context vector is: tensor([0.3061, 0.8210])


## Implementing a trainable self-attention but with every token at once

In [30]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = inputs @ W_key
        values = inputs @ W_value
        queries = inputs @ W_query

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec
    
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in=3, d_out=2)
output_v1 = sa_v1(inputs)

# We can double check if our self attention works by comparing output_v1[1] with context_vec_2 "journey"
print("Output from SelfAttention_v1:", output_v1)

Output from SelfAttention_v1: tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


In [58]:
context_vec_2

tensor([0.3061, 0.8210])

In [60]:
import torch.nn as nn


# Improved Self-Attention with Linear Layers
# It's better to use nn.Linear layers for the weight matrices because they handle biases and initialization more effectively
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = self.W_query(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec
    
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in=3, d_out=2, qkv_bias=False)
output_v2 = sa_v2(inputs)
print("Output from SelfAttention_v2:", output_v2)

Output from SelfAttention_v2: tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


## Causal attention

### Also known as masked attention is used to "mask" the next tokens. It is used for a more "creative" output.

In [61]:
# We are going to reuse the query and key weight from SelfAttention_v2 for convenience

# queries and keys is a weight matrix
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)
print("Attention weights from reused weights:\n", attn_weights)

Attention weights from reused weights:
 tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [62]:
# We will now make a "mask" using `tril` to prevent attention to future tokens

context_length = attn_weights.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
print("We have a diagonal of zeros which are future tokens that should not be attended to.")

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
We have a diagonal of zeros which are future tokens that should not be attended to.


In [63]:
# I don't get this part

masked_simple = attn_weights * mask_simple
print("Masked attention weights:\n", masked_simple)

Masked attention weights:
 tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [66]:
# Why is keepdim=True here? What does it do?
# dim=-1 means that the summation is done across the last dimension (i.e., across each row)

row_sums = masked_simple.sum(dim=-1, keepdim=True)
print(row_sums)
masked_simple_norm = masked_simple / row_sums
print("Masked and normalized attention weights:\n", masked_simple_norm)
print("Each row sums to one:", masked_simple_norm.sum(dim=-1))

tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)
Masked and normalized attention weights:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)
Each row sums to one: tensor([1., 1., 1., 1., 1., 1.], grad_fn=<SumBackward1>)


In [36]:
# When negative infinity are present, softmax treats them as zeros

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print("Masked attention scores with -inf:\n", masked)

Masked attention scores with -inf:
 tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)


In [37]:
# Apply softmax to get masked attention weights

attn_weights = torch.softmax(masked / keys.shape[1]**0.5, dim=-1)
print("Masked attention weights after softmax:\n", attn_weights)

Masked attention weights after softmax:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


## Dropout
### is a technique where randomly selected hidden layer units are ignored during training (these layers are "dropped out")
We do this technique to help prevent overfitting by ensuring that the model does not become reliant on any specific layer.
Dropout is ONLY used during training and is disabled afterwards.
Dropout is typically applied at 2 times:
1. After calculating attention weights
2. After applying the attention weights to the value vectors

In [38]:
# We will use a high dropout rate of 50% (which means masking half of the attention weights) to demonstrate the effect

# Dropout example
torch.manual_seed(123)
dropout_rate = 0.5
dropout = torch.nn.Dropout(dropout_rate)
example = torch.ones(6, 6)
dropped = dropout(example)
print("Example after applying dropout:\n", dropped)

Example after applying dropout:
 tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [39]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6380, 0.6816, 0.6804, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5090, 0.5085, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4120, 0.0000, 0.3869, 0.0000, 0.0000],
        [0.0000, 0.3418, 0.3413, 0.3308, 0.3249, 0.0000]],
       grad_fn=<MulBackward0>)


## Implementing a compact causal attention to use for the multi-head attention

In [40]:
batch = torch.stack((inputs, inputs), dim=0)

# Batch shape is [2,6,3] because two input texts with six tokens each where each token is a three-dimensional embedding vector
print("Batch shape:", batch.shape)

Batch shape: torch.Size([2, 6, 3])


In [41]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
        dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_length, context_length),
        diagonal=1)
    )
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec

In [42]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


## Multi-head attention
### We will be dividing the attention mechanism into multiple "heads" which each operate independently.

In [None]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(
                 d_in, d_out, context_length, dropout, qkv_bias
             ) 
             for _ in range(num_heads)]
        )
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2

mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, 
                 context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads   
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)   
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)        
        queries = self.W_query(x)   
        values = self.W_value(x)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)      
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)  
        queries = queries.view(                                             
            b, num_tokens, self.num_heads, self.head_dim                    
        )                                                                   
        keys = keys.transpose(1, 2)         
        queries = queries.transpose(1, 2)   
        values = values.transpose(1, 2)     
        attn_scores = queries @ keys.transpose(2, 3)  
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]   
  
        attn_scores.masked_fill_(mask_bool, -torch.inf)    
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = (attn_weights @ values).transpose(1, 2)  
        
        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )
        context_vec = self.out_proj(context_vec)   
        return context_vec