# ChatGPT-Like Transformer From Scratch Using PyTorch

We will be coding a ChatGPT-like Transformer model in PyTorch from scratch in this notebook. This is also known as a Decoder-Only Transformer.

This notebook is based off a tutorial from [StatQuest](https://www.youtube.com/watch?v=C9QSpl5nmrY&list=PLblh5JKOoLUIxGDQs4LFFD--41Vzf-ME1&index=30)

## 1. Set Up

In [33]:
#!pip install torch
#!pip install lightning

In [34]:
import torch
import torch.nn as nn 
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

## 2. Generate Dataset

To keep things simple, we'll just want to train the Transformer to respond to two different prompts:
1. What is StatQuest
2. StatQuest is what?

And we will want the answers to be "Awesome". 

This means that the dictionary (aka tokens) will be:
- what
- is
- statquest
- awesome

In [35]:
# dictionary that maps vocabulary to id numbers
token_to_id = {'what': 0,
                'is': 1,
                'statquest':2,
                'awesome':3,
                '<EOS>':4
                }

id_to_token = dict(map(reversed, token_to_id.items()))

# create a tensor with the input data
inputs = torch.tensor([[token_to_id["what"], ## input #1: what is statquest <EOS> awesome
                        token_to_id["is"], 
                        token_to_id["statquest"], 
                        token_to_id["<EOS>"],
                        token_to_id["awesome"]], 
                       
                       [token_to_id["statquest"], # input #2: statquest is what <EOS> awesome
                        token_to_id["is"], 
                        token_to_id["what"], 
                        token_to_id["<EOS>"], 
                        token_to_id["awesome"]]])

# create a tensor with the labels
labels = torch.tensor([[token_to_id["is"], 
                        token_to_id["statquest"], 
                        token_to_id["<EOS>"], 
                        token_to_id["awesome"], 
                        token_to_id["<EOS>"]],  
                       
                       [token_to_id["is"], 
                        token_to_id["what"], 
                        token_to_id["<EOS>"], 
                        token_to_id["awesome"], 
                        token_to_id["<EOS>"]]])

# create a tensor dataset and dataloader
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

## 3. Position Encoding

Position encoding is used to keep the order of the words in the input and output. 

In [36]:
class PositionEncoding(nn.Module):
    def __init__(self, d_model=2, max_len=6):
        
        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)
        embedding_index = torch.arange(start=0, end=d_model, step=2).float()

        div_term = 1/torch.tensor(10000.0)**(embedding_index/d_model)

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe)

    def forward(self, word_embeddings):
        return word_embeddings + self.pe[:word_embeddings.size(0), :]

## 4. Attention Mechanism

The Attention Mechanism is used to determine which words are important in the input and output.

In [37]:
class Attention(nn.Module):
    def __init__(self, d_model=2):
        
        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):

        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)
        
        q = q.to('mps')
        k = k.to('mps')
        v = v.to('mps')

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        
        scaled_sims = scaled_sims.to('mps')

        if mask is not None:
            mask = mask.to('mps')
            scaled_sims = scaled_sims.masked_fill(mask == mask, value=-1e9)

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## 5. Decoder-Only Transformer

Here, we'll piece together the earlier steps and complete the Transformer.
1. Word Embeddings
2. Position Encoding
3. Masked Self-Attention
4. Residual Connections
5. Fully Connected Layer
6. Softmax

In [38]:
class DecoderOnlyTransformer(L.LightningModule):

    def __init__(self, num_tokens=4, d_model=2, max_len=6):

        super().__init__()

        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        self.self_attention = Attention(d_model=d_model)

        self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)

        self.loss = nn.CrossEntropyLoss()

    def forward(self, token_ids):

        word_embeddings = self.we(token_ids)
        position_encoded = self.pe(word_embeddings)

        mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0))))
        mask = mask == 0

        self_attention_values = self.self_attention(position_encoded, position_encoded, position_encoded, mask=mask)

        residual_connection_values = position_encoded + self_attention_values

        residual_connection_values = residual_connection_values.to('mps')

        fc_layer_output = self.fc_layer(residual_connection_values)

        return fc_layer_output
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx):
        input_tokens, labels = batch
        output = self.forward(input_tokens[0])
        loss = self.loss(output, labels[0])

        return loss

## 6. Run the Model before Training

Input prompt: "What is StatQuest"

In [39]:
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6)

model.to('mps')

DecoderOnlyTransformer(
  (we): Embedding(5, 2)
  (pe): PositionEncoding()
  (self_attention): Attention(
    (W_q): Linear(in_features=2, out_features=2, bias=False)
    (W_k): Linear(in_features=2, out_features=2, bias=False)
    (W_v): Linear(in_features=2, out_features=2, bias=False)
  )
  (fc_layer): Linear(in_features=2, out_features=5, bias=True)
  (loss): CrossEntropyLoss()
)

In [40]:
model_input = torch.tensor([token_to_id["what"], 
                            token_to_id["is"], 
                            token_to_id["statquest"], 
                            token_to_id["<EOS>"]], device='mps')
input_length = model_input.size(dim=0)

predictions = model(model_input) 
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])], device='mps')
predicted_ids = predicted_id

max_length = 6
for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])], device='mps')
    predicted_ids = torch.cat((predicted_ids, predicted_id))

print("Predicted Tokens:\n") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 <EOS>


Answer is wrong because we have not trained the model

## 7. Train the Model

We'll use the Lightning Module to train the model.

In [41]:
trainer = L.Trainer(max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type             | Params | Mode 
------------------------------------------------------------
0 | we             | Embedding        | 10     | train
1 | pe             | PositionEncoding | 0      | train
2 | self_attention | Attention        | 12     | train
3 | fc_layer       | Linear           | 15     | train
4 | loss           | CrossEntropyLoss | 0      | train
------------------------------------------------------------
37        Trainable params
0         Non-trainable params
37        Total params
0.000     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode
/opt/homebrew/Caskroom/miniconda/base/envs/playground/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing

Epoch 29: 100%|██████████| 2/2 [00:00<00:00, 231.81it/s, v_num=2]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 2/2 [00:00<00:00, 144.85it/s, v_num=2]


In [46]:
model_input = torch.tensor([token_to_id["what"], 
                            token_to_id["is"], 
                            token_to_id["statquest"], 
                            token_to_id["<EOS>"]], device='mps')
input_length = model_input.size(dim=0)

# Ensure all model parameters are on 'mps' device
for param in model.parameters():
    param.data = param.data.to('mps')

# Ensure positional encoding tensor is on 'mps' device
model.pe = model.pe.to('mps')

predictions = model(model_input) 
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])], device='mps')
predicted_ids = predicted_id

max_length = 6
for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])], device='mps')
    predicted_ids = torch.cat((predicted_ids, predicted_id))

print("Predicted Tokens:\n") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 awesome
	 <EOS>


Correct output is seen.