# Shakespeare GPT -- GPT From Scratch!


In [318]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Dataset Exploration and Modification


In [319]:
# read in input
with open('./input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [320]:
# inpeact the length and the first 1000 chars
print("Length of dataset in chars:", len(text))
print(text[:1000])

Length of dataset in chars: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for

In [321]:
# create an alphabet based on the input
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(chars)
print(vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


In [322]:
# time to tokenize the alphabet (character tokenizer)

# lets create a mapping for encoding (char to int) and decoding (int to char)
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}

encode = lambda s: [stoi[c] for c in s] # string to list of ints
decode = lambda l: ''.join([itos[i] for i in l]) # list of ints to string

print(encode("Hello World!"))
print(decode(encode("Hello World!")))

[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42, 2]
Hello World!


In [323]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

# Dataset Splitting


In [324]:
# splitting the dataset into a train (90%) set and a validation (10%) set
n = int(0.9 * len(data))
train = data[:n]
val = data[n:]

print("TRAIN:", train.shape)
print("VAL:", val.shape)

TRAIN: torch.Size([1003854])
VAL: torch.Size([111540])


In [325]:
# creating a batch with visualization

torch.manual_seed(314159365) # for reproducibility
batch_size = 4 # number of blocks in each batche
block_size = 8 # context length

def get_batch(split):
    # create a batch of inputs (x) and targets (y)
    
    data = train if split == 'train' else val # choose split to use
    
    idx = torch.randint(len(data) - block_size, (batch_size,)) # get index of random token
        
    # get inputs and targets based in index
    x = torch.stack([data[i : i + block_size] for i in idx])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in idx])
    
    return x, y

Xb, Yb = get_batch('train')
print("Inputs shape:", Xb.shape)
print("Inputs:")
print(Xb)
print("\nTargets shape:", Yb.shape)
print("Targets:\n")
print(Yb)

for b in range(batch_size):
    print(f'\nBlock #{b+1}:')
    for i in range(block_size):
        context = Xb[b, :i+1] # context 'slides' over input
        target = Yb[b, i]
        print(f'{context.tolist()} ===> {target}')

Inputs shape: torch.Size([4, 8])
Inputs:
tensor([[51, 47, 53,  6,  1, 58, 46, 53],
        [37, 30, 30, 17, 24, 10,  0, 21],
        [21,  5, 50, 50,  1, 54, 56, 53],
        [41, 43, 47, 60, 43,  0, 37, 53]])

Targets shape: torch.Size([4, 8])
Targets:

tensor([[47, 53,  6,  1, 58, 46, 53, 59],
        [30, 30, 17, 24, 10,  0, 21,  1],
        [ 5, 50, 50,  1, 54, 56, 53, 51],
        [43, 47, 60, 43,  0, 37, 53, 59]])

Block #1:
[51] ===> 47
[51, 47] ===> 53
[51, 47, 53] ===> 6
[51, 47, 53, 6] ===> 1
[51, 47, 53, 6, 1] ===> 58
[51, 47, 53, 6, 1, 58] ===> 46
[51, 47, 53, 6, 1, 58, 46] ===> 53
[51, 47, 53, 6, 1, 58, 46, 53] ===> 59

Block #2:
[37] ===> 30
[37, 30] ===> 30
[37, 30, 30] ===> 17
[37, 30, 30, 17] ===> 24
[37, 30, 30, 17, 24] ===> 10
[37, 30, 30, 17, 24, 10] ===> 0
[37, 30, 30, 17, 24, 10, 0] ===> 21
[37, 30, 30, 17, 24, 10, 0, 21] ===> 1

Block #3:
[21] ===> 5
[21, 5] ===> 50
[21, 5, 50] ===> 50
[21, 5, 50, 50] ===> 1
[21, 5, 50, 50, 1] ===> 54
[21, 5, 50, 50, 1, 54] ===> 

# Bigram Language Model


In [326]:
# lets start with a very simple language model: the Bigram
# we will use PyTorch modules rather than building it ourselves
# (see my makemore-CityNames repo for Bigram and Trigram models from scratch)

torch.manual_seed(314159265) # fro reproducibility

class Bigram(nn.Module):
    
    # initialize the model
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) # create an 2D embedding table
    
    # forward pass
    def forward(self, idx, targets=None):
        
        # idx and targets are a both (B, T) tensors of integers where B = batch and T = time (or block)
        logits = self.token_embedding_table(idx) # (B, T, C) | C = channels (or vocab_size)
        
        if targets is None:
            loss = None
        else:
            # dimension manipulation
            B, T, C = logits.shape
            logits = logits.view((B*T, C))
            targets = targets.view((B*T))
            loss = F.cross_entropy(logits, targets) # get loss using cross_entropy
        return logits, loss
    
    # generate new tokens
    def generate(self, idx, max_new_tokens):
        # idx is a (B, T) tensor
        for _ in range(max_new_tokens):
            
            logits, loss = self(idx) # get predictions
            
            logits = logits[:, -1, :] # focus on last time step (B, C)
            probs = F.softmax(logits, dim=1) # get probabilities over rows (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # sample from probs distribution (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # concatenate new token (B, T+1)
        
        return idx
    
m = Bigram(vocab_size)
logits, loss = m(Xb, Yb)
print(logits.shape)
print(loss.item())
print("Expected loss:", -torch.log(torch.tensor([1/65])).item())

# lets generate tokens from the Bigram
idx = torch.zeros((1,1), dtype=torch.long) # get index (first token is index 0 or '/n')
out = m.generate(idx, max_new_tokens=100)[0].tolist() # generate the new tokens
print(decode(out)) # decode the tokens

torch.Size([32, 65])
4.797330856323242
Expected loss: 4.174387454986572

MtcZK!kHeruAiysl3sI
&
aalhb$GxSyyysk3RkdWF?Yk
&iqfF?oHRwm?cqAZxSb  eU$WkqZlD.gnN-zYcjoduqpR!NPZtqjjA


### Yay, garbage! But thats ok because the Bigram model is random right now... so let's train it!


In [327]:
# create a PyTorch optimizer (to see an optimizer (gradient descent) from scratch go to my makemore-CityNames repo)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-03)

In [328]:
batch_size = 32
for steps in range(100):
    Xb, Yb = get_batch('train') # get sample batch
    
    # evaluate loss
    logits, loss = m(Xb, Yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

4.5410284996032715


In [329]:
# lets generate tokens from the trained Bigram
idx = torch.zeros((1,1), dtype=torch.long) # get index (first token is index 0 or '/n')
out = m.generate(idx, max_new_tokens=300)[0].tolist() # generate the new tokens
print(decode(out)) # decode the tokens


DA.WfRiSTiw.Ovs-ezvjeXZMv VNTaxgjFpEPG$YFcBp'oEdbZBHEJzj .vSb
Qe CMuavoEgI?lL' GLsFyUBFp.mhVS.Op'&$fSvXTEXFDPpRgP-oe'aF,vlPYOIqbwFfQTEkbruXs-rh&3sFRxsMGM3ND.-UGHJkLQPUz;&ISKblGsNs!yM,PG!kcPYZSNkomOeXFrEdDkr;;xhcvGgnyj n&hpjKueskstDuGGIFhUBzOq?MsNT3v Hv'MA'TovSwIuZb ksv'kxy!:;KWzvS;
UO!keKWRiBH?RqDRO


# The Math Behind Self-Attention !


In [330]:
# consider the following:

torch.manual_seed(314159265)

B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [331]:
# version 1: inefficient approach

# we want x[b, t] = mean(x[b,i]) where i <= t
x_bow = torch.zeros((B,T,C)) # bow: bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t, C)
        x_bow[b,t] = torch.mean(xprev, 0)


In [332]:
# version 2: efficient approach but not what we want yet

weights = torch.tril(torch.ones((T,T))) # get trianglular ones tensor
weights = weights / torch.sum(weights, 1, keepdims=True) # convert ones to distributions
mat = weights @ x  # matrix multiplication

torch.allclose(x_bow,mat) # evaluates to true if all values are the same

True

In [333]:
# version 3: efficient and uses Softmax
tril = torch.tril(torch.ones((T,T)))
weights = torch.zeros((T,T)) # get trianglular ones tensor
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=1)
mat = weights @ x  # matrix multiplication

torch.allclose(x_bow,mat) # evaluates to true if all values are the same

True

### As we can see, we can use matrix multiplication to optimize the calculation of mean over the time steps


In [334]:
# version 4: Attention!
torch.manual_seed(314159265)

B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# a single Head of self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False) # token we are at
query = nn.Linear(C, head_size, bias=False) # past tokens we want context from
value = nn.Linear(C, head_size, bias=False) # communicates data if interested

k = key(x) # (B, T, head_size) | in cross attention, this would come from somewhere else
q = query(x) # (B, T, head_size)
weights = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) ===> (B, T, T) 

tril = torch.tril(torch.ones((T,T)))
weights = weights.masked_fill(tril == 0, float('-inf')) # delete for encoder attention blocks
weights = F.softmax(weights, dim=-1)

v = value(x) # in cross attention, this would come from somewhere else
out = weights @ v  # matrix multiplication

out.shape

torch.Size([4, 8, 16])

#### Notes:

- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- No notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode of tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other.
- _Self-Attention_ just means that the keys and values are produced from the same source as queries. In _Cross-Attention_, the queries get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- In an _encoder_ attention block, just delete the single line that does the masking with `tril`, allowing all tokens to communicate. The example above is a _decoder_ attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- _Scaled_ attention additionally divides `weights` by 1 / sqrt(head_size). This makes it so when Q, K are unit variance, `weights` will be unit variance too and Softmax will stay diffuse and not saturate too much i.e. softmax will start to act like one-hot encoding. Example below:


In [335]:
# Scaled attention
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
weights = q @ k.transpose(-2, -1) * head_size**-0.5 # scaled by head_size**-0.5
print(f'k var: {k.var().item():.4f} | q var: {q.var().item():.4f} | weights var: {weights.var().item():.4f}')

k var: 1.1114 | q var: 1.1233 | weights var: 1.0711


# Tokenization!

#### Disclaimer
Tokenization is a very, very important to LLMs. It should not be ignored! Maybe one day we won't need it, but we need it right now.\
Check out [this website](https://tiktokenizer.vercel.app) for a visual representation of tokenization!

## Byte Pair Encoding Algorithm

In [336]:
[ord(x) for x in 'Good morning! 😊'] # getting unicode

[71, 111, 111, 100, 32, 109, 111, 114, 110, 105, 110, 103, 33, 32, 128522]

In [337]:
list('Good morning! 😊'.encode('utf-8')) # getting utf-8 encoding

[71,
 111,
 111,
 100,
 32,
 109,
 111,
 114,
 110,
 105,
 110,
 103,
 33,
 32,
 240,
 159,
 152,
 138]

In [338]:
# Byte Pair Encoding
# text copied from https://www.reedbeta.com/blog/programmers-intro-to-unicode/
text2 = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception."
tokens = text2.encode("utf-8")
tokens = list(map(int, tokens))
print('---')
print(text2)
print("Text length:", len(text))
print('---')
print(tokens)
print("Encoding length:", len(tokens))

---
Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception.
Text length: 1115394
---
[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226,

In [339]:
# let's get the counts of all the unique two character sequences in the paragraph above using the encodings
def get_counts(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # iterating over consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts
    
counts = get_counts(tokens)
print(sorted({(v, k) for k, v in counts.items()}, reverse=True)) # sort elements by value

[(20, (101, 32)), (15, (240, 159)), (12, (226, 128)), (12, (105, 110)), (10, (115, 32)), (10, (97, 110)), (10, (32, 97)), (9, (32, 116)), (8, (116, 104)), (7, (159, 135)), (7, (159, 133)), (7, (97, 114)), (6, (239, 189)), (6, (140, 240)), (6, (128, 140)), (6, (116, 32)), (6, (114, 32)), (6, (111, 114)), (6, (110, 103)), (6, (110, 100)), (6, (109, 101)), (6, (104, 101)), (6, (101, 114)), (6, (32, 105)), (5, (117, 115)), (5, (115, 116)), (5, (110, 32)), (5, (100, 101)), (5, (44, 32)), (5, (32, 115)), (4, (116, 105)), (4, (116, 101)), (4, (115, 44)), (4, (114, 105)), (4, (111, 117)), (4, (111, 100)), (4, (110, 116)), (4, (110, 105)), (4, (105, 99)), (4, (104, 97)), (4, (103, 32)), (4, (101, 97)), (4, (100, 32)), (4, (99, 111)), (4, (97, 109)), (4, (85, 110)), (4, (32, 119)), (4, (32, 111)), (4, (32, 102)), (4, (32, 85)), (3, (118, 101)), (3, (116, 115)), (3, (116, 114)), (3, (116, 111)), (3, (114, 116)), (3, (114, 115)), (3, (114, 101)), (3, (111, 102)), (3, (111, 32)), (3, (108, 108)), (

In [340]:
chr(101), chr(32)

('e', ' ')

So the most common sequence of two characters in the paragraph is 'e' and ' '.

In [341]:
# lets make a function to replace this sequence with the int 256 
# (256 because utf-8 goes to 255 and we need a new encoding)
top_pair = max(counts, key=counts.get)

def replace_seq(ids, pair, replacement):
    i = 0
    while i < len(ids) - 1: # iterating over n - 1 elements
        if ids[i:i+len(pair)] == list(pair): # if the pair matches the current items
            ids = ids[:i] + [replacement] + ids[i+2:] # replace the items
        i += 1 # add one to i
    return ids

mod = replace_seq(tokens, top_pair, 256)
print(mod)
print("Modified Length:", len(mod))

[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140, 240, 159, 135, 169, 226, 128, 140, 240, 159, 135, 170, 33, 32, 240, 159, 152, 132, 32, 84, 104, 256, 118, 101, 114, 121, 32, 110, 97, 109, 256, 115, 116, 114, 105, 107, 101, 115, 32, 102, 101, 97, 114, 32, 97, 110, 100, 32, 97, 119, 256, 105, 110, 116, 111, 32, 116, 104, 256, 104, 101, 97, 114, 116, 115, 32, 111, 102, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 115, 32, 119, 111, 114, 108, 100, 119, 105, 100, 101, 46, 32, 87, 256, 97, 108, 108, 32, 107, 110, 111, 119, 32, 119, 256, 111, 117, 103, 104, 116, 32, 116, 111, 32, 226, 128, 156

Using the function we made, we can see that (101, 32) does not appear anymore and is now replaced by the single integer 256. Becuse the length was originally 616, and (101,32) appeared 20 times, we can verify the function works properly because the modified length is 596 which is 20 less that 616.

In [342]:
# now lets put both functions together into a larger function.
# the amount of repititions is a hyperparameter we can tune


def bp_enc(ids, iters):
    merges = {}
    tmp = list(ids)
    s = 256
    for i in range(iters):
        counts = get_counts(tmp)
        pair = max(counts, key=counts.get)
        merges[s+i] = pair
        print(f"Pair {i+1:4d}: {str(pair):>10s} {counts.get(pair):4d} ===> {s + i:4d}")
        tmp = replace_seq(tmp, pair, s + i)
    
    return tmp, merges

tokens_bp, merges = bp_enc(tokens, 10)
print(tokens_bp)
print("token_bp length:", len(tokens_bp))

Pair    1:  (101, 32)   20 ===>  256
Pair    2: (240, 159)   15 ===>  257
Pair    3: (226, 128)   12 ===>  258
Pair    4: (105, 110)   12 ===>  259
Pair    5:  (115, 32)   10 ===>  260
Pair    6:  (97, 110)   10 ===>  261
Pair    7: (116, 104)    8 ===>  262
Pair    8: (257, 133)    7 ===>  263
Pair    9: (257, 135)    7 ===>  264
Pair   10:  (97, 114)    7 ===>  265
[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 263, 164, 263, 157, 263, 152, 263, 146, 263, 158, 263, 147, 263, 148, 258, 189, 32, 264, 186, 258, 140, 264, 179, 258, 140, 264, 174, 258, 140, 264, 168, 258, 140, 264, 180, 258, 140, 264, 169, 258, 140, 264, 170, 33, 32, 257, 152, 132, 32, 84, 104, 256, 118, 101, 114, 121, 32, 110, 97, 109, 256, 115, 116, 114, 105, 107, 101, 260, 102, 101, 265, 32, 261, 100, 32, 97, 119, 256, 259, 116, 111, 32, 262, 256, 104, 101, 265, 116, 260, 111, 102, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 260, 119, 111, 114,

The top ten pair counts sum up to 108. 616 - 108 = 508 so the function works as expected!

In [343]:
sample = text[:20000] # first 10000 characters from tiny shakespeare
tokens = sample.encode("utf-8") # raw encoding bytes
tokens = list(map(int, tokens)) # get ints in range 0-255

In [344]:
vocab_size = 276
num_merges = vocab_size - 256
ids = list(tokens)

ids, merges = bp_enc(ids, num_merges)
print("tokens Length:", len(tokens))
print("ids Length:", len(ids))
print(f"Compression ratio: {len(tokens) / len(ids):.2f}x")

Pair    1:  (101, 32)  517 ===>  256
Pair    2: (116, 104)  402 ===>  257
Pair    3:  (116, 32)  321 ===>  258
Pair    4:  (115, 32)  291 ===>  259
Pair    5: (111, 117)  270 ===>  260
Pair    6:   (44, 32)  248 ===>  261
Pair    7:  (100, 32)  234 ===>  262
Pair    8:  (114, 32)  203 ===>  263
Pair    9: (105, 110)  183 ===>  264
Pair   10:  (97, 110)  170 ===>  265
Pair   11: (101, 110)  167 ===>  266
Pair   12:   (58, 10)  160 ===>  267
Pair   13:  (121, 32)  147 ===>  268
Pair   14:   (10, 10)  146 ===>  269
Pair   15: (101, 114)  140 ===>  270
Pair   16: (111, 110)  138 ===>  271
Pair   17: (108, 108)  131 ===>  272
Pair   18:  (97, 114)  126 ===>  273
Pair   19: (257, 256)  126 ===>  274
Pair   20: (121, 260)  124 ===>  275
tokens Length: 20000
ids Length: 15756
Compression ratio: 1.27x


#### Note:
The tokenizer is a completely seperate, independent module from the LLM. It has it's own training dataset (which could be different from the LLM) on which it will train on vocabularu using the Byte Pair Encoding (BPE) Algorithm. Only later does the LLM actually recieve the tokens the tokenizer produces. This means the LLM never directly deals with any text, only tokens.

### Encoding and Decoding Using A BPE Tokenizer

#### Decoding
Given a sequence of integers in the range [0, vocab_size], what is the text?

In [345]:
# pre-processing
vocab = {idx: bytes([idx]) for idx in range(256)} # int to byte list mapping

# adding merges to vocab mapping
for idx, (p0, p1) in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]
    
def decode(ids):
    # given list if ids (integers), returns Python string
    tokens = b"".join([vocab[idx] for idx in ids])
    text = tokens.decode('utf-8', errors='replace')
    return text

output = decode(ids)


#### Encoding
Given a string, what is the encoding in integers?

In [346]:
def encode(text):
    tokens  = list(text.encode('utf-8'))
    while len(tokens) >= 2:
        counts = get_counts(tokens)
                   
        # for any pair in counts, look at merges and get the min count
        pair = min(counts, key = lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break
        
        idx = merges[pair]
        tokens = bp_enc(tokens, pair, idx)
    
    return tokens
    

ids = encode(output)
print(ids[:10])

[70, 105, 114, 115, 116, 32, 67, 105, 116, 105]


In [347]:
text2 = decode(encode(text))
print(text == text2)

True


In [348]:
valtext = "hello world! this is validation text for the tokenizer :)"
valtext2 = decode(encode(valtext))
print(valtext == valtext2)

True


### Yay! Byte Pair Encoding is implemented!

### ...Now onto more complicated tokenization methods

### Forcing splits using regex patterns (like GPT2)

In [354]:
import regex as re
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll'd| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

print(re.findall(gpt2pat, "Hello've world123 how's are       you!I!?HOW'S  "))

['Hello', "'ve", ' world', '123', ' how', "'s", ' are', '      ', ' you', '!', 'I', '!?', 'HOW', "'", 'S', '  ']


Notice how the regex referring to `'` only deal with lowercase letters... So `'s` will work, but `'S` will not.

In [359]:
import tiktoken

# GPT-2 (does not merge spaces)
enc = tiktoken.get_encoding("gpt2")
print(enc.encode("    hello world!!!"))

# GPT-4 (merges spaces)
enc = tiktoken.get_encoding("cl100k_base")
print(enc.encode("    hello world!!!"))

[220, 220, 220, 23748, 995, 10185]
[262, 24748, 1917, 12340]


### Special Tokens
Tokens used as delimiters, etc.