##1.1 Brief Introduction

* There are different attention variants here we will discuss: self-attention,causal attention and multi-head attention.
* These variants builds on each other, the goal here will be to arrive at a compact and efficient implementation of multi-head attention which we will plug into the LLM architecture.

##1.2 Simple-attention with trainable weights

* In self-attention our goal is to calculate context vectors for each elements in the input sequence.
* Let's use the following input sequence and represent it in a embedding vector.


In [1]:
import torch
inputs = torch.tensor(
   [ [0.34,0.55,0.66],#Attention(x1)
    [0.99,0.87,0.56],#is(x2)
    [0.67,0.12,0.65],#all(x3)
    [0.99,0.89,0.53],#you(x4)
    [0.77,0.67,0.77],#need(x5)
   ]
)

* Next we initialize the query,key,values weights matrices which willl help us project our input embedding into it's respective query,key and value matrices.


In [2]:
torch.manual_seed(123)
w_query = torch.rand(3,5)
w_key = torch.rand(3,5)
w_value = torch.rand(3,5)

In [3]:
##computing for our query,key and value matrices
query = torch.matmul(inputs, w_query)
key = torch.matmul(inputs, w_key)
value = torch.matmul(inputs, w_value)
print(f"Query:{query}")
print(f"Key:{key}")
print(f"Value:{value}")

Query:tensor([[0.7853, 0.7042, 0.1919, 0.4651, 0.6335],
        [1.2236, 1.0150, 0.3807, 0.9519, 0.8824],
        [0.5073, 0.8091, 0.2301, 0.6112, 0.3424],
        [1.2314, 0.9971, 0.3804, 0.9497, 0.8875],
        [1.0513, 1.0183, 0.3207, 0.8049, 0.7873]])
Key:tensor([[1.0406, 0.5497, 1.1973, 1.2639, 0.4013],
        [1.5090, 0.7888, 1.8532, 1.7368, 0.9145],
        [0.7995, 0.3310, 1.1880, 0.9585, 0.5008],
        [1.5063, 0.7924, 1.8402, 1.7294, 0.9189],
        [1.3882, 0.7024, 1.7238, 1.6450, 0.7225]])
Value:tensor([[0.8222, 0.4466, 0.7402, 0.9640, 0.9085],
        [1.5228, 0.9283, 1.2754, 1.3727, 1.6282],
        [0.9374, 0.6549, 0.8192, 1.1236, 0.9768],
        [1.5193, 0.9229, 1.2689, 1.3464, 1.6250],
        [1.3175, 0.7990, 1.1416, 1.3923, 1.4168]])


* Next we calculate the attention scores

In [4]:
attn_scores = torch.matmul(query,key.T)
print(f"Attention Scores:{attn_scores}")

Attention Scores:tensor([[2.2762, 3.4834, 1.8520, 3.4806, 3.1385],
        [3.8442, 5.8128, 3.1209, 5.8050, 5.2712],
        [2.1581, 3.2049, 1.7041, 3.2004, 2.9221],
        [3.8416, 5.8109, 3.1214, 5.8031, 5.2692],
        [3.3711, 5.1021, 2.7245, 5.0962, 4.6205]])


* Next we compute the attention weights by using the formula below:
    * The formula for the attention weights is given by:


$\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$.

* Where Q represents query,K represents key, and dk represent output dimension.

In [5]:
d_k = key.shape[-1]
attn_weights = torch.softmax(attn_scores/d_k**0.5,dim=-1)
print(f"Attention Weights:{attn_weights}")

Attention Weights:tensor([[0.1486, 0.2551, 0.1230, 0.2547, 0.2186],
        [0.1186, 0.2860, 0.0858, 0.2850, 0.2245],
        [0.1559, 0.2490, 0.1273, 0.2485, 0.2194],
        [0.1186, 0.2860, 0.0859, 0.2850, 0.2245],
        [0.1277, 0.2770, 0.0957, 0.2763, 0.2233]])


* To get the context vector we do a dot product between attention weights and values matrices

In [6]:
context_vector = torch.matmul(attn_weights,value)
print(f"Context Vector:{context_vector}")

Context Vector:tensor([[1.3009, 0.7934, 1.1089, 1.2789, 1.3941],
        [1.3424, 0.8171, 1.1409, 1.2998, 1.4386],
        [1.2932, 0.7887, 1.1030, 1.2750, 1.3859],
        [1.3424, 0.8171, 1.1409, 1.2998, 1.4385],
        [1.3305, 0.8103, 1.1317, 1.2938, 1.4259]])


In [8]:
##implementing a compact self-attention class
import torch.nn as nn
class SelfAttention_v1(nn.Module):
  def __init__(self,d_in,d_out):
    super().__init__()
    self.w_query = nn.Parameter(torch.rand(d_in,d_out))
    self.w_key = nn.Parameter(torch.rand(d_in,d_out))
    self.w_value = nn.Parameter(torch.rand(d_in,d_out))

  def forward(self,x):
    keys = x @ self.w_key
    queries = x @ self.w_query
    values = x @ self.w_value
    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
    context_vec = attn_weights @ values
    return context_vec

In [10]:
##use case for the class
d_in = inputs.shape[-1] # or 3
d_out = 5 # As used in previous manual calculations
sa_v1 = SelfAttention_v1(d_in,d_out)
context_vector = sa_v1(inputs)
print(f"Context Vector:{context_vector}")

Context Vector:tensor([[1.1028, 1.0398, 0.9905, 1.1706, 1.2588],
        [1.1476, 1.0834, 1.0098, 1.2004, 1.2930],
        [1.0974, 1.0344, 0.9883, 1.1669, 1.2544],
        [1.1480, 1.0837, 1.0099, 1.2006, 1.2933],
        [1.1326, 1.0687, 1.0034, 1.1903, 1.2813]], grad_fn=<MmBackward0>)


* We can improve `SelfAttention_v1` implementation further by utilizing Pytorch's `nn.Linear` layers, which effectively perform matrix multiplications when the bias units are disabled.

In [13]:
## self-attention class using pytorch's
class SelfAttention_v2(nn.Module):
 def __init__(self,d_in,d_out,qkv_bias=False):
  super().__init__()
  self.w_key = nn.Linear(d_in,d_out,bias=qkv_bias)
  self.w_query = nn.Linear(d_in,d_out,bias=qkv_bias)
  self.w_value = nn.Linear(d_in,d_out,bias=qkv_bias)

 def forward(self,x):
   keys  = self.w_key(x)
   queries = self.w_query(x)
   values = self.w_value(x)
   attn_scores = queries @ keys.T
   attn_weights = torch.softmax(
      attn_scores/keys.shape[-1]**0.5,dim=-1
  )
   context_vec = attn_weights @ values
   return context_vec

In [14]:
##use case
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in,d_out)
context_vector = sa_v2(inputs)
print(f"Context Vector:{context_vector}")

Context Vector:tensor([[ 0.6918,  0.4175, -0.7206, -0.2338,  0.2625],
        [ 0.6856,  0.4112, -0.7157, -0.2336,  0.2580],
        [ 0.6935,  0.4199, -0.7217, -0.2341,  0.2648],
        [ 0.6856,  0.4111, -0.7156, -0.2336,  0.2579],
        [ 0.6878,  0.4134, -0.7174, -0.2337,  0.2595]], grad_fn=<MmBackward0>)


##1.3 Causal Attention

* For many LLM tasks, you will want the self-attention mechanism to consider only tokens that appear prior to the current position when predicting the next token in a sequence.
* Causal attention restricts a model to only consider previous and current inputs in a sequence when processing any given token when computing attention scores.

In [15]:
##applying a causal attention
#1 computing the attention weights using softmax
queries = sa_v2.w_query(inputs)
keys = sa_v2.w_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5,dim=-1)
print(attn_weights)

tensor([[0.2090, 0.1883, 0.2181, 0.1878, 0.1968],
        [0.2184, 0.1807, 0.2265, 0.1803, 0.1941],
        [0.2103, 0.1907, 0.2120, 0.1908, 0.1962],
        [0.2184, 0.1806, 0.2266, 0.1802, 0.1941],
        [0.2151, 0.1834, 0.2238, 0.1830, 0.1948]], grad_fn=<SoftmaxBackward0>)


In [16]:
#masking values along diagonals
context_length = attn_scores.shape[0]
mask = torch.tril(torch.ones(context_length,context_length))
print(mask)

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])


In [18]:
#multiplying mask with attention weights
masked = attn_weights*mask
print(masked)

tensor([[0.2090, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2184, 0.1807, 0.0000, 0.0000, 0.0000],
        [0.2103, 0.1907, 0.2120, 0.0000, 0.0000],
        [0.2184, 0.1806, 0.2266, 0.1802, 0.0000],
        [0.2151, 0.1834, 0.2238, 0.1830, 0.1948]], grad_fn=<MulBackward0>)


In [19]:
## renormalize the attention weights to sum up to to 1
row_sums = masked.sum(dim=1,keepdim=True)
masked_norm = masked /row_sums
print(masked_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5472, 0.4528, 0.0000, 0.0000, 0.0000],
        [0.3431, 0.3111, 0.3458, 0.0000, 0.0000],
        [0.2710, 0.2242, 0.2812, 0.2236, 0.0000],
        [0.2151, 0.1834, 0.2238, 0.1830, 0.1948]], grad_fn=<DivBackward0>)


* The softmax function converts its inputs into a probability distribution.
* When negative infinity values are present in a row, the softmax function treats them as zero probability.

In [20]:
mask = torch.triu(torch.ones(context_length,context_length),diagonal=1)
masked = attn_scores.masked_fill(mask.bool(),-torch.inf)
print(masked)

tensor([[-0.3166,    -inf,    -inf,    -inf,    -inf],
        [-0.5420, -0.9652,    -inf,    -inf,    -inf],
        [-0.3316, -0.5503, -0.3142,    -inf,    -inf],
        [-0.5402, -0.9646, -0.4576, -0.9698,    -inf],
        [-0.4812, -0.8374, -0.3923, -0.8418, -0.7029]],
       grad_fn=<MaskedFillBackward0>)


In [21]:
##applying softmax
attn_weights = torch.softmax(masked,dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6042, 0.3958, 0.0000, 0.0000, 0.0000],
        [0.3545, 0.2848, 0.3607, 0.0000, 0.0000],
        [0.2949, 0.1929, 0.3203, 0.1919, 0.0000],
        [0.2330, 0.1632, 0.2547, 0.1625, 0.1867]], grad_fn=<SoftmaxBackward0>)
