# CODING GPT FROM SCRATCH

This Notebook follows [Code GPT From Scratch](https://www.youtube.com/watch?v=kCc8FmEb1nY), by Andrej Karpathy. This is my implementation, with my explanations, to better understand the way transformers work. 

The goal of this notebook is to implement a ***G***enrative ***P***retrained ***T***ransformer ***(GPT)***, that will generate Shakespeare like text.

## 1. Preparing the Dataset
In this projects, we are going to use [Tiny Shakespeare Dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).

To prepare the dataset for processing, 

1. Download shakespeare data
2. Create tokenizer - what are the distinct charerchters we have?
3. create the enconder -> charecter to int
   create the decoder -> int back to charerechter.
4. show the first 1000 tokens in the dataset.
5. define block size
   define batch size
6. set the torch.seed to 1337 torch.seed(1037)
7. split to 90%-10% train/val
8. randomly sample the data.
   Create the batch size
   
   
 1. The initial loss should be that of a uniformal random variables:
    loss_i = -ln(1/65) [65 charecters/tokens] 

### 1.1 Download the data

In [None]:
#!wget(https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt)
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

### 1.2 Build the Tokenizer, Encoder and Decoder
Our goal is to have the ability to convert a string of text, into a sequence of integers, and back.
To do this, we first need to create a tokenizer, that will map each character to an integer. This is done by going over the entire dataset, and building the vocabulary: the set of all unique characters in the dataset.

After that, we will create the encoder - that maps a string of text into a sequence of numbers, and the decoder - that maps a sequence of numbers bak to a string of text.

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('Vocabulary size (num of unique chars):', vocab_size)

# create a mapping from character to index and vice versa
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

# create the encoder (ch -> idx) and decoder (idx -> ch) functions
encode = lambda s: [stoi[c] for c in s] #encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) #decoder: take a list of integers, output a string

print(encode("hi there"))
print (decode(encode("hi there")))

### 1.3 Train and Test Split

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:

data = torch.tensor(encode(text), dtype=torch.long)
n = int(len(data)*0.9)
train_data, val_data = data[:n], data[n:]

# lets print the first 200 elements of data. This is what our data looks like.
print(data[:200])

## 2. Feeding Data Into The Network

We will never feed the entire dataset into the network. Instead we feed it with a sequence, up to a fixed block size. 

### 2.1. The Idea of Input `Context`, and `Block_Size` 
We are going to be using `B-Gram Model`: A very simple model, that predicts the `next 1 word` based on a given sequence. The input sequence is called the `context`. The maximum length `context`us called `block_size`. We want the model predict the next word, given `context` of length `[1..block_size]`. Later on, this will be usefull for the transformer network to predict sequences of any given length up to `block_size`

In [None]:
block_size = 8
print(train_data[:block_size+1])

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
print(f"The input tensor is: {x}")
print(f"The target tensor is: {y}")
print("So that...")
for t in range(block_size):
    context = x[:t+1]
    target =  y[t]
    print(f"For input {context}, the output is: {target}")

## 2.2. Processing Input Batch

When `context` size is `[1,...,block_size]`; `block_size=8` We have `8` different input contexts, with corresponding `8` different output targets. In other words, `8` training examples. To improve training speed, we will be adding the `batch` dimension. We will be randomly sampling `batch_size` samples from the training data, and use them for training.

In [None]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split=='train' else val_data
    # Sample batch_size random samples from data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix]).to(device)
    y = torch.stack([data[i+1: i+ 1 + block_size] for i in ix]).to(device)
    
    return x, y

In [None]:
xb, yb = get_batch(split="train")

print("input batch:")
print(f"shape: {xb.shape}")
print(xb)

print("target batch:")
print(f"shape: {yb.shape}")
print(yb)

There are 4 batches. 
Each batch has 8 different contexts (`1,..,block_size`).
There are 8*4=32 different input-target pairs in every batch.

In [None]:
for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"For input tensor: {context}, the target is: {target}")
    print ("-----")

## 3. Building the B-Gram Langunage Model

As described before, this is a very simple language model. The target is to predict the next 1st word based on a given sequence.

In [None]:
class BigramLanguageModel(nn.Module):
    """ Bigram Language Model.
        Args:
            vocab_size (int): Number of tokens (size of vocabulary)
    """

    def __init__(self, vocab_size):
        super().__init__()
        # Each token directiry reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) # B,T,C
        # cross_entropy expects the input (logits) to be of shape (B, C, T)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C),  targets.view(B*T))

        return logits ,loss

    def generate(self, idx, max_new_tokens):
        """ Generate new tokens, given current context.
            In a simple B-gram model, we just need to look up the logits for the next token.
            In more complex models, we are also intrested in the history of the tokens. In other
            words, a larger context.

            Args:
                idx (Tensor): Initial context
                max_new_tokens (int): Maximum number of new tokens to generate.
            Returns:
                Tensor: Generated tokens

        """

        # idx is a tensor of shape (B, T)
        idx.to(device)
        for i in range(max_new_tokens):
            # get the prediction for the current context
            logits, loss = self.forward(idx) # logits shape B, T, C
            # In B-gram model, focus only on the last time step.
            logits = logits[:, -1, :]   
            # convert logits to probabilities
            probs = F.softmax(logits, dim=-1)
            # sample the next token from the distribution
            next_token = torch.multinomial(probs, num_samples=1).to(device)
            idx = torch.cat([idx, next_token], dim=1) # shape B, T+1

        return idx
            
            


model = BigramLanguageModel(vocab_size)
model.to(device)
out, loss = model(xb, yb)
print(out.shape)
print(loss)

# We expect the loss to be around -log(1/vocab_size) = -log(1/65) = 4.17
print(f"Expected loss for uniform distribution: {-torch.log(torch.tensor(1/vocab_size))}")

# Example - generate new tokens
print("\n")
print("Example - generate new tokens")
print("(Model is not trained yet, so it will generate random tokens)")
idx = torch.zeros((1,1), dtype=torch.long).to(device)
new_tokens = model.generate(idx, 100)[0].tolist()

# decode the generated tokens
gen_txt = decode(new_tokens)
print(gen_txt)


### 3.1 Training the B-Gram Model
Without training the model, the results are random. The loss is roughly `-ln(1/vocab_size)`. To improve the model, we will train it with `AdamW` optimizer, and implement a simple training loop. 

In [None]:
# define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

The code block below, is used to estimate the loss. Instead of logging the loss of every epoch - which is very noisy, it will average the loss over `eval_iters` epochs. We shall run this block on both training and validation data. While training, it is important to see that the training and validation loss are decreasing. If the validation loss is increasing, it is a sign of overfitting.

In [None]:
eval_iters = 200
@torch.no_grad()
def estimate_loss():
    # estimate the loss on training and validation sets
    out = {}
    losses = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters).to(device)
        for k in range(eval_iters):
            xb, yb = get_batch(split)
            logits, loss = model(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out
        

In [None]:
# define the training loop
num_epochs = 10000
batch_size = 32
eval_interval = 100

model = model.to(device)
for step in range(num_epochs):
    
    if step % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step: {step}, Train Loss: {losses['train']}, Val Loss: {losses['val']}")

    # sample a batch of data
    xb, yb = get_batch(split="train")
    # forward pass
    logits, loss = model(xb, yb)

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(f"Loss = {loss.item()}") # this prints the loss of the last batch

### 3.2. Generating Text From Trained B-Gram Model

In [None]:
idx = torch.zeros((1,1), dtype=torch.long, device=device)
new_tokens = model.generate(idx, 500)[0].tolist()

# decode the generated tokens
gen_txt = decode(new_tokens)
print(gen_txt)

Now, looking at the result, we can see that the model has learned to generate text that looks like Shakespeare. However, the result is of low quality. This is because model is still quite simple and can be improved. To improve the model, we would like to learn the inter-relations between tokens in the sentence, even if they are far apart. This is where the **transformer** model comes in.

## 4. Building the Transformer Model
Up until now, using only the `Bigram Language Model`, we could not learn the inter relations between tokens. To predict the next token, we only used the previous token. **The transformer model, on the other hand, can learn the inter-relations between tokens, even if they are far apart.**

### 4.1 The Attention Module - Learning the Relations Between Tokens

In a language model, we would like to predict the next token, based only of preivous (context) tokens. This is different from models used for images, where we would like to learn the inter relations between one token, and the entire set of tokens representing the image.

The attention module, which is the at the core of the transformer model, is the building block used to learn and represent the inter-relations between tokens. When the attention module is applied to a sequence of tokens coming from the same input sequence, it is called **self-attention**. When the attention module is applied to two different sequences of tokens, it is called **cross-attention**. 


### 4.2 Preliminaries - Represting the Information of Previous Tokens as Average
As a first step, we would like to develop a method to represent the information of previous tokens. The most simple way to do that, would be to average the embeddings of all previous tokens.

Lets look a the next toy example:
Lets say our `batch` size is `4`, our `block_size` is `8`, and out `vocab_size` is 2.

In [None]:
import torch
torch.manual_seed(1337)
B,T,C = 4,8,2 #batch size, time, channels

x = torch.randn(B,T,C)
x.shape

In [None]:
# We want to create a new tensor, where each element in the T dimension,
# is the mean of all elements before it:
# x[b,t] = mean_{i<=t} x[b,i]

xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b,t] = torch.mean(xprev, 0)

print(x[0])
print(xbow[0])

#### 4.2.1 A Mathematical Trick

Although the above method outputs the desired averaging, because it uses `for` loops, it is very inefficient. To improve the efficency, we will use matrix multiplication.

A multipication of every matrix with a square matrix of size `nxn`, filled with `1`, results in the summation of the rows.

So given the matrices:

In [None]:
a = torch.ones(3,3)
b = torch.randint(0,10, (3,2)).float()

print('a=')
print(a)
print('b=')
print(b)

The running sum of the rows of `b` is given by:

In [None]:
c = a@b
print('c=')
print(c)

Now, because matrix `a` if a full mask of `1`, the result of all the rows in `c` is the same (we sum all rows in `b`).
If instead, matrix `a` will be a lower triangular matrix:

In [None]:
tril = torch.tril(torch.ones(3,3))
print(tril)

Every element in `c` will be the sum of all previous elements in the same column in `b`:

In [None]:

c = tril@b
print('b=')
print(b)
print('c=')
print(c)

Finally, becuase we want to average all the preivous elements, we can normalize matrix `a`, so that the sum of each row is `1`

In [None]:
tril = torch.tril(torch.ones(3,3))
a = tril / torch.sum(tril, dim=1, keepdim=True)
print('a=')
print(a)

So now, multiplying `a` by `b` will give us the desired result:

In [None]:
print('b=')
print(b)

c = a@b
print('c=')
print(c)

Returning to the original `xbow`, whith size `(B,T,C)`, we would like to do the same trick, with batch multiplcation. 
So, if we have a tensor of size `(T,T)` and we multiply it by a tensor of size `(B,T,C)`, pytorch will autmalically broadcast the first tensor to size `(B,T,T)`, and multiply it by the second tensor.

`(B,T,T) x (B,T,C) = (B,T,C)`

In [None]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei@x

print ("xbow and xbow are the same. For example")
print("xbow[0,:3] ")
print(xbow[0,:3])
print("xbow2[0,:3] ")
print(xbow2[0,:3])


#### 4.2.2. Using Softmax
When we train self attension module, we would like the weights to be learned. Thats because not every token will be equally important. So instead of using averaging, we are going to use `softmax` normalization.

We start by initializing all of our weights to `0`.

In [None]:
wei2 = torch.zeros(T, T)
wei2

Because we want to aggerate only past information, we will set all future elements to `-inf`. This will make the softmax of all future elements to `0`.

In [None]:
tril = torch.tril(torch.ones(T, T))
wei2 = wei2.masked_fill(tril==0, float('-inf'))
wei2

Now we are going to apply softmax function, and we will get the desired result.

In [None]:
wei2 = F.softmax(wei2, dim=1)
wei2

And the new `xbow` will be:

In [None]:
xbow3 = wei2@x
print("Are xbow2 and xbow3 the same: ", torch.allclose(xbow2, xbow3))

### 5. Building the Self-Attention Block
In the above section, we represented the information stored in previous tokens as an average of their embeddings. As seen from `xbow` matrices, this causes the embeddings to have a similar weight. 
Instead, we need each of the embedding to have a different weight based on the current embedding. For example, if the current embedding is a noun, we would like the weights of the other embeddings to be higher if they are verbs, and lower if they are adjectives.

To do this, we propose two more vectors. The `query` vector `Q`, and the `key` vector `K`. 
Given the current token, the `Q` vector represents `what am I looking for`. The `K` vector represents `what do I have`. 
Analogous to the noun example, if the current token is a noun, the `Q` vector will have high weights where it expects verbs to be, and low weights where it expects adjectives to be. The `K` vector will have high weights where verbs are, and low weights where adjectives are.

`Q` and `K` are learnable vectors. 

In [None]:
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

x = torch.randn(B, T, C)

k = key(x)  #(B,T,16)
q = query(x) #(B,T,16)

wei = q@k.transpose(-2, -1) # (B,T,16) @ (B,16,T) --> (B, T, T)

wei[0]

Now, because we don't want to aggeragate any information from future tokens, we are going to use the same trick as before - use a lower triangular matrix and set all future elements to `-inf`.

In [None]:
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei[0]

The last thing we need to do is to multiply the weights with the value `V` of the current token - also a learnable vector - and normalize with the size of the embedding.
The normalization is important to keep the variance of a normal distribution to be `1`, so that the weights don't become one hot vectors.

In [None]:
value = nn.Linear(C, head_size, bias=False)
v = value(x) #(B,T,16)

tril = torch.tril(torch.ones(T, T))
wei = head_size**(-0.5)*wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

out = wei@v
v[0]

#### **So, finally here is the attention module:**

In [None]:
class SelfAttentionHead(nn.Module):
    """One head of self attention"""

    def __init__(self, n_embd, head_size, block_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        wei= q @ k.transpose(-2, -1) * C**(-0.5)
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        
        out = wei @ v
        return out

### 6. Integrating the Self-Attention Module into the Bigram Language Model
Now, lets add the self-attention module to the bigram model, so that it will learn the inter-relations between tokens. 
To do this, we first reduce the number of embeddings in the token embedding layer, and instead add a linear layer after the embedding layer. This will allow the model to learn more complex patterns in the data and increase stability.

Second, we add a positional encoding to the input embeddings. This is important because it helps the model learn the relations between tokens also based on their position in the sequence.

In [None]:
class BigramLanguageModelSignleAttention(nn.Module):
    """ Bigram Language Model.
        Args:
            vocab_size (int): Number of tokens (size of vocabulary)
    """
 
    def __init__(self, vocab_size, n_embd=24):
        super().__init__()
        # Each token directiry reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.pos_embedding_table = nn.Embedding(block_size, n_embd)
        self.atten_head = SelfAttentionHead(n_embd, head_size=16, block_size=block_size)
        self.head = nn.Linear(head_size, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embd = self.token_embedding_table(idx) # (B,T,n_embd)
        pos_embd = self.pos_embedding_table(torch.arange(T).to(device)) # (T,n_embd)
        x = token_embd + pos_embd   # adding positional embedding to token embedding
        x = self.atten_head(x) # apply one head of attention
        logits = self.head(x) # (B,T,vocab_size)
        # cross_entropy expects the input (logits) to be of shape (B, C, T)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C),  targets.view(B*T))

        return logits ,loss

    def generate(self, idx, max_new_tokens):
        """ Generate new tokens, given current context.
            In a simple B-gram model, we just need to look up the logits for the next token.
            In more complex models, we are also intrested in the history of the tokens. In other
            words, a larger context.

            Args:
                idx (Tensor): Initial context
                max_new_tokens (int): Maximum number of new tokens to generate.
            Returns:
                Tensor: Generated tokens

        """

        # idx is a tensor of shape (B, T)
        idx.to(device)
        for i in range(max_new_tokens):
            idx_cond = idx[:, -block_size:] # only consider the last block_size tokens
            # get the prediction for the current context
            logits, loss = self.forward(idx_cond) # logits shape B, T, C
            # In B-gram model, focus only on the last time step.
            logits = logits[:, -1, :]   
            # convert logits to probabilities
            probs = F.softmax(logits, dim=-1)
            # sample the next token from the distribution
            next_token = torch.multinomial(probs, num_samples=1).to(device)
            idx = torch.cat([idx, next_token], dim=1) # shape B, T+1

        return idx
            
            


model = BigramLanguageModelSignleAttention(vocab_size)
model.to(device)
out, loss = model(xb, yb)
print(out.shape)
print(loss)

# We expect the loss to be around -log(1/vocab_size) = -log(1/65) = 4.17
print(f"Expected loss for uniform distribution: {-torch.log(torch.tensor(1/vocab_size))}")

#### Training the Model

Lets see if the model has improved.

In [None]:
# define the training loop
num_epochs = 10000
batch_size = 32
eval_interval = 100
lr = 1e-3

model = BigramLanguageModelSignleAttention(vocab_size)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
for step in range(num_epochs):
    
    if step % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step: {step}, Train Loss: {losses['train']}, Val Loss: {losses['val']}")

    # sample a batch of data
    xb, yb = get_batch(split="train")
    # forward pass
    logits, loss = model(xb, yb)

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

losses = estimate_loss()
print(f"Step: {step}, Train Loss: {losses['train']}, Val Loss: {losses['val']}")

Previously we got the following loss:

`Step: 9900, Train Loss: 2.461491346359253, Val Loss: 2.5023200511932373`

After training the model with the self attention block, we get the following loss:

`Step: 9900, Train Loss: 2.3991994857788086, Val Loss: 2.40938138961792`

So we got some improvement.

### 7. Further Improvement: Multi-Head Attention

A single attention head can only learn a single type of inter relations between tokens. To learn multiple types of inter-relations, we can use multiple attention heads in parallel. This is called `multi-head attention`.


![multi-head-attention](assets/attention_heads.png)

To create a multi-head attention model, all we have to do is to create multiple `Q`, `K`, and `V` vectors, and apply the self-attention module to each of them. Each of the single attention heads will be concatenated, and passed through a linear layer to reduce the dimensionality back to the original size. 

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head self attention"""

    def __init__(self, n_heads, n_embd, head_size, block_size):
        super().__init__()        
        self.heads = nn.ModuleList([SelfAttentionHead(n_embd, head_size, block_size) for _ in range(n_heads)])
        
    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        return out

Now lets integrate the multi-head attention into our model.
The important thing to note is reduce the number of embeddings according to the number of heads, so once they are concatenated they will have the same size as the original embeddings.

In [None]:
class BigramLanguageModelMultiHeadAttention(nn.Module):
    """ Bigram Language Model.
        Args:
            vocab_size (int): Number of tokens (size of vocabulary)
    """
 
    def __init__(self, vocab_size, n_embd=32, n_heads=4):
        super().__init__()
        # Each token directiry reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.pos_embedding_table = nn.Embedding(block_size, n_embd)
        head_size = n_embd // n_heads
        self.multi_att_head = MultiHeadAttention(n_heads=n_heads, n_embd=n_embd, head_size=head_size, block_size=block_size)
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embd = self.token_embedding_table(idx) # (B,T,n_embd)
        pos_embd = self.pos_embedding_table(torch.arange(T).to(device)) # (T,n_embd)
        x = token_embd + pos_embd   # adding positional embedding to token embedding
        x = self.multi_att_head(x) # apply one head of attention
        logits = self.head(x) # (B,T,vocab_size)
        # cross_entropy expects the input (logits) to be of shape (B, C, T)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C),  targets.view(B*T))

        return logits ,loss

    def generate(self, idx, max_new_tokens):
        """ Generate new tokens, given current context.
            In a simple B-gram model, we just need to look up the logits for the next token.
            In more complex models, we are also intrested in the history of the tokens. In other
            words, a larger context.

            Args:
                idx (Tensor): Initial context
                max_new_tokens (int): Maximum number of new tokens to generate.
            Returns:
                Tensor: Generated tokens

        """

        # idx is a tensor of shape (B, T)
        idx.to(device)
        for i in range(max_new_tokens):
            idx_cond = idx[:, -block_size:] # only consider the last block_size tokens
            # get the prediction for the current context
            logits, loss = self.forward(idx_cond) # logits shape B, T, C
            # In B-gram model, focus only on the last time step.
            logits = logits[:, -1, :]   
            # convert logits to probabilities
            probs = F.softmax(logits, dim=-1)
            # sample the next token from the distribution
            next_token = torch.multinomial(probs, num_samples=1).to(device)
            idx = torch.cat([idx, next_token], dim=1) # shape B, T+1

        return idx
            
            


model = BigramLanguageModelMultiHeadAttention(vocab_size)
model.to(device)
out, loss = model(xb, yb)
print(out.shape)
print(loss)

# We expect the loss to be around -log(1/vocab_size) = -log(1/65) = 4.17
print(f"Expected loss for uniform distribution: {-torch.log(torch.tensor(1/vocab_size))}")

Now lets train the new model and see if we have any improvement.

In [None]:
# define the training loop
num_epochs = 10000
batch_size = 32
eval_interval = 100
lr = 1e-3

model = BigramLanguageModelMultiHeadAttention(vocab_size)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
for step in range(num_epochs):
    
    if step % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step: {step}, Train Loss: {losses['train']}, Val Loss: {losses['val']}")

    # sample a batch of data
    xb, yb = get_batch(split="train")
    # forward pass
    logits, loss = model(xb, yb)

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

losses = estimate_loss()
print(f"Step: {step}, Train Loss: {losses['train']}, Val Loss: {losses['val']}")

And we got further improvement:

`Step: 9999, Train Loss: 2.1877453327178955, Val Loss: 2.239150047302246`.

Lets generate some text and see what we got:

In [None]:
# Example - generate new tokens
idx = torch.zeros((1,1), dtype=torch.long).to(device)
new_tokens = model.generate(idx, 500)[0].tolist()

# decode the generated tokens
gen_txt = decode(new_tokens)
print(gen_txt)

There is still room for improvement, but the model is already generating text with some correct english words (i.e: `he`, `I`, `that`,`making`, `here`)

### 8. The Transformer Model

Now, that we have implemented the multi-head attention module, we can build the transformer model. The strength of the transformer model is that it can learn the inter-relations between tokens, even if they are far apart and complex. 

Lets look at the transformer block, as described in the paper:

![transformer-model](assets/transformer_model.png)

We are now going to break down the transformer model into its components:

#### **Masked Multi-Head Attention**
This is the multi-head attention we have implemented before.

#### **Layer Normalization**
This is a normalization technique similar to batch normalization, but instead of normalizing over the batch dimension, we normalize over the sequence dimension. This is important because we want to keep the inter-relations in the same sequence, and keep the different sequences separate.

In pytorch we can use `nn.LayerNorm` to implement this.

*Note:* In contrast to the paper, it is now common to use layer normaliation in the input of the multi-head attention module, and not in the output.

#### **Skip Connections**
As our network gets deeper, the gradients can become very small, and the training can be very difficult. Skip connections are simply the addition of the input to the output of a layer. This allows the gradients to flow directly to the input of the layer, and at the same time also modify the weights of the layer.

#### **Feed Forward Network**
The feed forward network is a simple linear layer followed by non-linearity. This helps the model learn more complex patterns in the data.

#### **Projection Layer**
The projection layer is at the end of the transformer block, and is used to project the output of the transformer block back to the original size of the embeddings. It is also learnable.

#### **Dropout**
Dropout is a regularization technique used to prevent overfitting. It works by randomly setting a fraction of the input to zero. In pytorch we can use `nn.Dropout` to implement this.


Putting all of this together we get the transformer block:

In [None]:
class FeedForward(nn.Module):
    """A simple linear layer followed by a ReLU activation"""
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(0.2)
        )
        
    def forward(self, x):
        return self.net(x)

In [None]:
class TransformerBlock(nn.Module):
    """A single transformer block"""
    def __init__(self, n_embd, n_heads, block_size):
        super().__init__()
        head_size = n_embd // n_heads
        self.att = MultiHeadAttention(n_heads, n_embd, head_size, block_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.att(self.ln1(x))   # layer norm and skip connection
        x = x + self.ffwd(self.ln2(x))  # layer norm and skip connection
        return x

Now lets build the BigramTransformer model, and train it.

In [None]:
class BigramLanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, n_embd=32, n_heads=4, n_blocks=4):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.pos_embedding_table = nn.Embedding(block_size, n_embd)
        self.block_list = nn.ModuleList([TransformerBlock(n_embd, n_heads, block_size) for _ in range(n_blocks)])
        self.blocks = nn.Sequential(*self.block_list)
        self.ln = nn.LayerNorm(n_embd)  # layer norm before the final head
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embd = self.token_embedding_table(idx) # (B,T,n_embd)
        pos_embd = self.pos_embedding_table(torch.arange(T, device=device)) # (T,n_embd)

        x = token_embd + pos_embd   # adding positional embedding to token embedding
        x = self.blocks(x) # apply transformer blocks (B,T,n_embd)
        x = self.ln(x) # layer norm
        logits = self.head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C),  targets.view(B*T))

        return logits ,loss
    
    def generate(self, idx, max_new_tokens):
        """ Generate new tokens, given current context.
            In a simple B-gram model, we just need to look up the logits for the next token.
            In more complex models, we are also intrested in the history of the tokens. In other
            words, a larger context.

            Args:
                idx (Tensor): Initial context
                max_new_tokens (int): Maximum number of new tokens to generate.
            Returns:
                Tensor: Generated tokens

        """

        # idx is a tensor of shape (B, T)
        idx.to(device)
        for i in range(max_new_tokens):
            idx_cond = idx[:, -block_size:] # only consider the last block_size tokens
            # get the prediction for the current context
            logits, loss = self.forward(idx_cond) # logits shape B, T, C
            # In B-gram model, focus only on the last time step.
            logits = logits[:, -1, :]   
            # convert logits to probabilities
            probs = F.softmax(logits, dim=-1)
            # sample the next token from the distribution
            next_token = torch.multinomial(probs, num_samples=1).to(device)
            idx = torch.cat([idx, next_token], dim=1) # shape B, T+1

        return idx

Now lets update the hyperparameters and train the model.

In [None]:
# hyper parameters
num_epochs = 10000
batch_size = 32
n_embd = 32
eval_interval = 100
lr = 1e-3
n_heads = 4
n_transformer_blocks = 4

model = BigramLanguageModelTransformer(vocab_size, n_embd, n_heads, n_transformer_blocks)
model = model.to(device)

# pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# training loop
for step in range(num_epochs):
    
    if step % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step: {step}, Train Loss: {losses['train']}, Val Loss: {losses['val']}")

    # sample a batch of data
    xb, yb = get_batch(split="train")
    # forward pass
    logits, loss = model(xb, yb)

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

losses = estimate_loss()
print(f"Step: {step}, Train Loss: {losses['train']}, Val Loss: {losses['val']}")

We got further improvement, and the loss dropped to:

`Step: 9999, Train Loss: 1.8907207250595093, Val Loss: 2.0092897415161133`

Lets generate some text and see what we got:

In [None]:
# generate text from the model
context = torch.zeros((1,1), dtype=torch.long).to(device)
new_tokens = model.generate(context, 500)[0].tolist()

# decode the generated tokens
gen_txt = decode(new_tokens)
print(gen_txt)

This looks much better. The words look a lot more like english words, and the model has learned to generate text that looks like Shakespeare.
This not perfect, but it is a good result for a simple transformer model.
