## Simple Self attention mechanism ¶
The goal of self attention is to compute context vectors which are an enriched embeddeding that combines information from all other element. In other words 
for every token we try and create a better represenetation by allowing the an input in a sequence to interact and weigh in the importance of all other positons in the same sequence.

In [1]:
#First we compute intermediate values or attention scores
import torch
torch.manual_seed(123)
import torch
inputs = torch.tensor(
 [[0.43, 0.15, 0.89], # Your (x^1)
 [0.55, 0.87, 0.66], # journey (x^2)
 [0.57, 0.85, 0.64], # starts (x^3)
 [0.22, 0.58, 0.33], # with (x^4)
 [0.77, 0.25, 0.10], # one (x^5)
 [0.05, 0.80, 0.55]] # step (x^6)
)
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [2]:
#Then we normalize the weights 
attn_weights_2_tmp = attn_scores_2/attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())


Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [3]:
#It is advisable to use softmax function for normalization which is better at managing extreme values and offers more favorable gradient properties during training
def softmax_naive(x):
    return torch.exp(x)/torch.exp(x).sum(dim = 0)
attn_weights_2_naive = softmax_naive(attn_scores_2)
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


The softmax function ensures that the attention weights are always positive. This makes the output interpretable as probabilities or relative importance,
where higher weights indicate greater importance.
Note - that this naive softmax implementation (softmax_naive) may encounter
numerical instability problems, such as overflow and underflow, when dealing with
large or small input values. Therefore, in practice, it’s advisable to use the PyTorch
implementation of softmax,

In [4]:
#PyTorch implementation of softmax
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [5]:
#context vector z(2) is the weighted sum of all input vectors, obtained by multiplying each input vector by its corresponding attention weight
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2+= attn_weights_2[i] * x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


In [6]:
#Computing attention scores for all inputs
attn_scores = torch.empty(6,6)
for i, x_i in enumerate(inputs):
    for j,x_j in enumerate(inputs): 
        attn_scores[i,j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


When computing the preceding attention score tensor, we used for loops in
Python. However, for loops are generally slow, and we can achieve the same results
using matrix multiplication

In [7]:
attn_scores = inputs @inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


The dim parameter in functions like torch.softmax
specifies the dimension of the input tensor along which the function will be computed. By setting dim=-1, we are instructing the softmax function to apply the normalization along the last dimension of the attn_scores tensor. If attn_scores is a
two-dimensional tensor (for example, with a shape of [rows, columns]), it will normalize across the columns so that the values in each row (summing over the column
dimension) sum up to 1.

In [8]:
#Computing attention weights 
attn_weights = torch.softmax(attn_scores, dim = -1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [9]:
print("All row sums:", attn_weights.sum(dim=-1))

All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


## Implementing self-attention with trainable weights
The most notable difference is the introduction of weight matrices that are
updated during model training. These trainable weight matrices are crucial so that
the model (specifically, the attention module inside the model) can learn to produce
“good” context vectors. 

We will implement the self-attention mechanism step by step by introducing the
three trainable weight matrices Wq, Wk, and Wv. These three matrices are used to
project the embedded input tokens, x(i), into query, key, and value vectors, respectively

In [10]:
# We start here by computing only one context vector, z(2), for illustration purposes.
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)


We set requires_grad=False to reduce clutter in the outputs, but if we were to use
the weight matrices for model training, we would set requires_grad=True to update
these matrices during model training. 

Next we compute Query, Key and Value Matrices

In [11]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value


Even though our temporary goal is only to compute the one context vector, z(2), we still
require the key and value vectors for all input elements as they are involved in computing the attention weights with respect to the query q(2).
We can obtain all keys and values via matrix multiplication:

In [12]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


The second step is to compute the attention scores

In [13]:
keys_2 = keys[1]
attn_scores_22 = query_2.dot(key_2)
print(attn_scores_22)

tensor(1.8524)


In [14]:
#Generalizing to compute all attention scores
attn_scores_2 = query_2 @ keys.T 
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


We compute the attention weights by scaling the attention scores and
using the softmax function. However, now we scale the attention scores by dividing
them by the square root of the embedding dimension of the keys (taking the square
root is mathematically the same as exponentiating by 0.5)

In [15]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2/d_k ** 0.5, dim = -1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


Similar to when we computed the context vector as a weighted sum over the input vectors, we now compute the context vector as a weighted sum over the
value vectors. Here, the attention weights serve as a weighting factor that weighs the respective importance of each value vector. Also as before, we can use matrix multiplication to obtain the output in one step:

In [16]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


## Implementing a compact self-attention Python class

At this point, we have gone through a lot of steps to compute the self-attention outputs. We did so mainly for illustration purposes so we could go through one step at a
time. In practice, with the LLM implementation in the next chapter in mind, it is
helpful to organize this code into a Python class, as shown below

What is nn.Module?

 .nn.Module is the base class for all neural network components in PyTorch.

 .Every layer (nn.Linear, nn.Conv2d, nn.LSTM, etc.) is a subclass of nn.Module.

 .Whole models (like GPT, ResNet, etc.) are also subclasses of nn.Module.

In [17]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1] ** 0.5, dim = -1)
        
        context_vec = attn_weights @ values
        return context_vec

In [18]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


We can improve the SelfAttention_v1 implementation further by utilizing
PyTorch’s nn.Linear layers, which effectively perform matrix multiplication when
the bias units are disabled. Additionally, a significant advantage of using nn.Linear
instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear
has an optimized weight initialization scheme, contributing to more stable and
effective model training.

In [19]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
        
    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1] ** 0.5, dim = -1)
        context_vec = attn_weights @ values
        return context_vec

Note that SelfAttention_v1 and SelfAttention_v2 give different outputs because
they use different initial weights for the weight matrices since nn.Linear uses a more
sophisticated weight initialization scheme.

In [20]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


## Exercise 3.1
Note that nn.Linear in SelfAttention_v2 uses a different weight initialization
scheme as nn.Parameter(torch.rand(d_in, d_out)) used in SelfAttention_v1,
which causes both mechanisms to produce different results. To check that both
implementations, SelfAttention_v1 and SelfAttention_v2, are otherwise similar, we can transfer the weight matrices from a SelfAttention_v2 object to a SelfAttention_v1, such that both objects then produce the same results.
Your task is to correctly assign the weights from an instance of SelfAttention_v2
to an instance of SelfAttention_v1. To do this, you need to understand the relationship between the weights in both versions. (Hint: nn.Linear stores the weight
matrix in a transposed form.) After the assignment, you should observe that both
instances produce the same outputs.


In [21]:
#Collecting weight instances from SelfAttention_v2
sa_v1 = SelfAttention_v1(d_in , d_out)
sa_v2 = SelfAttention_v2(d_in, d_out)

with torch.no_grad():  # we don’t need gradients for this
    sa_v1.W_query.copy_(sa_v2.W_query.weight.T)
    sa_v1.W_key.copy_(sa_v2.W_key.weight.T)
    sa_v1.W_value.copy_(sa_v2.W_value.weight.T)

out_v2 = sa_v2(inputs)
out_v1 = sa_v1(inputs)

print(torch.allclose(out_v1, out_v2))

True


In [22]:
print(out_v1, "\n \n", out_v2)

tensor([[0.1839, 0.0178],
        [0.1815, 0.0205],
        [0.1818, 0.0202],
        [0.1826, 0.0191],
        [0.1875, 0.0144],
        [0.1799, 0.0218]], grad_fn=<MmBackward0>) 
 
 tensor([[0.1839, 0.0178],
        [0.1815, 0.0205],
        [0.1818, 0.0202],
        [0.1826, 0.0191],
        [0.1875, 0.0144],
        [0.1799, 0.0218]], grad_fn=<MmBackward0>)


In [23]:
#alternative method, here .weight is an attribute to extract weight values from the matrix, .values is the entire linear object
sa_v1 = SelfAttention_v1(d_in , d_out)
sa_v2 = SelfAttention_v2(d_in, d_out)
sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_key   = torch.nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)

out_v2 = sa_v2(inputs)
out_v1 = sa_v1(inputs)

print(out_v1, "\n \n", out_v2)

tensor([[-0.2061,  0.3977],
        [-0.2032,  0.3992],
        [-0.2032,  0.3991],
        [-0.2039,  0.3977],
        [-0.2044,  0.3967],
        [-0.2035,  0.3986]], grad_fn=<MmBackward0>) 
 
 tensor([[-0.2061,  0.3977],
        [-0.2032,  0.3992],
        [-0.2032,  0.3991],
        [-0.2039,  0.3977],
        [-0.2044,  0.3967],
        [-0.2035,  0.3986]], grad_fn=<MmBackward0>)


##  Hiding future words with causal attention
Causal attention, also known as masked attention, is a specialized form of selfattention. It restricts a model to only consider previous and current inputs in a sequence
when processing any given token when computing attention scores. This is in contrast
to the standard self-attention mechanism, which allows access to the entire input sequence at once

To achieve this in GPT-like LLMs, for each token processed, we mask out
the future tokens, which come after the current token in the input text. We mask out the attention weights above the diagonal, and we normalize the nonmasked attention weights such that the attention weights sum to 1 in
each row

Our next step is to implement the causal attention mask in code. To implement the
steps to apply a causal attention mask to obtain the masked attention weights, let’s work with the attention scores and weights from the previous section to code the causal attention mechanism

In [24]:
# we compute the attention weights using the softmax function as we have done previously:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores/keys.shape[-1] ** 0.5, dim = -1)
print(attn_weights)

tensor([[0.1574, 0.1694, 0.1692, 0.1687, 0.1649, 0.1703],
        [0.1588, 0.1672, 0.1667, 0.1720, 0.1577, 0.1775],
        [0.1591, 0.1672, 0.1666, 0.1719, 0.1579, 0.1773],
        [0.1626, 0.1667, 0.1664, 0.1700, 0.1609, 0.1734],
        [0.1651, 0.1664, 0.1662, 0.1686, 0.1630, 0.1706],
        [0.1602, 0.1671, 0.1666, 0.1712, 0.1591, 0.1758]],
       grad_fn=<SoftmaxBackward0>)


We can implement the second step using PyTorch’s tril function to create a mask
where the values above the diagonal are zero:

In [25]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


Now, we can multiply this mask with the attention weights to zero-out the values above
the diagonal:

In [26]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1574, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1588, 0.1672, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1591, 0.1672, 0.1666, 0.0000, 0.0000, 0.0000],
        [0.1626, 0.1667, 0.1664, 0.1700, 0.0000, 0.0000],
        [0.1651, 0.1664, 0.1662, 0.1686, 0.1630, 0.0000],
        [0.1602, 0.1671, 0.1666, 0.1712, 0.1591, 0.1758]],
       grad_fn=<MulBackward0>)


The third step is to renormalize the attention weights to sum up to 1 again in each
row. We can achieve this by dividing each element in each row by the sum in each row:

In [27]:
rows_sum = masked_simple.sum(dim = -1, keepdim=True)
masked_simple_norm = masked_simple/rows_sum
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4871, 0.5129, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3227, 0.3392, 0.3381, 0.0000, 0.0000, 0.0000],
        [0.2442, 0.2505, 0.2499, 0.2554, 0.0000, 0.0000],
        [0.1991, 0.2006, 0.2004, 0.2033, 0.1966, 0.0000],
        [0.1602, 0.1671, 0.1666, 0.1712, 0.1591, 0.1758]],
       grad_fn=<DivBackward0>)


 While we could wrap up our implementation of causal attention at this point, we can
 still improve it. Let’s take a mathematical property of the softmax function and imple
ment the computation of the masked attention weights more efficiently in fewer steps,

The softmax function converts its inputs into a probability distribution. When nega
tive infinity values (-∞) are present in a row, the softmax function treats them as zero
 probability. (Mathematically, this is because e–∞ approaches 0.)
 We can implement this more efficient masking “trick” by creating a mask with 1s
 above the diagonal and then replacing these 1s with negative infinity (-inf) values:

 We effectively avoid the renormalization step where we divide by row_sum

In [28]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[-0.0555,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.1005, -0.0276,    -inf,    -inf,    -inf,    -inf],
        [-0.0984, -0.0277, -0.0325,    -inf,    -inf,    -inf],
        [-0.0603, -0.0247, -0.0278,  0.0029,    -inf,    -inf],
        [-0.0328, -0.0220, -0.0238, -0.0033, -0.0507,    -inf],
        [-0.0840, -0.0240, -0.0281,  0.0099, -0.0935,  0.0473]],
       grad_fn=<MaskedFillBackward0>)


In [29]:
# Applying Softmax finally
attn_weights = torch.softmax(masked/keys.shape[-1] ** 0.5, dim = -1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4871, 0.5129, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3227, 0.3392, 0.3381, 0.0000, 0.0000, 0.0000],
        [0.2442, 0.2505, 0.2499, 0.2554, 0.0000, 0.0000],
        [0.1991, 0.2006, 0.2004, 0.2033, 0.1966, 0.0000],
        [0.1602, 0.1671, 0.1666, 0.1712, 0.1591, 0.1758]],
       grad_fn=<SoftmaxBackward0>)


## Masking additional attention weights with dropout
This method helps prevent overfitting by ensuring that a model does not become overly reliant on any spe
cific set of hidden layer units. It’s important to emphasize that dropout is only used
 during training and is disabled afterward.
 In the transformer architecture, including models like GPT, dropout in the atten
tion mechanism is typically applied at two specific times: after calculating the atten
tion weights or after applying the attention weights to the value vectors. Here we will
 apply the dropout mask after computing the attention weights, because it’s the more common variant in practice.

In [30]:
#Example implementation
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [31]:
#Applying to our attention weights
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6454, 0.6784, 0.6762, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5009, 0.4998, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4013, 0.0000, 0.4066, 0.0000, 0.0000],
        [0.0000, 0.3342, 0.3333, 0.3424, 0.3182, 0.0000]],
       grad_fn=<MulBackward0>)


Before we begin, let’s ensure that the code can handle batches consisting of
 more than one input so that the CausalAttention class supports the batch outputs
 produced by the data loader we implemented in chapter 2.
 For simplicity, to simulate such batch inputs, we duplicate the input text example:

In [32]:
#.stack combines a sequence of tensors by adding a new dimension
batch = torch.stack((inputs,inputs) , dim = 0)
print(batch.shape)

torch.Size([2, 6, 3])


In [33]:
#Implementing the causal self attention class 
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    def forward(self,x):
        b,num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool() [:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1] ** 0.5, dim = -1)
        context_vec = attn_weights @ values
        return context_vec
        

We can use the CausalAttention class as follows, similar to SelfAttention
 previously:

In [34]:
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)


context_vecs.shape: torch.Size([2, 6, 2])


In [35]:
print(context_vecs)

tensor([[[-0.4821,  0.4336],
         [-0.5368,  0.5483],
         [-0.5545,  0.5886],
         [-0.4937,  0.5311],
         [-0.4589,  0.5169],
         [-0.4479,  0.4971]],

        [[-0.4821,  0.4336],
         [-0.5368,  0.5483],
         [-0.5545,  0.5886],
         [-0.4937,  0.5311],
         [-0.4589,  0.5169],
         [-0.4479,  0.4971]]], grad_fn=<UnsafeViewBackward0>)


 A wrapper class to implement multi-head attention

In [36]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        #Wrapping python list into a PyTorch List, to specify modules and update their parameters
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range (num_heads)]
        )
        
        #run list comprehension through each head, and concatenate along embedding dimension
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim = -1)

In [37]:
#Testing our Multi Head Attention Wrapper
torch.manual_seed(123)
context_length = batch.shape[1]
d_in , d_out = 3 , 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


Exercise 3.2 

Returning two-dimensional embedding vectors 
Change the input arguments for the MultiHeadAttentionWrapper(..., num_
 heads=2) call such that the output context vectors are two-dimensional instead of
 four dimensional while keeping the setting num_heads=2. Hint: You don’t have to
 modify the class implementation; you just have to change one of the other input
 arguments.

In [38]:
d_in,d_out = 3,1
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.0189, 0.2729],
         [0.2181, 0.3037],
         [0.2804, 0.3125],
         [0.2830, 0.2793],
         [0.2476, 0.2541],
         [0.2748, 0.2513]],

        [[0.0189, 0.2729],
         [0.2181, 0.3037],
         [0.2804, 0.3125],
         [0.2830, 0.2793],
         [0.2476, 0.2541],
         [0.2748, 0.2513]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


##  Implementing multi-head attention with weight splits

The following MultiHeadAttention class integrates the multi-head functionality within a single class.
It splits the input into multiple heads by reshaping the projected query, key, and value tensors and then combines the results from these heads after computing attention.

In [59]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert(d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads #Projection dimension reduced to match output dimension
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) #Linear layer to combine outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask" , torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self , x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = self.W_query(x)
        
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        queries = queries.transpose(1,2)
        
        attn_scores = queries @ keys.transpose(2,3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores/ keys.shape[-1] ** 0.5, dim = -1)    
        attn_weights = self.dropout(attn_weights)
        
        context_vec = (attn_weights @ values).transpose(1, 2)
        
        context_vec = context_vec.contiguous().view(b , num_tokens, self.d_out)
        
        context_vec = self.out_proj(context_vec) 
        return context_vec  
            
        

In [60]:
# Demonstration
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],
                   
                    [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

print(a.shape)

torch.Size([1, 2, 3, 4])


In [61]:
#perform a batched matrix multiplication between the tensor itself and a view of the tensor where we transposed the last two dimensions, num_tokens and head_dim:
print(a @ a.transpose(2, 3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


The MultiHeadAttention class can be used similar to the SelfAttention and
 CausalAttention classes we implemented earlier:

In [62]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


In [None]:
#Initialising a class with similar attention heads and input/output dimensions as smallest GPT-2 Model