In this section, we implement a simplified self-attention mechanism to
compute these weights and the resulting context vector one step at a time.

Consider the following input sentence, which has already been embedded
into 3-dimensional vectors as discussed in chapter 2. We choose a small
embedding dimension for illustration purposes to ensure it fits on the page
without line breaks

In [1]:
import torch 
inputs = torch.tensor([
    [0.43, 0.15, 0.89], # Your     (X^1)
    [0.55, 0.87, 0.66], # journey  (X^2)
    [0.57, 0.85, 0.64], # starts   (X^3)
    [0.22, 0.58, 0.33], # with     (X^4)
    [0.77, 0.25, 0.10], # one      (X^5)
    [0.05, 0.80, 0.55]  # step     (X^6)
])


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "c:\Users\USER\Documents\M Martin\code LLMs From Scratch\.torch_2_0_1\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "c:\Users\USER\Documents\M Martin\code LLMs From Scratch\.torch_2_0_1\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "c:\Users\USER\Documents\M Martin\code LLMs From Scratch\.torch_2_0_1\L

The first step of implementing self-attention is to compute the intermediate
values ω, referred to as attention scores, as illustrated in Figure 3.8 (P72)

Figure 3.8 The overall goal of this section is to illustrate the computation of the context vector `z(2)`
using the second input sequence, `x(2)` as a query. This figure shows the first intermediate step,
computing the attention scores `ω` between the query `x(2)` and all other input elements as a dot
product. (Note that the numbers in the figure are truncated to one digit after the decimal point to
reduce visual clutter.)

### Understanding dot products

In [2]:
# We determine these scores by computing the dot product of the query, x(2), with every other input token:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


A dot product is essentially just a concise way of multiplying two vectors
element-wise and then summing the products, which we can demonstrate as
follows:

In [3]:
res = 0

for idx, element in enumerate(inputs[0]):
    res += inputs[0][idx] * query[idx]

print(res)
print(torch.dot(inputs[0],query))

tensor(0.9544)
tensor(0.9544)


A dot product is essentially just a concise way of multiplying two vectors
element-wise and then summing the products, which we can demonstrate as
follows:

In the next step, as shown in Figure 3.9, we normalize each of the attention
scores that we computed previously.

"""After computing the attention scores ω21 to ω2T with respect to the input query x(2),
the next step is to obtain the attention weights α21 to α2T by normalizing the attention scores"""



The main goal behind the normalization shown in Figure 3.9 is to obtain
attention weights that sum up to 1. This normalization is a convention that is
useful for interpretation and for maintaining training stability in an LLM.
Here's a straightforward method for achieving this normalization step

In [4]:
attn_scores_2_tmp = attn_scores_2/attn_scores_2.sum()
print("Attention weights: ", attn_scores_2_tmp)
print("Sum: ", attn_scores_2_tmp.sum())

Attention weights:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum:  tensor(1.0000)


As the output shows, the attention weights now sum to 1

In practice, it's more common and advisable to use the softmax function for
normalization. This approach is better at managing extreme values and offers
more favorable gradient properties during training. Below is a basic
implementation of the softmax function for normalizing the attention scores:

In [5]:
def softmax_naive(x):
    return torch.exp(x)/torch.exp(x).sum()

attn_scores_2_naive = softmax_naive(attn_scores_2)
print("Attention weights: ", attn_scores_2_naive)
print("Sum: ", attn_scores_2_naive.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


In addition, the softmax function ensures that the attention weights are
always positive. This makes the output interpretable as probabilities or
relative importance, where higher weights indicate greater importance.

Note that this naive softmax implementation (softmax_naive) may encounter
numerical instability problems, such as overflow and underflow, when
dealing with large or small input values. Therefore, in practice, it's advisable
to use the PyTorch implementation of softmax, which has been extensively
optimized for performance:

In [6]:
attn_weights_2 = torch.softmax(attn_scores_2,dim=0)
print("Attention weights: ", attn_scores_2_naive)
print("Sum: ", attn_weights_2.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


Now that we computed the normalized attention weights, we are ready for the
final step illustrated in Figure 3.10: calculating the context vector z(2) by
multiplying the embedded input tokens, x(i), with the corresponding attention
weights and then summing the resulting vectors

The context vector z(2) depicted in Figure 3.10 is calculated as a weighted
sum of all input vectors. This involves multiplying each input vector by its
corresponding attention weight:

In [7]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    # print(x_i)
    # print(attn_weights_2[i])
    context_vec_2 += attn_weights_2[i]*x_i 
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


In the next section, we will generalize this procedure for computing context
vectors to calculate all context vectors simultaneously.

### Computing attention weights for all input tokens

We follow the same three steps as before, as summarized in Figure 3.12,
except that we make a few modifications in the code to compute all context
vectors instead of only the second context vector, z(2).

First, in step 1 as illustrated in Figure 3.12 P(77), we add an additional for-loop to
compute the dot products for all pairs of inputs.

In [8]:
attn_scores = torch.empty(6,6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)


tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


When computing the preceding attention score tensor, we used for-loops in
Python. However, for-loops are generally slow, and we can achieve the same
results using matrix multiplication:

In [9]:
attn_scores = inputs@inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In step 2, as illustrated in Figure 3.12, we now normalize each row so that the
values in each row sum to 1

In [10]:
attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


Before we move on to step 3, the final step shown in Figure 3.12, let's briefly
verify that the rows indeed all sum to 1:

In [11]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum: ", row_2_sum)
print("All row sums : ", attn_weights.sum(dim=1))

Row 2 sum:  1.0
All row sums :  tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In the third and last step, we now use these attention weights to compute all
context vectors via matrix multiplication

In [12]:
all_context_vecs = attn_weights@inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


We can double-check that the code is correct by comparing the 2nd row with
the context vector z(2) that we computed previously

In [13]:
print("previous 2nd context vector: ", context_vec_2)

previous 2nd context vector:  tensor([0.4419, 0.6515, 0.5683])


### Implementing self-attention with trainable weights

##### Computing the attention weights step by step

We will implement the self-attention mechanism step by step by introducing
the three trainable weight matrices Wq, Wk, and Wv. These three matrices are
used to project the embedded input tokens, x(i), into query, key, and value
vectors as illustrated in Figure 3.14.

Earlier in section 3.3.1, we defined the second input element x(2) as the query
when we computed the simplified attention weights to compute the context
vector z(2). Later, in section 3.3.2, we generalized this to compute all context
vectors z(1) ... z(T) for the six-word input sentence "Your journey starts with
one step."


Similarly, we will start by computing only one context vector, z(2), for
illustration purposes. In the next section, we will modify this code to
calculate all context vectors.


Let's begin by defining a few variables:

In [14]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

Next, we initialize the three weight matrices Wq, Wk, and Wv that are shown
in Figure 3.14

In [15]:
torch.manual_seed(123)
w_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

Note that we are setting requires_grad=False to reduce clutter in the
outputs for illustration purposes, but if we were to use the weight matrices for
model training, we would set requires_grad=True to update these matrices
during model training.

Next, we compute the query, key, and value vectors as shown earlier in
Figure 3.14:

In [16]:
query_2 = x_2@w_query
key_2 = x_2@w_key
value_2 = x_2@w_value
print(query_2)

tensor([0.4306, 1.4551])


#### Weight parameters vs attention weights

In [17]:
keys = inputs @ w_key
values = inputs @ w_value
print("keys.shape: ", keys.shape)
print("values.shape: ", values.shape)

keys.shape:  torch.Size([6, 2])
values.shape:  torch.Size([6, 2])


The attention score computation is a dot-product computation similar to what we
have used in the simplified self-attention mechanism in section 3.3. The new aspect here is that we
are not directly computing the dot-product between the input elements but using the query and
key obtained by transforming the inputs via the respective weight matrices

In [18]:
# First, let's compute the attention score ω22:
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
print(attn_scores_22)

tensor(1.8524)


In [19]:
# Again, we can generalize this computation to all attention scores via matrix
# multiplication
attn_scores_2 = query_2@keys.T # All attention scores for given
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


Next, as illustrated in Figure 3.16, we compute the attention weights by
scaling the attention scores and using the softmax function we used earlier..
The difference to earlier is that we now scale the attention scores by dividing
them by the square root of the embedding dimension of the keys, (note that
taking the square root is mathematically the same as exponentiating by 0.5)

In [20]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2/d_k**0.5, dim=-1)

In [21]:
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


Now, the final step is to compute the context vectors, as illustrated in Figure
3.17

Similar to section 3.3, where we computed the context vector as a weighted
sum over the input vectors, we now compute the context vector as a weighted
sum over the value vectors. Here, the attention weights serve as a weighting
factor that weighs the respective importance of each value vector. Similar to
section 3.3, we can use matrix multiplication to obtain the output in one step

In [22]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


So far, we only computed a single context vector, z(2). In the next section, we
will generalize the code to compute all context vectors in the input sequence,
z(1) to z(T).

Why query, key, and value?

The terms "key," "query," and "value" in the context of attention mechanisms
are borrowed from the domain of information retrieval and databases, where
similar concepts are used to store, search, and retrieve information.

A "query" is analogous to a search query in a database. It represents the
current item (e.g., a word or token in a sentence) the model focuses on or
tries to understand. The query is used to probe the other parts of the input
sequence to determine how much attention to pay to them.


The "key" is like a database key used for indexing and searching. In the
attention mechanism, each item in the input sequence (e.g., each word in a
sentence) has an associated key. These keys are used to match with the query.
The "value" in this context is similar to the value in a key-value pair in a
database. It represents the actual content or representation of the input items.

Once the model determines which keys (and thus which parts of the input)
are most relevant to the query (the current focus item), it retrieves the
corresponding values.

### Implementing a compact self-attention Python class

In [23]:
# A compact self-attention class
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.w_query = nn.Parameter(torch.rand(d_in, d_out))
        self.w_key = nn.Parameter(torch.rand(d_in, d_out))
        self.w_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x@self.w_key
        queries = x @ self.w_query
        values = x @ self.w_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
                attn_scores/keys.shape[-1]**0.5, dim=-1
            ) 
        context_vec = attn_weights @ values
        return context_vec

In this PyTorch code, SelfAttention_v1 is a class derived from nn.Module,
which is a fundamental building block of PyTorch models, which provides
necessary functionalities for model layer creation and management.

The __init__ method initializes trainable weight matrices (W_query, W_key,
and W_value) for queries, keys, and values, each transforming the input
dimension d_in to an output dimension d_out.

During the forward pass, using the forward method, we compute the attention
scores (attn_scores) by multiplying queries and keys, normalizing these
scores using softmax. Finally, we create a context vector by weighting the
values with these normalized attention scores

We can use this class as follows

In [24]:
d_in = inputs.shape[1]
d_out = 2
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [25]:
inputs.shape

torch.Size([6, 3])

We can improve the SelfAttention_v1 implementation further by utilizing
PyTorch's nn.Linear layers, which effectively perform matrix multiplication
when the bias units are disabled. Additionally, a significant advantage of
using nn.Linear instead of manually implementing
nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight
initialization scheme, contributing to more stable and effective model
training.

In [26]:
# Listing 3.2 A self-attention class using PyTorch's Linear layers
import torch.nn as nn
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.w_query = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out,bias=qkv_bias)

    def forward(self, x):
        keys = self.w_key(x)
        queries = self.w_query(x)
        values = self.w_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
                attn_scores/keys.shape[-1]**0.5, dim=-1
            ) 
        context_vec = attn_weights @ values
        return context_vec

In [27]:
d_in = inputs.shape[1]
d_out = 2
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Exercice 3.1

Note that nn.Linear in SelfAttention_v2 uses a different weight
initialization scheme as nn.Parameter(torch.rand(d_in, d_out)) used in
SelfAttention_v1, which causes both mechanisms to produce different
results. To check that both implementations, SelfAttention_v1 and
SelfAttention_v2, are otherwise similar, we can transfer the weightmatrices from a SelfAttention_v2 object to a SelfAttention_v1, such that
both objects then produce the same results.

Your task is to correctly assign the weights from an instance of
SelfAttention_v2 to an instance of SelfAttention_v1. To do this, you need
to understand the relationship between the weights in both versions. (Hint:
nn.Linear stores the weight matrix in a transposed form.) After the
assignment, you should observe that both instances produce the same outputs

In [28]:
# import torch.nn as nn
# class SelfAttention_v2(nn.Module):
#     def __init__(self, d_in, d_out, qkv_bias=False):
#         super().__init__()
#         self.d_out = d_out
#         self.w_query = nn.Linear(d_in, d_out,bias=qkv_bias)
#         self.w_key = nn.Linear(d_in, d_out,bias=qkv_bias)
#         self.w_value = nn.Linear(d_in, d_out,bias=qkv_bias)

#     def forward(self, x):
#         keys = self.w_key@x
#         queries = self.w_query@x
#         values = self.w_value@x

#         attn_scores = queries @ keys.T
#         attn_weights = torch.softmax(
#                 attn_scores/keys.shape[-1]**0.5, dim=-1
#             ) 
#         context_vec = attn_weights @ values
#         return context_vec

In [29]:
# d_in = inputs.shape[1]
# d_out = 2
# torch.manual_seed(123)
# sa_v1 = SelfAttention_v2(d_in, d_out)
# print(sa_v1(inputs))

### Hiding future words with causal attention

In this section, we modify the standard self-attention mechanism to create a
causal attention mechanism, which is essential for developing an LLM in the
subsequent chapters.

Causal attention, also known as masked attention, is a specialized form of
self-attention. It restricts a model to only consider previous and current inputs
in a sequence when processing any given token. This is in contrast to the
standard self-attention mechanism, which allows access to the entire input
sequence at once.

##### Applying a causal attention mask

One way to obtain the masked attention weight matrix in causal attention is to apply
the softmax function to the attention scores, zeroing out the elements above the diagonal and
normalizing the resulting matrix

In [30]:
# In the first step illustrated in Figure 3.20, we compute the attention weights
# using the softmax function as we have done in previous sections:

queries = sa_v2.w_query(inputs)
keys = sa_v2.w_key(inputs)

attn_scores = queries @ keys.T

attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


We can implement step 2 in Figure 3.20 using PyTorch's tril function to
create a mask where the values above the diagonal are zero:

In [31]:
# The resulting mask is as follows:
mask_simle = torch.tensor([[1., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1.]])

In [32]:
masked_simple = attn_weights*mask_simle

As we can see, the elements above the diagonal are successfully zeroed out:

In [33]:
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


The third step in Figure 3.20 is to renormalize the attention weights to sum up
to 1 again in each row. We can achieve this by dividing each element in each
row by the sum in each row

In [34]:
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple/row_sums

The result is an attention weight matrix where the attention weights above the
diagonal are zeroed out and where the rows sum to 1

In [35]:
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


#### leakage

When we apply a mask and then renormalize the attention weights, it might
initially appear that information from future tokens (which we intend to
mask) could still influence the current token because their values are part of
the softmax calculation. However, the key insight is that when we
renormalize the attention weights after masking, what we're essentially doing
is recalculating the softmax over a smaller subset (since masked positions
don't contribute to the softmax value).

In simpler terms, after masking and renormalization, the distribution of
attention weights is as if it was calculated only among the unmasked
positions to begin with. This ensures there's no information leakage from
future (or otherwise masked) tokens as we intended

While we could be technically done with implementing causal attention at
this point, we can take advantage of a mathematical property of the softmax
function and implement the computation of the masked attention weights
more efficiently in fewer steps, as shown in Figure 3.21.

The softmax function converts its inputs into a probability distribution. When
negative infinity values (-∞) are present in a row, the softmax function treats
them as zero probability. (Mathematically, this is because e-∞ approaches 0.)
We can implement this more efficient masking "trick" by creating a mask
with 1's above the diagonal and then replacing these 1's with negative infinity
(-inf) values:

In [36]:
context_length = 6
mask = torch.triu(torch.ones(context_length, context_length),diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)

This results in the following mask:

In [37]:
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


Now, all we need to do is apply the softmax function to these masked results,
and we are done:

In [38]:
attn_weights = torch.softmax(masked/keys.shape[-1]**0.5, dim=1)

In [39]:
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


We could now use the modified attention weights to compute the context
vectors 

In [40]:
context_vec = attn_weights @ values

##### Masking additional attention weights with dropout

Dropout in deep learning is a technique where randomly selected hidden
layer units are ignored during training, effectively "dropping" them out. This
method helps prevent overfitting by ensuring that a model does not become
overly reliant on any specific set of hidden layer units. It's important to
emphasize that dropout is only used during training and is disabled afterward.


In the transformer architecture, including models like GPT, dropout in the
attention mechanism is typically applied in two specific areas: after
calculating the attention scores or after applying the attention weights to thevalue vectors

Here, we will apply the dropout mask after computing the attention weights,
as illustrated in Figure 3.22, because it's the more common variant in
practice

In the following code, we apply PyTorch's dropout implementation first to a
6×6 tensor consisting of ones for illustration purposes:

In [41]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


When applying dropout to an attention weight matrix with a rate of 50%, half
of the elements in the matrix are randomly set to zero. To compensate for the
reduction in active elements, the values of the remaining elements in the
matrix are scaled up by a factor of 1/0.5 =2. This scaling is crucial to
maintain the overall balance of the attention weights, ensuring that the
average influence of the attention mechanism remains consistent during both
the training and inference phases.

In [42]:
# Now, let's apply dropout to the attention weight matrix itself:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


#### Now, let's apply dropout to the attention weight matrix itself:

In [43]:
# For simplicity, to simulate such batch inputs, we duplicate the input text example:
batch = torch.stack((inputs, inputs))

In [44]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape 

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec

We can use the CausalAttention class as follows, similar to SelfAttention
previously

In [45]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0,False)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


### Extending single-head attention to multi-head attention

Stacking multiple single-head attention layers

In practical terms, implementing multi-head attention involves creating
multiple instances of the self-attention mechanism (depicted earlier in Figure
3.18 in section 3.4.1), each with its own weights, and then combining their
outputs. Using multiple instances of the self-attention mechanism can be
computationally intensive, but it's crucial for the kind of complex pattern
recognition that models like transformer-based LLMs are known for.

Figure 3.24 illustrates the structure of a multi-head attention module, which
consists of multiple single-head attention modules, as previously depicted in
Figure 3.18, stacked on top of each other

The multi-head attention module in this figure depicts two single-head attention
modules stacked on top of each other. So, instead of using a single matrix Wv for computing the
value matrices, in a multi-head attention module with two heads, we now have two value weight
matrices: W
v1 and Wv2. The same applies to the other weight matrices, Wq and Wk. We obtain
two sets of context vectors Z1 and Z2 that we can combine into a single context vector matrix Z

As mentioned before, the main idea behind multi-head attention is to run the
attention mechanism multiple times (in parallel) with different, learned linear
projections -- the results of multiplying the input data (like the query, key,
and value vectors in attention mechanisms) by a weight matrix.

In [46]:
# In code, we can achieve this by implementing a simple MultiHeadAttentionWrapper class that stacks multiple instances of ourpreviously implemented CausalAttention module

# A wrapper class to implement multi-head attention
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_lenght, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in=d_in, d_out=d_out, context_length=context_lenght, dropout=dropout) for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [47]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, 2)
context_vecs = mha(batch)

In [48]:
print(context_vecs)
print("context_vec.shape : ", context_vecs.shape)

tensor([[[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]],

        [[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]]], grad_fn=<CatBackward0>)
context_vec.shape :  torch.Size([2, 6, 4])


The first dimension of the resulting context_vecs tensor is 2 since we have
two input texts (the input texts are duplicated, which is why the context
vectors are exactly the same for those). The second dimension refers to the 6
tokens in each input. The third dimension refers to the 4-dimensional
embedding of each token.

Exercice 3.2 : Returning 2-dimensional embedding vectors.

Change the input arguments for the MultiHeadAttentionWrapper(...,
num_heads=2) call such that the output context vectors are 2-dimensional
instead of 4-dimensional while keeping the setting num_heads=2. Hint: You
don't have to modify the class implementation; you just have to change one of
the other input arguments.

In [49]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 1
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, 2)
context_vecs = mha(batch)

In [50]:
print(context_vecs)
print("context_vec.shape : ", context_vecs.shape)

tensor([[[-9.1476e-02,  3.4164e-02],
         [-2.6796e-01, -1.3427e-03],
         [-4.8421e-01, -4.8909e-02],
         [-6.4808e-01, -1.0625e-01],
         [-8.8380e-01, -1.7140e-01],
         [-1.4744e+00, -3.4327e-01]],

        [[-9.1476e-02,  3.4164e-02],
         [-2.6796e-01, -1.3427e-03],
         [-4.8421e-01, -4.8909e-02],
         [-6.4808e-01, -1.0625e-01],
         [-8.8380e-01, -1.7140e-01],
         [-1.4744e+00, -3.4327e-01]]], grad_fn=<CatBackward0>)
context_vec.shape :  torch.Size([2, 6, 2])


### Implementing multi-head attention with weight splits

In the previous section, we created a MultiHeadAttentionWrapper to
implement multi-head attention by stacking multiple single-head attention
modules. This was done by instantiating and combining several
CausalAttention objects

In [None]:
# An efficient multi-head attention class


In [58]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out,
    context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        # assert d_out % num_heads == 0 #d_out must be divisible by num_headself.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads #A
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) #B
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_length, context_length), diagonal=1)
    )
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x) 
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2) 
        values = values.transpose(1, 2) 

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 1
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, 2)
context_vecs = mha(batch)