##1.1 Brief Introduction

* There are different attention variants here we will discuss: self-attention,causal attention and multi-head attention.
* These variants builds on each other, the goal here will be to arrive at a compact and efficient implementation of multi-head attention which we will plug into the LLM architecture.

##1.2 Simple-attention with trainable weights

* In self-attention our goal is to calculate context vectors for each elements in the input sequence.
* Let's use the following input sequence and represent it in a embedding vector.


In [1]:
import torch
inputs = torch.tensor(
   [ [0.34,0.55,0.66],#Attention(x1)
    [0.99,0.87,0.56],#is(x2)
    [0.67,0.12,0.65],#all(x3)
    [0.99,0.89,0.53],#you(x4)
    [0.77,0.67,0.77],#need(x5)
   ]
)

* Next we initialize the query,key,values weights matrices which willl help us project our input embedding into it's respective query,key and value matrices.


In [2]:
torch.manual_seed(123)
w_query = torch.rand(3,5)
w_key = torch.rand(3,5)
w_value = torch.rand(3,5)

In [3]:
##computing for our query,key and value matrices
query = torch.matmul(inputs, w_query)
key = torch.matmul(inputs, w_key)
value = torch.matmul(inputs, w_value)
print(f"Query:{query}")
print(f"Key:{key}")
print(f"Value:{value}")

Query:tensor([[0.7853, 0.7042, 0.1919, 0.4651, 0.6335],
        [1.2236, 1.0150, 0.3807, 0.9519, 0.8824],
        [0.5073, 0.8091, 0.2301, 0.6112, 0.3424],
        [1.2314, 0.9971, 0.3804, 0.9497, 0.8875],
        [1.0513, 1.0183, 0.3207, 0.8049, 0.7873]])
Key:tensor([[1.0406, 0.5497, 1.1973, 1.2639, 0.4013],
        [1.5090, 0.7888, 1.8532, 1.7368, 0.9145],
        [0.7995, 0.3310, 1.1880, 0.9585, 0.5008],
        [1.5063, 0.7924, 1.8402, 1.7294, 0.9189],
        [1.3882, 0.7024, 1.7238, 1.6450, 0.7225]])
Value:tensor([[0.8222, 0.4466, 0.7402, 0.9640, 0.9085],
        [1.5228, 0.9283, 1.2754, 1.3727, 1.6282],
        [0.9374, 0.6549, 0.8192, 1.1236, 0.9768],
        [1.5193, 0.9229, 1.2689, 1.3464, 1.6250],
        [1.3175, 0.7990, 1.1416, 1.3923, 1.4168]])


* Next we calculate the attention scores

In [4]:
attn_scores = torch.matmul(query,key.T)
print(f"Attention Scores:{attn_scores}")

Attention Scores:tensor([[2.2762, 3.4834, 1.8520, 3.4806, 3.1385],
        [3.8442, 5.8128, 3.1209, 5.8050, 5.2712],
        [2.1581, 3.2049, 1.7041, 3.2004, 2.9221],
        [3.8416, 5.8109, 3.1214, 5.8031, 5.2692],
        [3.3711, 5.1021, 2.7245, 5.0962, 4.6205]])


* Next we compute the attention weights by using the formula below:
    * The formula for the attention weights is given by:


$\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$.

* Where Q represents query,K represents key, and dk represent output dimension.

In [5]:
d_k = key.shape[-1]
attn_weights = torch.softmax(attn_scores/d_k**0.5,dim=-1)
print(f"Attention Weights:{attn_weights}")

Attention Weights:tensor([[0.1486, 0.2551, 0.1230, 0.2547, 0.2186],
        [0.1186, 0.2860, 0.0858, 0.2850, 0.2245],
        [0.1559, 0.2490, 0.1273, 0.2485, 0.2194],
        [0.1186, 0.2860, 0.0859, 0.2850, 0.2245],
        [0.1277, 0.2770, 0.0957, 0.2763, 0.2233]])


* To get the context vector we do a dot product between attention weights and values matrices

In [6]:
context_vector = torch.matmul(attn_weights,value)
print(f"Context Vector:{context_vector}")

Context Vector:tensor([[1.3009, 0.7934, 1.1089, 1.2789, 1.3941],
        [1.3424, 0.8171, 1.1409, 1.2998, 1.4386],
        [1.2932, 0.7887, 1.1030, 1.2750, 1.3859],
        [1.3424, 0.8171, 1.1409, 1.2998, 1.4385],
        [1.3305, 0.8103, 1.1317, 1.2938, 1.4259]])


In [7]:
##implementing a compact self-attention class
import torch.nn as nn
class SelfAttention_v1(nn.Module):
  def __init__(self,d_in,d_out):
    super().__init__()
    self.w_query = nn.Parameter(torch.rand(d_in,d_out))
    self.w_key = nn.Parameter(torch.rand(d_in,d_out))
    self.w_value = nn.Parameter(torch.rand(d_in,d_out))

  def forward(self,x):
    keys = x @ self.w_key
    queries = x @ self.w_query
    values = x @ self.w_value
    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
    context_vec = attn_weights @ values
    return context_vec

In [8]:
##use case for the class
d_in = inputs.shape[-1] # or 3
d_out = 5 # As used in previous manual calculations
sa_v1 = SelfAttention_v1(d_in,d_out)
context_vector = sa_v1(inputs)
print(f"Context Vector:{context_vector}")

Context Vector:tensor([[1.1028, 1.0398, 0.9905, 1.1706, 1.2588],
        [1.1476, 1.0834, 1.0098, 1.2004, 1.2930],
        [1.0974, 1.0344, 0.9883, 1.1669, 1.2544],
        [1.1480, 1.0837, 1.0099, 1.2006, 1.2933],
        [1.1326, 1.0687, 1.0034, 1.1903, 1.2813]], grad_fn=<MmBackward0>)


* We can improve `SelfAttention_v1` implementation further by utilizing Pytorch's `nn.Linear` layers, which effectively perform matrix multiplications when the bias units are disabled.

In [9]:
## self-attention class using pytorch's
class SelfAttention_v2(nn.Module):
 def __init__(self,d_in,d_out,qkv_bias=False):
  super().__init__()
  self.w_key = nn.Linear(d_in,d_out,bias=qkv_bias)
  self.w_query = nn.Linear(d_in,d_out,bias=qkv_bias)
  self.w_value = nn.Linear(d_in,d_out,bias=qkv_bias)

 def forward(self,x):
   keys  = self.w_key(x)
   queries = self.w_query(x)
   values = self.w_value(x)
   attn_scores = queries @ keys.T
   attn_weights = torch.softmax(
      attn_scores/keys.shape[-1]**0.5,dim=-1
  )
   context_vec = attn_weights @ values
   return context_vec

In [10]:
##use case
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in,d_out)
context_vector = sa_v2(inputs)
print(f"Context Vector:{context_vector}")

Context Vector:tensor([[ 0.6918,  0.4175, -0.7206, -0.2338,  0.2625],
        [ 0.6856,  0.4112, -0.7157, -0.2336,  0.2580],
        [ 0.6935,  0.4199, -0.7217, -0.2341,  0.2648],
        [ 0.6856,  0.4111, -0.7156, -0.2336,  0.2579],
        [ 0.6878,  0.4134, -0.7174, -0.2337,  0.2595]], grad_fn=<MmBackward0>)


##1.3 Causal Attention

* For many LLM tasks, you will want the self-attention mechanism to consider only tokens that appear prior to the current position when predicting the next token in a sequence.
* Causal attention restricts a model to only consider previous and current inputs in a sequence when processing any given token when computing attention scores.

In [11]:
##applying a causal attention
#1 computing the attention weights using softmax
queries = sa_v2.w_query(inputs)
keys = sa_v2.w_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5,dim=-1)
print(attn_weights)

tensor([[0.2090, 0.1883, 0.2181, 0.1878, 0.1968],
        [0.2184, 0.1807, 0.2265, 0.1803, 0.1941],
        [0.2103, 0.1907, 0.2120, 0.1908, 0.1962],
        [0.2184, 0.1806, 0.2266, 0.1802, 0.1941],
        [0.2151, 0.1834, 0.2238, 0.1830, 0.1948]], grad_fn=<SoftmaxBackward0>)


In [12]:
#masking values along diagonals
context_length = attn_scores.shape[0]
mask = torch.tril(torch.ones(context_length,context_length))
print(mask)

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])


In [13]:
#multiplying mask with attention weights
masked = attn_weights*mask
print(masked)

tensor([[0.2090, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2184, 0.1807, 0.0000, 0.0000, 0.0000],
        [0.2103, 0.1907, 0.2120, 0.0000, 0.0000],
        [0.2184, 0.1806, 0.2266, 0.1802, 0.0000],
        [0.2151, 0.1834, 0.2238, 0.1830, 0.1948]], grad_fn=<MulBackward0>)


In [14]:
## renormalize the attention weights to sum up to to 1
row_sums = masked.sum(dim=1,keepdim=True)
masked_norm = masked /row_sums
print(masked_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5472, 0.4528, 0.0000, 0.0000, 0.0000],
        [0.3431, 0.3111, 0.3458, 0.0000, 0.0000],
        [0.2710, 0.2242, 0.2812, 0.2236, 0.0000],
        [0.2151, 0.1834, 0.2238, 0.1830, 0.1948]], grad_fn=<DivBackward0>)


* The softmax function converts its inputs into a probability distribution.
* When negative infinity values are present in a row, the softmax function treats them as zero probability.

In [15]:
mask = torch.triu(torch.ones(context_length,context_length),diagonal=1)
masked = attn_scores.masked_fill(mask.bool(),-torch.inf)
print(masked)

tensor([[-0.3166,    -inf,    -inf,    -inf,    -inf],
        [-0.5420, -0.9652,    -inf,    -inf,    -inf],
        [-0.3316, -0.5503, -0.3142,    -inf,    -inf],
        [-0.5402, -0.9646, -0.4576, -0.9698,    -inf],
        [-0.4812, -0.8374, -0.3923, -0.8418, -0.7029]],
       grad_fn=<MaskedFillBackward0>)


In [16]:
##applying softmax
attn_weights = torch.softmax(masked,dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6042, 0.3958, 0.0000, 0.0000, 0.0000],
        [0.3545, 0.2848, 0.3607, 0.0000, 0.0000],
        [0.2949, 0.1929, 0.3203, 0.1919, 0.0000],
        [0.2330, 0.1632, 0.2547, 0.1625, 0.1867]], grad_fn=<SoftmaxBackward0>)


##1.3.1 Masking additional weights with dropout

* `Dropout` in deep learning is a technique where randomly selected hidden layer units are ignored during training, effectively `dropping` the out.
* This method helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units.
* In the transformer architecture, dropout in the attention mechanism is typically applied at two specific times: after calculating the attention weights or after applying the attention weights to the value vectors.

In [17]:
##example
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)#dropout of 50%
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


* When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero.
* To compensate for the reduction in active elements,the values of the reamining elements in the matrix are scaled up by a factor of 1/0.5 = 2.
* This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of attention mechanism remains consistent during both the training and inference phases.

In [18]:
##applying dropout to the attention weight matrix
torch.manual_seed(123)
print(dropout(attn_weights))


tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5697, 0.7214, 0.0000, 0.0000],
        [0.5898, 0.0000, 0.6406, 0.0000, 0.0000],
        [0.4660, 0.0000, 0.0000, 0.3249, 0.0000]], grad_fn=<MulBackward0>)


##1.3.2 Implementing a compact causal attention class

In [19]:
#ensuring code can handle batches consisting of more than one input so that
##causal attention class supports batch outputs produced by dataloader
batch = torch.stack((inputs,inputs),dim=0)
print(batch.shape)

torch.Size([2, 5, 3])


* The result in a three dimensional tensor consisting of two input texts with 5 tokens each,where each token is a three dimensional embedding vector

In [27]:
class CausalAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout_rate,qkv_bias=False):
    super().__init__()
    self.w_query = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.w_key = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.w_value = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.dropout = nn.Dropout(dropout_rate)
    self.register_buffer(
        "mask",
        torch.triu(torch.ones(context_length,context_length),diagonal=1)
    )

  def forward(self,x):
    b,num_tokens,d_in = x.shape
    keys = self.w_key(x)
    queries = self.w_query(x)
    values = self.w_value(x)
    attn_scores = queries @ keys.transpose(1,2)#tranposing dimensions 1 and 2 keeping the batch dimension at the first position
    # Ensure mask is broadcastable or correctly sized for batch operation
    attn_scores = attn_scores.masked_fill(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf) # Adjusted mask slicing for batch
    attn_weights = torch.softmax(
        attn_scores /keys.shape[-1]**0.5, dim=-1
    )
    attn_weights = self.dropout(attn_weights)
    context_vec = attn_weights @ values
    return context_vec

In [24]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, qkv_bias=False)
context_vecs = ca(batch)
print(f"Context vectors shape:{context_vecs.shape}")

Context vectors shape:torch.Size([2, 5, 5])


* The resulting context vector is a three dimensioal tensor where each token is now represented by a five-dimensional embedding

##1.4 Multi-Head attention

* Our final step is to extended previously implemented causal attention class over multiple heads.
* This is called `multi-head attention`.
* The terms `multi-head` refers to dividing the attention mechanism into multiple "heads" each operating independently.

##1.4.1 Stacking multiple single-head attention layers


* Implementing multi-head attention involves creating multiple instances of self-attention mechanism, each with it's own weights, and the combining their outputs.

In [30]:
## wrapper class for multi-head attention
class MultiHeadAttentionWrappper(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
    super().__init__()
    self.attn_heads = nn.ModuleList(
        [CausalAttention(d_in,d_out,context_length,dropout,qkv_bias) for _ in range(num_heads)]
    )

  def forward(self,x):
    return torch.cat([head(x) for head in self.attn_heads],dim=-1)

In [33]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in,d_out = 3,2
mha = MultiHeadAttentionWrappper(
    d_in,d_out,context_length,0.0,num_heads=2
)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4992, -0.0313,  0.4853,  0.2945],
         [-0.7145, -0.1641,  0.6931,  0.4656],
         [-0.6476, -0.0726,  0.6398,  0.3718],
         [-0.7172, -0.1349,  0.7042,  0.4430],
         [-0.7294, -0.1260,  0.7166,  0.4458]],

        [[-0.4992, -0.0313,  0.4853,  0.2945],
         [-0.7145, -0.1641,  0.6931,  0.4656],
         [-0.6476, -0.0726,  0.6398,  0.3718],
         [-0.7172, -0.1349,  0.7042,  0.4430],
         [-0.7294, -0.1260,  0.7166,  0.4458]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 5, 4])


* The first dimension of the resulting context_vecs tensor is 2 since we have two input texts.
* The second dimension refers to 5 tokens in each input.
* The third dimension refers to the four-dimensional dimensional embedding of each token.

##1.4.2 Implementing multi-head attention with weight splits

* So far we've implemented a `MultiHeadAttentionWrapper` which performs multi-head attention by stacking multiple single-head attention modules.
* In the `MultiHeadAttentionWrapper` class, multiple heads are implemented by creating a list of CausalAttention bojects, each representing a seperate attention head.
* Now let's write code which will implement multi-head attention in a single class

In [36]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
    super().__init__()
    assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads
    self.w_query = nn.Linear(d_in,d_out,qkv_bias)
    self.w_key = nn.Linear(d_in,d_out,qkv_bias)
    self.w_value = nn.Linear(d_in,d_out,qkv_bias)
    self.out_proj = nn.Linear(d_out,d_out) # Corrected d_in to d_out
    self.dropout = nn.Dropout(dropout)
    self.register_buffer(
        "mask",
        torch.triu(torch.ones(context_length,context_length),diagonal=1)
    )

  def forward(self,x):
    b,num_tokens,d_in = x.shape
    keys = self.w_key(x)
    values = self.w_value(x)
    queries = self.w_query(x)
    ##splitting matrix by adding a num_heads and head_dim dimensions
    keys= keys.view(b,num_tokens,self.num_heads,self.head_dim)
    values = values.view(b,num_tokens,self.num_heads,self.head_dim)
    queries = queries.view(b,num_tokens,self.num_heads,self.head_dim)

    #tranpsoing from shape (b,num_tokens,num_heads,head_dim) to (b,num_heads,num_tokens,head_dim)
    keys = keys.transpose(1,2)
    values = values.transpose(1,2)
    queries = queries.transpose(1,2)

    attn_scores = queries @ keys.transpose(2,3) # Fixed typo tranpose -> transpose
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens].unsqueeze(0).unsqueeze(0) # Added unsqueeze for broadcasting

    attn_scores.masked_fill_(mask_bool,-torch.inf)

    attn_weights = torch.softmax(
        attn_scores /keys.shape[-1]**0.5,dim=-1
    )
    attn_weights = self.dropout(attn_weights)

    context_vec_per_head = (attn_weights @ values).transpose(1,2) # Fixed typo tranpose -> transpose and renamed variable

    context_vec = context_vec_per_head.contiguous().view( # Used corrected variable name
        b,num_tokens,self.d_out
    )
    context_vec = self.out_proj(context_vec)
    return context_vec

* The splitting of the query,key and value tensors is achieved through tensor reshaping and tranposing operations using Pytorch's `.view` and `.transpose` methods.
* The key operation is to split the `d_out` dimension into `num_heads` and `head_dim`, where `head_dim=d_out/num_heads`.
* This splitting is then achieved using `.view` method: a tensor of dimensions, `(b,num_tokens,d_out)` is reshaped to dimension `(b,num_tokens,num_heada,head_dim)`.
* The tensors are then transposed to bring the `num_heads` dimension before the `num_tokens` dimension, resulting in a shape of `(b,num_heads,num_tokens,head_dim)`.
* This tranposition is crucial for correctly aligning the queries,keys and values across different heads and performing batched matrix multiplications efficiently.


In [39]:
##example
a = torch.randn(1,2,3,4)
print(a)

tensor([[[[-2.2150, -1.3193, -2.0915,  0.9629],
          [-0.0319, -0.4790,  0.7668,  0.0275],
          [-0.5872,  1.1952, -1.2096, -0.5560]],

         [[-2.7202,  0.5421, -1.1541,  0.7763],
          [-0.7067, -0.9222,  3.8954, -0.6027],
          [-0.0480,  0.5349,  1.1031,  1.3334]]]])


In [41]:
##performing batched matrix multiplication between the tensor itself and view of the tensor where we
## transposed the last 2 dimensions , num_tokens ad head_dim
print(a @ a.transpose(2,3))

tensor([[[[11.9482, -0.8749,  1.7182],
          [-0.8749,  0.8192, -1.4965],
          [ 1.7182, -1.4965,  3.5454]],

         [[ 9.6278, -3.5411,  0.1825],
          [-3.5411, 16.8871,  3.0340],
          [ 0.1825,  3.0340,  3.2833]]]])


*  In this case, the matrix multiplication implementation in PyTorch handles the four
dimensional input tensor so that the matrix multiplication is carried out between the two
 last dimensions (num_tokens, head_dim) and then repeated for the individual heads.

In [44]:
first_head = a[0,0,:,:]
first_res = first_head @ first_head.T
print("First head:\n",first_res)

second_head = a[0,1,:,:]
second_res =second_head @ second_head.T
print("\nSecond head:\n",second_res)

First head:
 tensor([[11.9482, -0.8749,  1.7182],
        [-0.8749,  0.8192, -1.4965],
        [ 1.7182, -1.4965,  3.5454]])

Second head:
 tensor([[ 9.6278, -3.5411,  0.1825],
        [-3.5411, 16.8871,  3.0340],
        [ 0.1825,  3.0340,  3.2833]])


*  Continuing with MultiHeadAttention, after computing the attention weights and con
text vectors, the context vectors from all heads are transposed back to the shape `(b,
 num_tokens, num_heads, head_dim)`. These vectors are then reshaped (flattened) into
 the shape `(b, num_tokens, d_out)`, effectively combining the outputs from all heads.

In [45]:
torch.manual_seed(123)
batch_size,context_length,d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in,d_out,context_length,0.0,num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:",context_vecs.shape)

tensor([[[0.2695, 0.4288],
         [0.2774, 0.3024],
         [0.2874, 0.3482],
         [0.2852, 0.3052],
         [0.2888, 0.3003]],

        [[0.2695, 0.4288],
         [0.2774, 0.3024],
         [0.2874, 0.3482],
         [0.2852, 0.3052],
         [0.2888, 0.3003]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 5, 2])
