Notebook by Can Pouliquen

# Naïve BPE (don't use this, it's excruciatingly slow)

In [26]:
from collections import Counter
import regex as re

##### Take a random toy corpus

In [28]:
corpus = " low low low low low lower lower widest widest widest newest newest newest newest newest newest"

##### Initialize a vocab `dict[int,bytes]` with the 256 possible byte values and tokenize the corpus with those

In [30]:
vocab = {i: bytes([i]) for i in range(256)}
tokens = [bytes([b]) for b in corpus.encode("utf-8")]
tokens[-10:]

[b'e', b's', b't', b' ', b'n', b'e', b'w', b'e', b's', b't']

##### Initialize a pair counter `dict[[bytes, bytes],int]`, iterate over all adjacent pairs across the whole corpus and update the counter

In [31]:
pair_counter = Counter()
for pair in zip(tokens[:-1], tokens[1:]):
    pair_counter[pair] += 1
pair_counter

Counter({(b'e', b's'): 9,
         (b's', b't'): 9,
         (b'w', b'e'): 8,
         (b't', b' '): 8,
         (b' ', b'l'): 7,
         (b'l', b'o'): 7,
         (b'o', b'w'): 7,
         (b' ', b'n'): 6,
         (b'n', b'e'): 6,
         (b'e', b'w'): 6,
         (b'w', b' '): 5,
         (b' ', b'w'): 3,
         (b'w', b'i'): 3,
         (b'i', b'd'): 3,
         (b'd', b'e'): 3,
         (b'e', b'r'): 2,
         (b'r', b' '): 2})

##### Pick the most frequent pair

In [32]:
most_frequent_pair = max(pair_counter, key=lambda p: (pair_counter[p], p))
most_frequent_pair

(b's', b't')

##### Merge that pair into a new bytes token and add it to the vocab

In [33]:
new_token = most_frequent_pair[0] + most_frequent_pair[1]

vocab[len(vocab)] = new_token

vocab[len(vocab)-1]

b'st'

##### Scan the entire corpus again to find occurrences of that pair and replace them with the merged token

In [34]:
merged = []
i = 0
while i < len(tokens):
    if i + 1 < len(tokens) and (tokens[i], tokens[i + 1]) == most_frequent_pair:
        merged.append(new_token)
        i += 2
    else:
        merged.append(tokens[i])
        i += 1
tokens = merged

In [35]:
tokens[-12:]

[b' ', b'n', b'e', b'w', b'e', b'st', b' ', b'n', b'e', b'w', b'e', b'st']

#### Repeat this cycle until a desired vocab size is reached
#### This is extremely slow because for each new merge, it requires iterating twice on the whole corpus

# Improving the naïve implem in 3 main ideas

##### 1- Use pre-tokenization (needs some guarding against special tokens, handled in Tiny-Tokenizer)

In [36]:
corpus = " low low low low low lower lower widest widest widest newest newest newest newest newest newest"

In [37]:
pre_tokens_splitter = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
words = re.findall(pre_tokens_splitter, corpus)
words

[' low',
 ' low',
 ' low',
 ' low',
 ' low',
 ' lower',
 ' lower',
 ' widest',
 ' widest',
 ' widest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest']

In [39]:
pre_tokens_ctr = Counter()
pre_tokens_ctr.update(words)
pre_tokens_ctr

Counter({' newest': 6, ' low': 5, ' widest': 3, ' lower': 2})

##### It will be convenient to represent this as a `dict[tuple[bytes], int]`

In [40]:
pre_tokens = {tuple(bytes([b]) for b in key.encode('utf-8')): pre_tokens_ctr[key] for key in pre_tokens_ctr.keys()}
pre_tokens

{(b' ', b'l', b'o', b'w'): 5,
 (b' ', b'l', b'o', b'w', b'e', b'r'): 2,
 (b' ', b'w', b'i', b'd', b'e', b's', b't'): 3,
 (b' ', b'n', b'e', b'w', b'e', b's', b't'): 6}

#### We can now count adjacent pairs by adding the frequency of the pre-token instead of incrementally !

In [41]:
pair_counter = Counter()
for pre_token in pre_tokens:
    for pair in zip(pre_token[:-1], pre_token[1:]):
        pair_counter[pair] += pre_tokens[pre_token]
pair_counter

Counter({(b'e', b's'): 9,
         (b's', b't'): 9,
         (b'w', b'e'): 8,
         (b' ', b'l'): 7,
         (b'l', b'o'): 7,
         (b'o', b'w'): 7,
         (b' ', b'n'): 6,
         (b'n', b'e'): 6,
         (b'e', b'w'): 6,
         (b' ', b'w'): 3,
         (b'w', b'i'): 3,
         (b'i', b'd'): 3,
         (b'd', b'e'): 3,
         (b'e', b'r'): 2})

##### For each merge, we still have to go twice through the corpus, but we have effectively "reduced its size" (we go twice through the pre-tokens)

##### 2- Parallelize pre-tokenization across CPU cores (see the implem in Tiny-Tokenizer)

##### 3- Caching : cache the pair-counter and in which pre-token each pair occurs 
(slightly more involved, so easier to follow orally, see the implem in Tiny-Tokenizer)

# Once the vocab is built, you can encode/decode 

In [42]:
from pathlib import Path
import pickle

from models.tokenizer import Tokenizer

In [43]:
path = Path("./results/TinyStoriesV2-GPT4-train_bpe_vocab.pkl")
with path.open("rb") as f:
    vocab = pickle.load(f)

path = Path("./results/TinyStoriesV2-GPT4-train_bpe_merges.pkl")
with path.open("rb") as f:
    merges = pickle.load(f)

tokenizer = Tokenizer(vocab=vocab, merges=merges, special_tokens="<|endoftext|>" or None)

prompt = "Hello Ockham, I like cats."
token_ids = tokenizer.encode(prompt)
print(token_ids)

[72, 101, 294, 111, 1420, 348, 104, 306, 44, 338, 366, 107, 101, 2177, 116, 115, 46]


In [44]:
tokenizer.decode(token_ids)

'Hello Ockham, I like cats.'

##### Encoding is different depending on the dataset you train on and obviously the vocab size ! Let's look at how "Ockham" is encoded when training on TinyStories (vocab = 10k tokens) vs OpenWebText (vocab = 32k tokens)

In [45]:
# With TinyStories vocab
tokenizer.encode("Ockham")

[79, 348, 104, 306]

In [50]:
path = Path("./results_old/open_web_text_bpe_vocab.pkl")
with path.open("rb") as f:
    vocab = pickle.load(f)

path = Path("./results_old/open_web_text_bpe_merges.pkl")
with path.open("rb") as f:
    merges = pickle.load(f)

tokenizer = Tokenizer(vocab=vocab, merges=merges, special_tokens="<|endoftext|>" or None)

prompt = "Ockham"
token_ids = tokenizer.encode(prompt)
print(token_ids)

[79, 726, 3147]
