<a href="https://colab.research.google.com/github/ResByte/llm-notebooks/blob/main/notebooks/04_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building Transformers from Scratch

This notebook is replication of blogpost by Mat Miller(https://blog.matdmiller.com/posts/2023-06-10_transformers/notebook.html)

He used it to explain the youtube video "Lets build GPT" by Andrej

## 1. Getting Started

In [22]:
# download the dataset
# in this case, mini-shakespeare

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-01-14 15:47:35--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2024-01-14 15:47:35 (19.4 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [23]:
with open('input.txt', 'r') as f:
    text = f.read()
print(f"Length: {len(text)}")
print(f"Initial sample: {text[:500]}")

Length: 1115394
Initial sample: First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


## 2. Tokenization

In LLMs generally tokens are created at sub-word level, but in this case we are doing at char level

In [24]:
vocab = sorted(list(set(text)))
vocab_size = len(vocab)
print(f"size: {vocab_size}")
print(vocab)

size: 65
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In order to feed these tokens into model, these characters needs to be converted to numbers. In LLMs, these are called embeddings and generally a learned vector representation is used.

In [25]:
char2idx = {char:idx for idx, char in enumerate(vocab)}
char2idx

{'\n': 0,
 ' ': 1,
 '!': 2,
 '$': 3,
 '&': 4,
 "'": 5,
 ',': 6,
 '-': 7,
 '.': 8,
 '3': 9,
 ':': 10,
 ';': 11,
 '?': 12,
 'A': 13,
 'B': 14,
 'C': 15,
 'D': 16,
 'E': 17,
 'F': 18,
 'G': 19,
 'H': 20,
 'I': 21,
 'J': 22,
 'K': 23,
 'L': 24,
 'M': 25,
 'N': 26,
 'O': 27,
 'P': 28,
 'Q': 29,
 'R': 30,
 'S': 31,
 'T': 32,
 'U': 33,
 'V': 34,
 'W': 35,
 'X': 36,
 'Y': 37,
 'Z': 38,
 'a': 39,
 'b': 40,
 'c': 41,
 'd': 42,
 'e': 43,
 'f': 44,
 'g': 45,
 'h': 46,
 'i': 47,
 'j': 48,
 'k': 49,
 'l': 50,
 'm': 51,
 'n': 52,
 'o': 53,
 'p': 54,
 'q': 55,
 'r': 56,
 's': 57,
 't': 58,
 'u': 59,
 'v': 60,
 'w': 61,
 'x': 62,
 'y': 63,
 'z': 64}

In [26]:
idx2char = {idx:char for char, idx in char2idx.items()}
idx2char

{0: '\n',
 1: ' ',
 2: '!',
 3: '$',
 4: '&',
 5: "'",
 6: ',',
 7: '-',
 8: '.',
 9: '3',
 10: ':',
 11: ';',
 12: '?',
 13: 'A',
 14: 'B',
 15: 'C',
 16: 'D',
 17: 'E',
 18: 'F',
 19: 'G',
 20: 'H',
 21: 'I',
 22: 'J',
 23: 'K',
 24: 'L',
 25: 'M',
 26: 'N',
 27: 'O',
 28: 'P',
 29: 'Q',
 30: 'R',
 31: 'S',
 32: 'T',
 33: 'U',
 34: 'V',
 35: 'W',
 36: 'X',
 37: 'Y',
 38: 'Z',
 39: 'a',
 40: 'b',
 41: 'c',
 42: 'd',
 43: 'e',
 44: 'f',
 45: 'g',
 46: 'h',
 47: 'i',
 48: 'j',
 49: 'k',
 50: 'l',
 51: 'm',
 52: 'n',
 53: 'o',
 54: 'p',
 55: 'q',
 56: 'r',
 57: 's',
 58: 't',
 59: 'u',
 60: 'v',
 61: 'w',
 62: 'x',
 63: 'y',
 64: 'z'}

In [27]:
encode = lambda x : [char2idx[char] for char in x]
decode = lambda idxs: ''.join([idx2char[idx] for idx in idxs])
print("Tokenize Hello world! :", encode("Hello world!"))
print("Create string from tokens :", decode([20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]))

Tokenize Hello world! : [20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]
Create string from tokens : Hello world!


In [28]:
# create tensor embeddings for the dataset

In [29]:
import torch

In [30]:
encoded_text = torch.tensor(encode(text))
encoded_text.shape, encoded_text.dtype

(torch.Size([1115394]), torch.int64)

In [31]:
encoded_text

tensor([18, 47, 56,  ..., 45,  8,  0])

In [32]:
train_split_pct = 0.9
train_split_idx = int(len(encoded_text)*train_split_pct)
train_split_idx

1003854

In [33]:
train_data = encoded_text[:train_split_idx]
valid_data = encoded_text[train_split_idx:]
print(f"Train data: {len(train_data)}, Valid data: {len(valid_data)}, Train pct: {len(train_data)/len(encoded_text)}")

Train data: 1003854, Valid data: 111540, Train pct: 0.8999994620734916


Context length is the minimum length of the seq used when training the transformer. Also referred to as block size. The transformer will be trained on each combination of tokens up to maximum context

In [34]:
context_length = 8
for i in range(context_length):
    x,y  = train_data[:i+1], train_data[i+1]
    print(f"idx: {i}, x: {x}, y:{y}, | decoded x: {decode(x.tolist())}, y: {decode(y[None].tolist())}")

idx: 0, x: tensor([18]), y:47, | decoded x: F, y: i
idx: 1, x: tensor([18, 47]), y:56, | decoded x: Fi, y: r
idx: 2, x: tensor([18, 47, 56]), y:57, | decoded x: Fir, y: s
idx: 3, x: tensor([18, 47, 56, 57]), y:58, | decoded x: Firs, y: t
idx: 4, x: tensor([18, 47, 56, 57, 58]), y:1, | decoded x: First, y:  
idx: 5, x: tensor([18, 47, 56, 57, 58,  1]), y:15, | decoded x: First , y: C
idx: 6, x: tensor([18, 47, 56, 57, 58,  1, 15]), y:47, | decoded x: First C, y: i
idx: 7, x: tensor([18, 47, 56, 57, 58,  1, 15, 47]), y:58, | decoded x: First Ci, y: t


## Parameters

In [35]:
TORCH_SEED = 42
torch.manual_seed(TORCH_SEED)
context_length = 8
batch_size = 4

## 2. Data Loader

In [36]:
def get_batch(train_valid):
    data = train_data if train_valid == 'train' else valid_data
    data_len = len(data)
    start_idxs = torch.randint(
        high=len(data) - context_length,
        size=(batch_size,))
    x = torch.stack([data[i: i+context_length] for i in start_idxs])
    y = torch.stack([data[i+1:i+context_length + 1] for i in start_idxs])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(f"shape: {xb.shape}, xb: {xb}")
print(f"Targets: Shape:{yb.shape} yb: {yb}")

for batch_idx in range(batch_size):
    for seq_idx in range(context_length):
        print(batch_idx, seq_idx)
        context = xb[batch_idx,:seq_idx + 1]
        target = yb[batch_idx, seq_idx]
        print(f"Given input: {context.tolist()} target: {target}")

inputs:
shape: torch.Size([4, 8]), xb: tensor([[57,  1, 46, 47, 57,  1, 50, 53],
        [ 1, 58, 46, 43, 56, 43,  1, 41],
        [17, 26, 15, 17, 10,  0, 32, 53],
        [57, 58,  6,  1, 61, 47, 58, 46]])
Targets: Shape:torch.Size([4, 8]) yb: tensor([[ 1, 46, 47, 57,  1, 50, 53, 60],
        [58, 46, 43, 56, 43,  1, 41, 39],
        [26, 15, 17, 10,  0, 32, 53,  1],
        [58,  6,  1, 61, 47, 58, 46,  0]])
0 0
Given input: [57] target: 1
0 1
Given input: [57, 1] target: 46
0 2
Given input: [57, 1, 46] target: 47
0 3
Given input: [57, 1, 46, 47] target: 57
0 4
Given input: [57, 1, 46, 47, 57] target: 1
0 5
Given input: [57, 1, 46, 47, 57, 1] target: 50
0 6
Given input: [57, 1, 46, 47, 57, 1, 50] target: 53
0 7
Given input: [57, 1, 46, 47, 57, 1, 50, 53] target: 60
1 0
Given input: [1] target: 58
1 1
Given input: [1, 58] target: 46
1 2
Given input: [1, 58, 46] target: 43
1 3
Given input: [1, 58, 46, 43] target: 56
1 4
Given input: [1, 58, 46, 43, 56] target: 43
1 5
Given input: [1, 

## 4. Bigram Model

Model predicts probability of one token following another.

In [37]:
import torch.nn as nn
from torch.nn import functional as F

In [38]:
torch.manual_seed(TORCH_SEED)

<torch._C.Generator at 0x7aba6c296830>

In [43]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size

        self.token_embedding_table = nn.Embedding(
            num_embeddings=self.vocab_size,
            embedding_dim=self.vocab_size
            )
    def forward(self, idx, targets=None):
        # both idx and targets are Batch, Time array of int
        logits = self.token_embedding_table(idx)
        if targets is not None:
            B, T, C = logits.shape
            logits_reshaped = logits.view(B*T, C)
            targets_reshaped = targets.view(B*T)
            loss = F.cross_entropy(
                input=logits_reshaped,
                target=targets_reshaped
            )
        else:
            loss=None
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is B, T array of indicies in the current context
        for _ in range(max_new_tokens):
            # get preds
            logits, loss = self(idx)

            # get the last time step from logits
            logits_last_timestep = logits[:, -1, :]
            print(f"Shape of logits time stamp: {logits_last_timestep.shape}")

            # apply softmax
            probs = F.softmax(input=logits_last_timestep, dim=-1)

            # sample from probs distribution
            idx_next = torch.multinomial(
                input=probs,
                num_samples=1
            )

            idx = torch.cat((idx, idx_next), dim=1)
        return idx


In [44]:
bigram_model = BigramLanguageModel(vocab_size=vocab_size)

In [45]:
logits, loss = bigram_model(xb, yb)

In [46]:
print(f"Loss: {loss}")

Loss: 4.731735706329346


In [47]:
idx = torch.zeros((1,1), dtype=torch.long)

In [48]:
print(f"100 Generated Tokens:",
      decode(bigram_model.generate(idx, max_new_tokens=100)[0].tolist()))

Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([1, 65])
Shape of logits time stamp: torch.Size([

In [49]:
# to train the model
optimizer = torch.optim.Adam(
    params=bigram_model.parameters(),
    lr=1e-3
)