## My review and implementation of a BPE (Byte Pair Encoding) tokenizer as outlined in Andrej Karpathy's Zero To Hero lecture series. 

In Python, strings are immutable sequences of Unicode code points. What are Unicode code points? Currently it's a definitition of 149,813 characters across 161 scripts as determined by the Unicode Consortium. It defines what they "look like" and which integers represent them. 

In [1]:
# The ord() function returns the number representing the unicode code of a specified character. 
ord("h")

104

In [2]:
ord("🔥")

128293

In [3]:
ord("中")

20013

In [4]:
[ord(x) for x in "很多人都会说中文"]

[24456, 22810, 20154, 37117, 20250, 35828, 20013, 25991]

Naturally, one might think to themselves, why do we need to tokenize a sequence at all when we already have these integers at our disposal? One reason is that our tokenizer vocabulary would be very, very long. The second reason is that the Unicode Standard is "alive" and is updated with regularity, which is concerning from a stability point of view. 

To overcome these issues, we can initially turn to encodings. The Unicode Consortium outlines three options; UTF-8, UTF-16, and UTF-32. UTF-8 is by far the most commonly used. UTF-8 takes every single unicode code point and translates it into a byte stream. This byte stream is between 1 - 4 bytes, thus it's considered a variable length encoding. There are pros/cons and trade-offs with all three encodings, but one the many pro's of UTF-8 is that it is backwards compatible with simple ASCII text.

In [5]:
"很多人都会说中文".encode("utf-8")

b'\xe5\xbe\x88\xe5\xa4\x9a\xe4\xba\xba\xe9\x83\xbd\xe4\xbc\x9a\xe8\xaf\xb4\xe4\xb8\xad\xe6\x96\x87'

In [6]:
list("很多人都会说中文".encode("utf-8"))

[229,
 190,
 136,
 229,
 164,
 154,
 228,
 186,
 186,
 233,
 131,
 189,
 228,
 188,
 154,
 232,
 175,
 180,
 228,
 184,
 173,
 230,
 150,
 135]

In [7]:
list("很多人都会说中文".encode("utf-32"))

[255,
 254,
 0,
 0,
 136,
 95,
 0,
 0,
 26,
 89,
 0,
 0,
 186,
 78,
 0,
 0,
 253,
 144,
 0,
 0,
 26,
 79,
 0,
 0,
 244,
 139,
 0,
 0,
 45,
 78,
 0,
 0,
 135,
 101,
 0,
 0]

It's somewhat clear that, for our purposes, UTF-32 is rather wasteful (notice all the superflous zeroes). 

However, if we just naively use UTF-8 byte streams, that would mean we're implicitly working with a vocabulary size of 256, which is far too small. There has been some research published around creating tokenizer-free autoregressive sequence modeling for LLM's, which theoretically would be fantastic. However as it currently stands, our current Transformer attention mechanisms are limited for computational reasons such that feeding UTF-8 encodings is very ineffecient and becomes exceedingly so with longer and longer sequence lengths. 

## Enter the Byte Pair Encoding algorithm

The alogrithm itself is thankfully not very complicated and yet it is incredibly enabling. Let's say we have some sort of an input sequence. We iteratively find the pairs of tokens in that sequnce that occur the most frequently. Once those pairs have been identified, we then replace that pair with a single new token that we append to our vocabularly. 

### 'aaabdaaabac'(vocab_size=4) -> 'ZabdZabac'(Z=aa) -> 'ZYdZYac'(Y=ab, Z=aa) -> 'XdXac'(X=ZY, Y=ab, Z=aa) 
Thus we went from a sequence of 11 with a vocab_size 4, and compressed down to a sequence of 5 with a vocab_size of 7.

In [8]:
# text from the first full paragraph of http://utf8everywhere.org/
text = "Our goal is to promote usage and support of the UTF-8 encoding and to convince that it should be the default choice of encoding for storing text strings in memory or on disk, for communication and all other uses. We believe that our approach improves performance, reduces complexity of software and helps prevent many Unicode-related bugs. We suggest that other encodings of Unicode (or text, in general) belong to rare edge-cases of optimization and should be avoided by mainstream users."
tokens = text.encode("utf-8") # our raw bytes
tokens = list(map(int, tokens)) # converts our raw bytes to a list of int's in rang 0 - 255 for this example
print('---')
print(text)
print("length", len(text))
print('---')
print(tokens)
print("length", len(tokens))

---
Our goal is to promote usage and support of the UTF-8 encoding and to convince that it should be the default choice of encoding for storing text strings in memory or on disk, for communication and all other uses. We believe that our approach improves performance, reduces complexity of software and helps prevent many Unicode-related bugs. We suggest that other encodings of Unicode (or text, in general) belong to rare edge-cases of optimization and should be avoided by mainstream users.
length 489
---
[79, 117, 114, 32, 103, 111, 97, 108, 32, 105, 115, 32, 116, 111, 32, 112, 114, 111, 109, 111, 116, 101, 32, 117, 115, 97, 103, 101, 32, 97, 110, 100, 32, 115, 117, 112, 112, 111, 114, 116, 32, 111, 102, 32, 116, 104, 101, 32, 85, 84, 70, 45, 56, 32, 101, 110, 99, 111, 100, 105, 110, 103, 32, 97, 110, 100, 32, 116, 111, 32, 99, 111, 110, 118, 105, 110, 99, 101, 32, 116, 104, 97, 116, 32, 105, 116, 32, 115, 104, 111, 117, 108, 100, 32, 98, 101, 32, 116, 104, 101, 32, 100, 101, 102, 97, 1

In [9]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # A pythonic way to iterate consecutive elements in parallel
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)
print(stats)
# print(sorted(((v,k) for k,v in stats.items()), reverse=True))

{(79, 117): 1, (117, 114): 2, (114, 32): 8, (32, 103): 2, (103, 111): 1, (111, 97): 2, (97, 108): 3, (108, 32): 2, (32, 105): 5, (105, 115): 2, (115, 32): 7, (32, 116): 10, (116, 111): 4, (111, 32): 3, (32, 112): 3, (112, 114): 4, (114, 111): 3, (111, 109): 3, (109, 111): 2, (111, 116): 3, (116, 101): 4, (101, 32): 14, (32, 117): 3, (117, 115): 3, (115, 97): 1, (97, 103): 1, (103, 101): 4, (32, 97): 8, (97, 110): 7, (110, 100): 5, (100, 32): 9, (32, 115): 7, (115, 117): 2, (117, 112): 1, (112, 112): 2, (112, 111): 1, (111, 114): 8, (114, 116): 1, (116, 32): 9, (32, 111): 11, (111, 102): 6, (102, 32): 5, (116, 104): 7, (104, 101): 5, (32, 85): 3, (85, 84): 1, (84, 70): 1, (70, 45): 1, (45, 56): 1, (56, 32): 1, (32, 101): 4, (101, 110): 5, (110, 99): 5, (99, 111): 8, (111, 100): 5, (100, 105): 4, (105, 110): 9, (110, 103): 6, (103, 32): 4, (32, 99): 4, (111, 110): 5, (110, 118): 1, (118, 105): 1, (99, 101): 4, (104, 97): 3, (97, 116): 6, (105, 116): 2, (115, 104): 2, (104, 111): 3, (111,

In [10]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # A pythonic way to iterate consecutive elements in parallel
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)
# print(stats)
print(sorted(((v,k) for k,v in stats.items()), reverse=True))

[(14, (101, 32)), (11, (32, 111)), (10, (32, 116)), (9, (116, 32)), (9, (105, 110)), (9, (100, 32)), (8, (114, 32)), (8, (111, 114)), (8, (99, 111)), (8, (32, 97)), (7, (116, 104)), (7, (115, 32)), (7, (97, 110)), (7, (32, 115)), (6, (114, 101)), (6, (111, 102)), (6, (110, 103)), (6, (97, 116)), (6, (32, 98)), (5, (111, 110)), (5, (111, 100)), (5, (110, 100)), (5, (110, 99)), (5, (110, 32)), (5, (104, 101)), (5, (102, 32)), (5, (101, 115)), (5, (101, 114)), (5, (101, 110)), (5, (32, 105)), (4, (121, 32)), (4, (116, 111)), (4, (116, 101)), (4, (115, 116)), (4, (112, 114)), (4, (105, 99)), (4, (103, 101)), (4, (103, 32)), (4, (101, 108)), (4, (101, 100)), (4, (100, 105)), (4, (100, 101)), (4, (99, 101)), (4, (98, 101)), (4, (32, 101)), (4, (32, 99)), (3, (118, 101)), (3, (117, 115)), (3, (117, 108)), (3, (116, 105)), (3, (115, 101)), (3, (115, 46)), (3, (114, 111)), (3, (111, 117)), (3, (111, 116)), (3, (111, 109)), (3, (111, 32)), (3, (110, 105)), (3, (109, 97)), (3, (104, 111)), (3, (1

In [11]:
# chr() can kinda be thought of the opposite of ord()
chr(101), chr(32)

('e', ' ')

Unsurprisingly to those familiar with English, there is a lot of sequences of "e" followed by " " (space) in a given paragraph. 

In [12]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # A pythonic way to iterate consecutive elements in parallel
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)
# print(stats)
# print(sorted(((v,k) for k,v in stats.items()), reverse=True))

top_pair = max(stats, key=stats.get)
top_pair

(101, 32)

In [13]:
def merge(ids, pair, idx):
    # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
    newids = []
    i = 0
    while i < len(ids):
        # if we are not at the very last position AND the pair matches, replace it
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2 
        else:
            newids.append(ids[i])
            i += 1
    return newids

print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))

#tokens2 = merge(tokens, top_pair, 256)
#print(tokens2)
#print("length:", len(tokens2))

[5, 6, 99, 9, 1]


In [14]:
def merge(ids, pair, idx):
    # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
    newids = []
    i = 0
    while i < len(ids):
        # if we are not at the very last position AND the pair matches, replace it
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2 
        else:
            newids.append(ids[i])
            i += 1
    return newids

#print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))

tokens2 = merge(tokens, top_pair, 256)
print(tokens2)
print("length:", len(tokens2))

[79, 117, 114, 32, 103, 111, 97, 108, 32, 105, 115, 32, 116, 111, 32, 112, 114, 111, 109, 111, 116, 256, 117, 115, 97, 103, 256, 97, 110, 100, 32, 115, 117, 112, 112, 111, 114, 116, 32, 111, 102, 32, 116, 104, 256, 85, 84, 70, 45, 56, 32, 101, 110, 99, 111, 100, 105, 110, 103, 32, 97, 110, 100, 32, 116, 111, 32, 99, 111, 110, 118, 105, 110, 99, 256, 116, 104, 97, 116, 32, 105, 116, 32, 115, 104, 111, 117, 108, 100, 32, 98, 256, 116, 104, 256, 100, 101, 102, 97, 117, 108, 116, 32, 99, 104, 111, 105, 99, 256, 111, 102, 32, 101, 110, 99, 111, 100, 105, 110, 103, 32, 102, 111, 114, 32, 115, 116, 111, 114, 105, 110, 103, 32, 116, 101, 120, 116, 32, 115, 116, 114, 105, 110, 103, 115, 32, 105, 110, 32, 109, 101, 109, 111, 114, 121, 32, 111, 114, 32, 111, 110, 32, 100, 105, 115, 107, 44, 32, 102, 111, 114, 32, 99, 111, 109, 109, 117, 110, 105, 99, 97, 116, 105, 111, 110, 32, 97, 110, 100, 32, 97, 108, 108, 32, 111, 116, 104, 101, 114, 32, 117, 115, 101, 115, 46, 32, 87, 256, 98, 101, 108, 105,

We did a single pass through where we replaced 14 pairs (the (101,32) pair) and thus reduced our tokens from 489 - 14 to get 475.

In [15]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # A pythonic way to iterate consecutive elements in parallel
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
    newids = []
    i = 0
    while i < len(ids):
        # if we are not at the very last position AND the pair matches, replace it
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2 
        else:
            newids.append(ids[i])
            i += 1
    return newids

# ---
vocab_size = 276 # our desired final vocab size for this example
num_merges = vocab_size - 256
ids = list(tokens) # makes a copy so we don't destroy the original list

merges = {} # (int, int) -> int | child1, child2 merges
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f"merging {pair} into new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

merging (101, 32) into new token 256
merging (32, 111) into new token 257
merging (100, 32) into new token 258
merging (105, 110) into new token 259
merging (99, 111) into new token 260
merging (114, 32) into new token 261
merging (97, 110) into new token 262
merging (116, 104) into new token 263
merging (97, 116) into new token 264
merging (115, 32) into new token 265
merging (262, 258) into new token 266
merging (101, 110) into new token 267
merging (260, 100) into new token 268
merging (259, 103) into new token 269
merging (116, 32) into new token 270
merging (116, 111) into new token 271
merging (112, 114) into new token 272
merging (257, 102) into new token 273
merging (258, 98) into new token 274
merging (101, 108) into new token 275


Notice that even our newly minted tokens are, in turn, eligible for merging as we continue to interate through. 

In [16]:
print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")

tokens length: 489
ids length: 362
compression ratio: 1.35X


It's important to remember that a tokenizer is a completely seperate & independent module from the LLM itself. It has its own training dataset/corpus which, in turn, can be completely different than the dataset/corpus used to train the model. Later on the LLM only ever sees the BPE tokens and never deals directly with any text. 

## Decoding

Given a sequence of integers in the range [0, vocab_size], what is the text? 

In [17]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1] # adding two byte-objects together, concat

def decode(ids):
    # given ids (list of integers), returns a Python string
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8")
    return text

We want to be careful with this current decode implementation. It can throw an error when we pass in `128` for decoding 

In [18]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1] # adding two byte-objects together, concat

def decode(ids):
    # given ids (list of integers), returns a Python string
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8")
    return text

print(decode([128]))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte

In [19]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1] # adding two byte-objects together, concat

def decode(ids):
    # given ids (list of integers), returns a Python string
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors='replace') # this is how we address the potential for decode errors
    return text

print(decode([128]))

�


## Encoding 

The other way around: Given a string, what are tokens?

In [20]:
merges

{(101, 32): 256,
 (32, 111): 257,
 (100, 32): 258,
 (105, 110): 259,
 (99, 111): 260,
 (114, 32): 261,
 (97, 110): 262,
 (116, 104): 263,
 (97, 116): 264,
 (115, 32): 265,
 (262, 258): 266,
 (101, 110): 267,
 (260, 100): 268,
 (259, 103): 269,
 (116, 32): 270,
 (116, 111): 271,
 (112, 114): 272,
 (257, 102): 273,
 (258, 98): 274,
 (101, 108): 275}

In [21]:
stats

{(79, 117): 1,
 (117, 261): 2,
 (261, 103): 1,
 (103, 111): 1,
 (111, 97): 2,
 (97, 108): 3,
 (108, 32): 1,
 (32, 105): 3,
 (105, 265): 1,
 (265, 271): 1,
 (271, 32): 3,
 (32, 272): 1,
 (272, 111): 3,
 (111, 109): 1,
 (109, 111): 2,
 (111, 116): 1,
 (116, 256): 1,
 (256, 117): 1,
 (117, 115): 3,
 (115, 97): 1,
 (97, 103): 1,
 (103, 256): 1,
 (256, 266): 2,
 (266, 115): 2,
 (115, 117): 2,
 (117, 112): 1,
 (112, 112): 1,
 (112, 111): 1,
 (111, 114): 3,
 (114, 116): 1,
 (116, 273): 1,
 (273, 32): 3,
 (32, 263): 1,
 (263, 256): 2,
 (256, 85): 1,
 (85, 84): 1,
 (84, 70): 1,
 (70, 45): 1,
 (45, 56): 1,
 (56, 32): 1,
 (32, 267): 2,
 (267, 268): 3,
 (268, 269): 3,
 (269, 32): 3,
 (32, 266): 3,
 (266, 271): 1,
 (32, 260): 1,
 (260, 110): 1,
 (110, 118): 1,
 (118, 259): 1,
 (259, 99): 1,
 (99, 256): 2,
 (256, 263): 3,
 (263, 264): 3,
 (264, 32): 1,
 (105, 270): 1,
 (270, 115): 2,
 (115, 104): 2,
 (104, 111): 3,
 (111, 117): 2,
 (117, 108): 3,
 (108, 274): 2,
 (274, 256): 2,
 (256, 100): 1,
 (100

In [22]:
def encode(text):
    # given a string, return list of integers (the tokens)
    tokens = list(text.encode("utf-8"))
    while True:
        stats = get_stats(tokens)
        """
        Here we're iterating using python's min function over an iterator. Remembering that `stats` is a dictionary our example, 
        python's min with an iterator allows us to iterate over the keys of the stats dictionary. The key lambda takes the indexes
        from `merges` as our key for min
        """
        pair = min(stats, key=lambda p: merges.get(p, float("inf"))) 
        if pair not in merges:
            break # nothing else can be merged
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens
        

print(encode("hello world!"))

[104, 275, 108, 111, 32, 119, 111, 114, 108, 100, 33]


This is good and it works, but it turns out we are leaving out a special case. 

In [23]:
def encode(text):
    # given a string, return list of integers (the tokens)
    tokens = list(text.encode("utf-8"))
    while True:
        stats = get_stats(tokens)
        """
        Here we're iterating using python's min function over an iterator. Remembering that `stats` is a dictionary our example, 
        python's min with an iterator allows us to iterate over the keys of the stats dictionary. The key lambda takes the indexes
        from `merges` as our key for min
        """
        pair = min(stats, key=lambda p: merges.get(p, float("inf"))) 
        if pair not in merges:
            break # nothing else can be merged
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens
        

print(encode("h"))

ValueError: min() iterable argument is empty

If we only have a single character or a completely empty string, then `stats` becomes empty and causes an error inside `min`. 

In [24]:
def encode(text):
    # given a string, return list of integers (the tokens)
    tokens = list(text.encode("utf-8"))
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        """
        Here we're iterating using python's min function over an iterator. Remembering that `stats` is a dictionary our example, 
        python's min with an iterator allows us to iterate over the keys of the stats dictionary. The key lambda takes the indexes
        from `merges` as our key for min
        """
        pair = min(stats, key=lambda p: merges.get(p, float("inf"))) 
        if pair not in merges:
            break # nothing else can be merged
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens
        

print(encode("hello world!"))
print(encode("h"))
print(encode(""))

[104, 275, 108, 111, 32, 119, 111, 114, 108, 100, 33]
[104]
[]


If we encoded a string and then immediately decode it back you expect to get the same string returned. But is that true for all strings? 

In [25]:
print(decode(encode("hello world")))

hello world


While this may typically work, it's important to still remember that not all token sequences are valid UTF-8 byte streams. Thus, in those cases they wouldn't be decodable. 

In [26]:
text2 = decode(encode(text))
print(text2 == text)

True


Let's go ahead and start looking at some more SOTA LLM's and the kinds of tokenizers that they use. 

## Forced splits using regex patterns (GPT-2)
Start by reading "Language Models are Unsupervised Multitask Learners" (Radford et al. 2018)(aka, the GPT-2 paper), specifically the section Input Representation. In it, they mention that they encountered issues with the naive byte-level implementation of BPE that we've worked with so far in this notebook. They observed the implementation including many versions of common words like dog, since they occur in many variations such as "dog.", "dog!", "dog?", etc. This resulted in sub-optimal allocation of limited vocabulary slots and model capacity. 

To avoid this, they prevent BPE from merging across character categories for any given byte sequence. They also state that they added an exception for spaces which significantly improved their compression efficiency while adding minimal fragmentation of words across multiple vocab tokens. 

Below is the regex pattern they use in gpt-2's encoder.py (the tokenizer), which can be found on github. 

In [27]:
import regex as re
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

print(re.findall(gpt2pat, "Hello world"))

['Hello', ' world']


In [28]:
print(re.findall(gpt2pat, "Hello world123 how are you"))

['Hello', ' world', '123', ' how', ' are', ' you']


In [29]:
print(re.findall(gpt2pat, "Hello world123 how've you been"))

['Hello', ' world', '123', ' how', "'ve", ' you', ' been']


In [30]:
print(re.findall(gpt2pat, "Hello world123 how've you been!?!?"))

['Hello', ' world', '123', ' how', "'ve", ' you', ' been', '!?!?']


In [31]:
print(re.findall(gpt2pat, "Hello world123 how've you     been!?!?    "))

['Hello', ' world', '123', ' how', "'ve", ' you', '    ', ' been', '!?!?', '    ']


The regex appears to take a raw string (r""") and looks for many 'x suffixes using many regex "or" symbols (|). 

We see then see " ?\p{L}+" which allows an optional space followed by one or more letters of any kind from any language. 

" ?\p{N}+" means an optional space followed by any kind of numeric character in any script. 

We then see " ?[^\s\p{L}\p{N}]+" where we get another optional space followed by something that is not a letter or number, essentially trying to match punctuation. Lastly we see "\s+(?!\S)|\s+" which uses a **negative lookahead assertion** to match white space up to *but not including* the last white space character. 

Regex patterns like this one are one way you can enforce rules prohibiting certain kinds of merges when it comes to chunking the text.  

It's important to notice that OpenAI comment in their code that they "Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions" because their regex pattern only matches for the lowercase conditions that they wrote. Moreover the fundamental use of apostrophes is potentially english-language specific, but I've not rigourously investigated this enough to determine how severe of a limitation it is in practice. 

In [32]:
example = """
for i in range(1, 101):
    if i % 3 ==0 and i % 5 == 0:
        print("FizzBuzz")
    elif i % 3 == 0:
        print("Fizz")
    elif i % 5 == 0:
        print("Buzz")
    else:
        print(i)
"""
print(re.findall(gpt2pat, example))

['\n', 'for', ' i', ' in', ' range', '(', '1', ',', ' 101', '):', '\n   ', ' if', ' i', ' %', ' 3', ' ==', '0', ' and', ' i', ' %', ' 5', ' ==', ' 0', ':', '\n       ', ' print', '("', 'FizzBuzz', '")', '\n   ', ' elif', ' i', ' %', ' 3', ' ==', ' 0', ':', '\n       ', ' print', '("', 'Fizz', '")', '\n   ', ' elif', ' i', ' %', ' 5', ' ==', ' 0', ':', '\n       ', ' print', '("', 'Buzz', '")', '\n   ', ' else', ':', '\n       ', ' print', '(', 'i', ')', '\n']


In practice, OpenAI did enforce some sort of rule to prevent the merging of these space chunks when they trained their tokenizer for gpt-2. Unfortunately, they are not very clear about how they implemented said rule as the training code has yet to ever be released. One can see this is the case by going to `https://tiktokenizer.vercel.app/`, selecting gpt-2 encodings, and noticing that spaces are individually tokenized. 

## Tiktoken 

In [33]:
import tiktoken

# GPT-2 (does not merge spaces)
enc = tiktoken.get_encoding("gpt2")
print(enc.encode("    hello world!!!"))

# GPT-4 (merges spaces)
enc = tiktoken.get_encoding("cl100k_base")
print(enc.encode("    hello world!!!"))

[220, 220, 220, 23748, 995, 10185]
[262, 24748, 1917, 12340]


Tiktoken is OpenAI's official tokenizer, and `cl100k_base` is the one used for GPT-4. Unfortunately again, OpenAI has only published tiktoken's inference code and not its training code. They also changed the regex pattern:

In [34]:
import regex as re
gpt4pat = re.compile(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""")

print(re.findall(gpt4pat, "Hello world123 how've you     been!?!?    "))

['Hello', ' world', '123', ' how', "'ve", ' you', '    ', ' been', '!?!?', '    ']


With this regex pattern, we can see some pretty stark differences. We still match a Python raw string literal with r"""

We then see '(?i:[sdmt]|ll|ve|re) where ' matches a single quote, (?i:...) indicates case-insensitive matching for the contents, [sdmt]|ll|ve|re matches 's', 'd', 'm', 't', 'll','ve', or 're' in order to handle contractions

| indicates an "or" like we saw before

With [^\r\n\p{L}\p{N}]?+, we see that [^...] negates the character sets \r (return) \n (new line), \p{L} (any kind of letter), \p{N} any kind of numeric character, and ?+ for a possive optional quantifier of any single character that is not a letter, number, or line break. 

\p{L}+ for one or more letters from any language like we saw with GPT-2's regex pattern

\p{N}{1,3} matches 1 to 3 numeric characters

?[^\s\p{L}\p{N}]++ indicating any optional space followed by a match of any character that is not whitespace, a letter, or number. ++ for possive one or more quantifier matching any non-alphanumeric, non-whitespace character(s), which could possibly be preceded by a space. 

[\r\n]* matches zero or more carriage returns or new lines. 

\s*[\r\n] matches any whitespace followed by a line break.

\s+(?!\S), where \s+ matches one or more whitespace characters and (?!\S) uses negative lookahead for any non-whitespace character and essentially matches whitespace at the end of a string. 

Finally \s+ again matches one ore more whitespace characters. 

## Special Tokens

Special tokens help us delimit certain parts of the data or create special parts of a structure of the token strings.

One example of a commonly used special token is the EOT (or End Of Text) token, which is inserted at the end of strings or documents during training to indicate to the LLM that the document has ended and a new document is starting. This and other special tokens are essential for fine-tuning a foundation LLM model into something like a chatbot, etc. 

## Sentencepiece

Sentencepiece is commonly used in other LLM's like the Llama series, T5, Mistral series and others. The reason being is that, unlike tiktoken, it can efficiently both train and inference BPE tokenizers AND it runs BPE on Unicode code points directly! 

Tiktoken encodes to utf-8 and then BPE's bytes, whereas sentencepiece BPE's the code points and optionally falls back to utf-8 for rare code points (where rarity is determined by the character_coverage hyperparameter), which then get translated to byte tokens. 