## Coding Transformer from scratch in Pytorch

In [None]:
# Importing "lightning" to make it way easier to write our code  and for automatic code optimization and scaling in the cloud!
! pip3 install lightning

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

from position_encoding import PositionEncoding
from attention import Attention
from decoder_only_transformer import DecoderOnlyTransformer

Using device: cpu



In [3]:
# Prompts: 1. What is Statquest? Awesome 2. Statquest is what? Awesome
# first, we create a dictionary that maps vocabulary tokens to id numbers...
token_to_id = {
    "what": 0,
    "is": 1,
    "statquest": 2,
    "awesome": 3,
    "<EOS>": 4
}

In [4]:
# ...then we create a dictionary that maps the ids to tokens. This will help us interpret the output.
id_to_token = {id:token for token, id in token_to_id.items()}

In [5]:
inputs = torch.tensor([[token_to_id["what"],
                        token_to_id["is"],
                        token_to_id["statquest"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"]],

                        [token_to_id["statquest"],
                         token_to_id["is"],
                         token_to_id["what"],
                         token_to_id["<EOS>"],
                         token_to_id["awesome"]]])

labels = torch.tensor([[token_to_id["is"],
                        token_to_id["statquest"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"],
                        token_to_id["<EOS>"]],

                        [token_to_id["is"],
                        token_to_id["what"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"],
                        token_to_id["<EOS>"]]])

In [6]:
# Package everything up into a DataLoader...
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

In [7]:
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6)

Seed set to 42


In [8]:
trainer = L.Trainer(max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type             | Params | Mode 
------------------------------------------------------------
0 | we             | Embedding        | 10     | train
1 | pe             | PositionEncoding | 0      | train
2 | self_attention | Attention        | 12     | train
3 | fc_layer       | Linear           | 15     | train
4 | loss           | CrossEntropyLoss | 0      | train
------------------------------------------------------------
37        Trainable params
0         Non-trainable params
37        Total params
0.000     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode
/opt/homebrew/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument

Epoch 29: 100%|██████████| 2/2 [00:00<00:00, 143.97it/s, v_num=4]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 2/2 [00:00<00:00, 97.18it/s, v_num=4] 


In [9]:
model_input = torch.tensor([token_to_id["what"],
                            token_to_id["is"],
                            token_to_id["statquest"],
                            token_to_id["<EOS>"]])

input_length = model_input.size(dim=0)
predictions = model(model_input)
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
print(predicted_id)
predicted_ids = predicted_id

max_length = 6
for i in range(input_length, max_length) :
    if (predicted_id == token_to_id["<EOS>"]):
        break
    
    model_input = torch.cat((model_input, predicted_id))

    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id) )

print("Predicted Tokens:\n")
for id in predicted_ids:
    print("\t", id_to_token[id.item()])

tensor([3])
Predicted Tokens:

	 awesome
	 <EOS>


In [10]:
## Now let's ask the other question...
model_input = torch.tensor([token_to_id["statquest"], 
                            token_to_id["is"], 
                            token_to_id["what"], 
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input) 
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
print("Predicted Tokens:\n") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 awesome
	 <EOS>
