**Causal Attention**: each word(or tokens) only see the past and current, mask is used to block information from future words.

**Self Attention**: looks at all the words, past and future, and calculate scores(attention weights) to decides which are the most important one for understanding the context

**MultiHead Attention**: multiple attention mechanism at a same time.

**Shortcomings of RNNs**
- vanishing gardient problem
- RNN can’t directly access earlier hidden states from the encoder during the decoding phase.
- short term memory: rnns don't work well for long sequences.

* RNNs have to remember the entire encoded input into a fixed-size hidden state before passing it to the decoder (*Information Bottleneck*)

To addresss this: attention mechanism is introduced

- *Bahdanau attention* mechanism -> additive Attention

    ↳ Bahdanau attention lets the decoder focus on the different parts of input at each time step.

    ↳ It assigns different weights (called attention scores) to each encoder state, deciding how important the input token is for the respective output token.

- *Luong attention* mechanism -> multiplicative attention

**Self Attention**
- self in self-attention means -> assesses and learns the relationships & dependencies b/w various parts of the input itself

- the goal of self attention is to have a *context vector(z)* for each element(x) in the input sequence

e.g->  z(2), context vector is an embedding that contains information about x(2) and all other x(i)'s.

#### **SIMPLIFIED SELF ATTENTION**

In [7]:
import tiktoken
tokenizer = tiktoken.get_encoding('gpt2')

text = 'Your journey starts with one step'
ids = tokenizer.encode(text)
print(ids)
print(tokenizer.n_vocab)

[7120, 7002, 4940, 351, 530, 2239]
50257


In [19]:
output_dim = 3 #d

import torch
import torch.nn as nn
torch.manual_seed(12)
embedding_layer = nn.Embedding(tokenizer.n_vocab, 3)
embedding_layer(torch.tensor(ids))

tensor([[ 0.4521,  1.0357,  0.1780],
        [ 1.4295,  0.9807, -0.9021],
        [ 1.5082,  0.8670, -0.1339],
        [ 0.2479,  2.3997,  0.7642],
        [-0.6501,  0.3397,  0.8724],
        [-0.0332, -0.2958,  0.1109]], grad_fn=<EmbeddingBackward0>)

In [20]:
import torch
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], #Your
     [0.55, 0.87, 0.66], #journey
     [0.57, 0.85, 0.64], #starts
     [0.22, 0.58, 0.33], #with
     [0.77, 0.25, 0.10], #one
     [0.05, 0.80, 0.55]] #step
)

query = inputs[1]   # second token as query
attn_scores_2 = torch.empty(inputs.shape[0])  #initialize

# taking dot product of query with each token
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)


tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [22]:
# (0.55*0.43) + (0.87*0.15) + (0.89*0.66)

**dot-product** is used as a measure of similarity because it quantifies how closely two vectors are aligned.

here, the magnitude of attention scores (dot product of query with input tensor) depicts the extent to which query(`input[1]`) **"attends to"** other input elements

higher the `attn_score`, higher the similarity b/w two inputs elements

In [26]:
# normalised attn scores (also called attention weights)
attn_weights_2_tmp= attn_scores_2 / attn_scores_2.sum()

print(f'Attention Weights:\n{attn_weights_2_tmp}')
print(f'Sum: {attn_weights_2_tmp.sum():.4f}')

Attention Weights:
tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: 1.0000


In [27]:
# using softmax for normalisationm
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print(f'Attention Weights:\n{attn_weights_2}')
print(f'Sum: {attn_weights_2.sum()}')

Attention Weights:
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.0


**Context Vector**

context vector, here(`z²`) = embedded input tokens * corresponding attention weights and then summing the resultant vectors

In [28]:
# calculating the context vector(z²)
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


manual calculations

In [29]:
# print(0.1385*0.43, 0.1385*0.15, 0.1385*0.89)
# print(0.2379*0.55, 0.2379*0.87, 0.2379*0.66)
# print(0.2333*0.57, 0.2333*0.58, 0.2333*0.64)
# print(0.1240*0.22, 0.1240*0.58, 0.1240*0.33)
# print(0.1082*0.77, 0.1082*0.25, 0.1082*0.10)
# print(0.1581*0.05, 0.1581*0.80, 0.1581*0.55)

In [30]:
# print(0.059555000000000004+0.13084500000000002+0.132981+0.02728+0.083314+0.007905)
# print(0.020775000000000002+0.206973+0.135314+0.07192+0.02705+0.12648)
# print(0.12326500000000001+0.15701400000000001+0.149312+0.040920000000000005+0.010820000000000001+0.086955)

**Computing all context vectors**

`attention_scores` ->
`attention_weights` -> 
`context_vectors`

In [31]:
# compute attention scores for all input elements

attn_scores = torch.empty(6,6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i,j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [32]:
# attention scores using matmul(or @) as for loop is slow

# attn_scores = torch.matmul(inputs, inputs.T)
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [33]:
# attn weights
attn_weights = torch.softmax(attn_scores, dim =-1) # dim =-1 means normalise across columns
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [34]:
# context vectors
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


#### **IMPLEMENTING SELF-ATTENTION WITH TRAINABLE WEIGHTS**

we have to introduce three trainable weight matrices Wq, Wk, Wv.
these matrices are used to project the embedded inputs tokens x(i) into `query`, `key` and `value` vectors

\ Wq, Wk, Wv -> these matrices are updated during training

In [35]:
inputs, inputs.shape

(tensor([[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]),
 torch.Size([6, 3]))

usually `input_embedding` = `output_embedding`

In [36]:
# computing only one context vector (z²)
x_2 = inputs[1]
d_in = inputs.shape[1]  # input embedding dim
d_out = 2 #output embedding dim

In [37]:
# initialize three weight matrices(query, key and value)
torch.manual_seed(123)
W_q = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_k = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_v = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)  #should use requires_grad = True

In [38]:
#query, key and value vectors (vector as d_out = 2)
query_2 = x_2 @ W_q
key_2 = x_2 @ W_k
value_2 = x_2 @ W_v

print(f'Query vector corresponding to input token x(2):\n{query_2}')

Query vector corresponding to input token x(2):
tensor([0.4306, 1.4551])


In [41]:
# to calculate the context vector (z2), we need to compute all key and value vectors
# all keys, values 

keys = inputs @ W_k
values = inputs @ W_v
print(f'Keys shape: {keys.shape}')
print(f'Values shape: {values.shape}')

Keys shape: torch.Size([6, 2])
Values shape: torch.Size([6, 2])


In [44]:
# keys, values, query_2

In [45]:
# attention scores = query * keys
# attention score for x(2)
keys_2 = keys[1]

attn_score_22 = query_2.dot(keys_2)
attn_score_22

tensor(1.8524)

In [46]:
# all attention scores corresponding to x(2)
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [47]:
# attention weights = scaled attention scores
import math

d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2/ math.sqrt(d_k), dim=-1)
attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

In [48]:
# context vector -> weighted sum over the value vectors
context_vec_2 = attn_weights_2 @ values
context_vec_2  # single context vector for token 2 in the input sequences

tensor([0.3061, 0.8210])

#### Implementing a compact self-attention Python class

In [49]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))  #trainable weights matrices
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self,x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)

        context_vec = attn_weights @ values
        return context_vec

In [50]:
# example
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in=3, d_out=2)

context_vector = sa_v1(inputs)
context_vector

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

Why `nn.Linear`?
- internally initializes weights and bias
- automatically registers gradient updates
- applies linear transformation to the incoming data

In [51]:
# a better way to use nn.Linear

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias= qkv_bias)  # stores W_query as (d_out, d_in)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias= qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)   # internally does x @ W_key.T
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        # print(f'Atteention Weights: {attn_weights}')
        context_vec = attn_weights @ values
        return context_vec

In [52]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)

context_vector = sa_v2(inputs)
context_vector

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

#### **CASUAL ATTENTION**(or masked attention)
**standard self-attention**: access to the entire input sequence at once.

**casual attention**: only previous and current inputs in a sequence, masking future tokens

Way to implement attention mask
1. normalise attention scores using softmax to get attention weights
2. mask with 0's above diagonal to obtain masked attention scores
3. normalise rows to get masked attention weights

In [53]:
# to apply casual mask -> mask elements above the diagonal with 0's.

# 1.
# compute attention weights
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T

attn_weights = torch.softmax(attn_scores / math.sqrt(keys.shape[-1]), dim = -1)
attn_weights

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

In [54]:
# 2 masking values above diagonal to zero
context_length = attn_scores.shape[0]

#.tril -> (l)lower triangular matrix
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [55]:
# multiply this mask with attn_weights 
masked_simple = attn_weights* mask_simple
masked_simple

tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)

In [57]:
# 3. renormalise these attn_weights
row_sum = masked_simple.sum(dim=-1, keepdim= True)
# row_sum
masked_simple_norm = masked_simple / row_sum
masked_simple_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)

a more efficient way of doing the same is to mask attention scores with -inf and apply softmax...as softmax(-inf) = 0 (because exp(-inf) tends to 0)

*masking-trick*
1. mask with 1's above the diagonal
2. replace 1's with -inf
3. apply softmax

In [58]:
# doing this using a 'trick'
#.triu -> upper triangular matrix
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) # diagonal = 1 (to which diagonal to consider)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [59]:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)

In [60]:
# apply softmax
attn_weights = torch.softmax(masked / math.sqrt(keys.shape[-1]), dim=1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

**Masking additional attentional weights with dropout**.

these attn_weights can be used to calculate context vector,

but to prevent overfitting we randomly select some attn_weights and drop them out.

*Dropout*? -> only used during training
we apply dropout after computing attn weights

In [62]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(p=0.5) #dropout rate 50%

#example
example = torch.ones(6,6)
dropout(example)

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])

why 2?

to maintain the scaling, therefore remaining elements are scaled up by a factor of `1/0.5 = 2`

In [63]:
# apply dropout to attn_weights
torch.manual_seed(123)
dropout(attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0335, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6804, 0.0000, 0.0000, 0.0000],
        [0.4889, 0.5090, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3988, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3418, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)

##### Casual Attention Class

In [64]:
# batch
batch = torch.stack((inputs, inputs), dim=0)
batch, batch.shape

(tensor([[[0.4300, 0.1500, 0.8900],
          [0.5500, 0.8700, 0.6600],
          [0.5700, 0.8500, 0.6400],
          [0.2200, 0.5800, 0.3300],
          [0.7700, 0.2500, 0.1000],
          [0.0500, 0.8000, 0.5500]],
 
         [[0.4300, 0.1500, 0.8900],
          [0.5500, 0.8700, 0.6600],
          [0.5700, 0.8500, 0.6400],
          [0.2200, 0.5800, 0.3300],
          [0.7700, 0.2500, 0.1000],
          [0.0500, 0.8000, 0.5500]]]),
 torch.Size([2, 6, 3]))

In [65]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias= qkv_bias)  # stores W_query as (d_out, d_in)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias= qkv_bias)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
            )
        

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2) # keeeping batch_dim as it is
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        attn_weights = torch.softmax(attn_scores/ math.sqrt(keys.shape[-1]), dim=-1)

        context_vec = attn_weights @ values
        return context_vec

why `register_buffer('buffer_name', tensor)`
- part of model but not trainable.
- moves with `.to(device)`.
- persists with the model (stored and loaded with `model.dict()`)

In [66]:
torch.manual_seed(123)
context_length = batch.shape[1]

ca = CasualAttention(d_in, d_out, context_length, dropout=0)
context_vecs = ca(batch)
print(f'Context Vector: {context_vecs}')
print(f'Context Vector Shape: {context_vecs.shape}')

Context Vector: tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
Context Vector Shape: torch.Size([2, 6, 2])


#### **Multi Head Attention**

to run attention mechanism parallely multiple times with different, learned projections

In [67]:
# stacking multiple (here two) self attention modules on top of each other

class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()

        self.heads = nn.ModuleList([CasualAttention
                                    (d_in, d_out, context_length, dropout, qkv_bias) 
                                    for _ in range(num_heads)]
                                    )

    def forward(self, x):
        context_vec = torch.cat([head(x) for head in self.heads], dim = -1)  # concat the two context vector matrices
        return context_vec

In [68]:
torch.manual_seed(123)
context_length = batch.shape[1]  # num of tokens
d_in, d_out = 3,2 

mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)
print(f'Context Vector:\n{context_vecs}')
print(f'Context Vector Shape:\n{context_vecs.shape}')

Context Vector:
tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
Context Vector Shape:
torch.Size([2, 6, 4])


**implementing MHA more efficiently with weights splits**

here rather than stacking multiple casual attention blocks, we split input into multiple heads and then combining these heads after computing attention

In [74]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()

        assert d_out % num_heads == 0

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)  # layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )


    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # .view() -> to reshape tensors
        # split d_out into num_heads, head_dim
        # (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # transpose from shape (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        # attn scores -> dot product of queries and keys for each head
        attn_scores = queries @ keys.transpose(2,3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / math.sqrt(keys.shape[-1]),
                                     dim = -1)

        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1,2)

        # transposing makes tensor non-contiguous
        # therefore before flattening into shape (b, num_tokens, self.d_out) make them contiguous
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)  # self.d_out = self.num_heads * self.head_dim

        context_vec = self.out_proj(context_vec)
        return context_vec


In [77]:
# print(batch)
print(batch.shape)

torch.Size([2, 6, 3])


In [78]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)
print(f'Context Vector:\n{context_vecs}')
print(f'Context Vector Shape:\n{context_vecs.shape}')

Context Vector:
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
Context Vector Shape:
torch.Size([2, 6, 2])


**For Smallest GPT-2 Model**
- `embedding_size` = 768
- `attention_heads` = 12
- and `d_in` = ` d_out`

In [72]:
# mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
# mask.bool()[:3, :3]

In [71]:
# a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],   
#             [0.8993, 0.0390, 0.9268, 0.7388],
#             [0.7179, 0.7058, 0.9156, 0.4340]],

#            [[0.0772, 0.3565, 0.1479, 0.5331],
#             [0.4066, 0.2318, 0.4545, 0.9737],
#             [0.4606, 0.5159, 0.4220, 0.5786]]]])
# a.transpose(2,3)
# a @ a.transpose(2,3)