# **Example 6.2.3 (Multihead attention implementation)**

 This tutorial is based on the following [link](https://medium.com/@fareedkhandev/create-gpt-from-scratch-using-python-part-1-bd89ccf6206a) and the detailed video link by Andrej Karpathy is here: [video link](https://youtu.be/kCc8FmEb1nY)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Getting the libraries

In [None]:
# import libraries

import torch
import torch.nn as nn
from torch.nn import functional as F
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive




For the ease of understanding the concept we are using random data as the input.

Note that, here
* _B_ is the batch size
*  _T_ is the token size
* _C_ corresponds to channel

### Building a self attention head from scratch

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.

- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to **positionally encode** tokens.

- Each example across batch dimension is of course processed completely independently and never "talk" to each other


Here it is assumed that the future nodes doesn't communicate with the past. This is made sure by using the triangular matrix "tril". But we might have application where the communication might be required to happen across all the nodes. For example, sentiment classification where you need to understand the context of the entire sentence which means you need to communicate all the ways.

- In an "encoder" attention block just remove the masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.

- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)

Here in the code block you can see :
```python
k = key(x)   
q = query(x)
v = value(x)
```
all the q,k,v comes from x.


**Cross attention** : q comes from x but k and v might come from some external blocks. For example, in transformers in decoder block query comes from x, however, k and v comes from the encoder block or the adjacent nodes. We are basically reading infromation from the side. In this scenarios, we like to get information from external source.

- "Scaled" attention additional divides `wei` below by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [None]:
# version 4: self-attention with value!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)

# all the queries will multiply with all the keys
wei = q @ k.transpose(-2,-1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
v = value(x)
out = wei @ v
#out = wei @ x
print(wei[0])

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)


In [None]:
print(v.size())
print(wei.size())

torch.Size([4, 8, 16])
torch.Size([4, 8, 8])


In [None]:
print(tril)

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])




Now let's look at the above code: instead of performing just an average of information from previous token and current token, we are using a triangular matrix to control the flow of information.

This lower triangular matrix helps mask out the weights that we create. Then we normalize it. Once we apply infinities to the other masked regions and apply softmax, we get a uniform matrix.

We don't want it to be all Uniform. Maybe we are looking for some interesting relationships which can be translated into this _wei_ matrix. For example, a vowel character might be more interested in the consonant character from the past. This information needs to flow from the past to the present in a data dependent manner.



### Query and Key

Every single token at each position emit two vectors: Query (Q) and Key (V).
- Query (Q) : what am I looking for?
- Key (K) : What do I contain?

The way we get the affinities between these two vectors is to perform a dot product of Q with K. This dot product would be _wei_ as this gives information about that specific token with respect to any other token in the sequence. For example, if the query and the key are well aligned then the _wei_ would interact to high amount which makes sure the right importance is given to these tokens based on Q and K. Lets implement this.

Lets understand the importnace of query and key using an example:

Take the last element of _wei[0]_, we can see that it is **0.2391**. So basically it is trying to communicate with a **query**, say for example:

**_I am a vowel (it knows what content it has) at the 8th position (it knows its position too) and I need some information related to consonants till 4th position._**

All the nodes would then emit **keys**. One of the channel would be saying that "I am a consonant and I am occuring in a position within the 4th position of past token" that key will have higher number in the channel. So the dot product of Q and K will have higher affinity in this case. So say when they have a higher value, for example, in this case highlighted in yellow in the above figure. So, since it has a higher value, lot of information will be aggregated to this particular position and hence learn a lot about it.

This is what the _wei_ represents. Hence, through this matrix the sentence context is being learnt. It says how much of information needs to be aggregated from the tokens in the past.

## Self attention head
Now lets create a self-attention head class out of the above code:

In [None]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout) #This is to increase performance

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

In the above class we have just implemented a single scaled dot-product self attention. Now lets see how we can add a multihead attention block.

**Multihead attention**: It is just applying multiple self-attentions in parallel and concatenating the results.

Basically this is like group convolution so instead of doing 1 big convolution operation we are splitting into n convolutions and adding them or concatenating them. So if we have $n_{embd} = 32$, we are just splitting the self-attention into 4 heads whioch would give 8 dimensional self-attention. Earlier we had 1 communication channel. But now we have 4 communication channel. Since we have 4 communication channel we now want 8-dimensional self attention to get the 32 number.

So keeping this in mind we can implement the multihead attention class like below.

In [None]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return out

Here we are implementing multihead attention in a BigramLanguage model which is a basic model thjat predicts each words probability based on the previous word. Bigrams means the pair of consecutive words. The class for the Bigram Language model is implemented below.

Note that, we are adding token embedding layer and positional encoding layer as they are crucial to give much broader information related to the tokens. Also, the multihead attention layer is used here with 4 heads.

In [None]:
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = MultiHeadAttention(4,n_embd//4) # i.e., 4 heads of 8-dimensional self-attention
        self.lm_head = nn.Linear(n_embd,vocab_size)

    def forward(self, idx, targets=None):
        B,T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        token_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = token_emb + pos_emb # both position identities and token identities #(B,T,C)
        x = self.sa_head(x)    # apply one head of self-attention, (B,T,C)
        x2=x
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)


        return logits, loss, x2

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss, x = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


We are using a "poem.txt" file to perform this experiment. Feel free to replace it with other text files to explore more related to the results.

We are defining the batch size as 32 and the maximum context length for prediciton as 8 (block_size). number of embedding layer = 32. The iterations for which the model runs is 10000.

* ### Read the input file and perform Tokenization
Once we read the input text file, each of the unique characters is identified to help with tokenization. A simple tokenization is done using the number of unique characters in the text file.

* ### Train-test splitting
The tokenized data is then split into train and test groups.

* ### Breaking data into chunks/ blocks:
We use _get_batch(split)_ function to generate small batches of data of inputs x and targets y using batch size and block size

* ### Estimate loss
This function is used to estimate the loss during training and test.

* ### Model training
Next we create the model and intialize the _AdamW_ optimizer. We use the estimate loss function and the optimizer to optimize and evaluate the loss on train and val sets.

* ### Evaluation using model
Once we have the model trained, it is used to decode on test set and generate around 500 tokens.


In [None]:
# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 50500
eval_interval = 10000
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32
n_head = 4
#n_layer = 6
dropout = 0.2

# ------------
torch.manual_seed(1337)

with open('/content/drive/MyDrive/DL_Book_Notebooks/Chapter 6: Attention Networks and transformers/Data/Rime_ancient_poem2.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss(): #It computes the loss function as the CE between the actual and the predicted token
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss, x = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

model = BigramLanguageModel()
m = model.to(device)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss, x = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)


print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

In [None]:
xb.size()