In [2]:
from collections import OrderedDict
import typing

import regex as re

We are looking for a way to tokenize strings of text. One possible way is to encode the text in unicode or UTF-8.

In [3]:
s = '안녕하세요 👋 (hello in Korean)!'
s

'안녕하세요 👋 (hello in Korean)!'

In [4]:
print([ord(x) for x in s])

[50504, 45397, 54616, 49464, 50836, 32, 128075, 32, 40, 104, 101, 108, 108, 111, 32, 105, 110, 32, 75, 111, 114, 101, 97, 110, 41, 33]


In [5]:
s.encode('utf-8')

b'\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94 \xf0\x9f\x91\x8b (hello in Korean)!'

In [6]:
print(list(s.encode('utf-8')))

[236, 149, 136, 235, 133, 149, 237, 149, 152, 236, 132, 184, 236, 154, 148, 32, 240, 159, 145, 139, 32, 40, 104, 101, 108, 108, 111, 32, 105, 110, 32, 75, 111, 114, 101, 97, 110, 41, 33]


In [7]:
print(list(s.encode('utf-16')))

[255, 254, 72, 197, 85, 177, 88, 213, 56, 193, 148, 198, 32, 0, 61, 216, 75, 220, 32, 0, 40, 0, 104, 0, 101, 0, 108, 0, 108, 0, 111, 0, 32, 0, 105, 0, 110, 0, 32, 0, 75, 0, 111, 0, 114, 0, 101, 0, 97, 0, 110, 0, 41, 0, 33, 0]


In [8]:
print(list(s.encode('utf-32')))

[255, 254, 0, 0, 72, 197, 0, 0, 85, 177, 0, 0, 88, 213, 0, 0, 56, 193, 0, 0, 148, 198, 0, 0, 32, 0, 0, 0, 75, 244, 1, 0, 32, 0, 0, 0, 40, 0, 0, 0, 104, 0, 0, 0, 101, 0, 0, 0, 108, 0, 0, 0, 108, 0, 0, 0, 111, 0, 0, 0, 32, 0, 0, 0, 105, 0, 0, 0, 110, 0, 0, 0, 32, 0, 0, 0, 75, 0, 0, 0, 111, 0, 0, 0, 114, 0, 0, 0, 101, 0, 0, 0, 97, 0, 0, 0, 110, 0, 0, 0, 41, 0, 0, 0, 33, 0, 0, 0]


Note how the UTF-16 and UTF-32 encodings above contain many 0s. This redundancy is wasteful, hence why we prefer UTF-8.

However, a problem with UTF-8 is that characters are encoded into byte streams, meaning that we only have 256 possible tokens (a byte can take 256 possible values). This small vocabulary size leads to long, stretched-out vectors that embed the text, which eats up the model's context length quickly.

We can solve this via the byte pair encoding algorithm.

## Byte-Pair Encoding

In [9]:
text = 'Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception.'
tokens = text.encode("utf-8")
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience
print('---')
print(f'Text: {text}')
print(f'Length of text: {len(text)}')
print('---')
print(f'Tokens: {tokens}')
print(f'Length of tokens: {len(tokens)}')

---
Text: Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception.
Length of text: 533
---
Tokens: [239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 1

In [10]:
def pair_counts(code_points: typing.List[int]) -> typing.Dict[typing.Tuple[int, int], int]:
    counts = dict()
    for pair in zip(code_points, code_points[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

In [11]:
counts_dict = pair_counts(tokens)
print(sorted(counts_dict.items(), key=lambda pair: pair[1], reverse=True))

[((101, 32), 20), ((240, 159), 15), ((226, 128), 12), ((105, 110), 12), ((115, 32), 10), ((32, 97), 10), ((97, 110), 10), ((32, 116), 9), ((116, 104), 8), ((159, 133), 7), ((159, 135), 7), ((97, 114), 7), ((239, 189), 6), ((128, 140), 6), ((140, 240), 6), ((104, 101), 6), ((101, 114), 6), ((109, 101), 6), ((114, 32), 6), ((110, 100), 6), ((32, 105), 6), ((111, 114), 6), ((116, 32), 6), ((110, 103), 6), ((32, 115), 5), ((115, 116), 5), ((100, 101), 5), ((110, 32), 5), ((117, 115), 5), ((44, 32), 5), ((97, 109), 4), ((114, 105), 4), ((32, 102), 4), ((101, 97), 4), ((100, 32), 4), ((110, 116), 4), ((32, 111), 4), ((32, 119), 4), ((111, 117), 4), ((32, 85), 4), ((85, 110), 4), ((110, 105), 4), ((105, 99), 4), ((99, 111), 4), ((111, 100), 4), ((104, 97), 4), ((116, 101), 4), ((103, 32), 4), ((115, 44), 4), ((116, 105), 4), ((32, 240), 3), ((118, 101), 3), ((116, 114), 3), ((101, 115), 3), ((116, 111), 3), ((111, 32), 3), ((114, 116), 3), ((116, 115), 3), ((111, 102), 3), ((32, 112), 3), ((1

Note from above, the token pair `(101, 32)` is the most commonly occurring pair of tokens in the text. Here's what that sequence is:

In [12]:
chr(101), chr(32)

('e', ' ')

Here is a Pythonic way to get the most commonly occurring pair from the dictionary of counts.

In [13]:
top_pair = max(counts_dict, key=counts_dict.get)
top_pair

(101, 32)

Now we need a function that can merge a pair of tokens by a new token.

In [14]:
def replace_pairs(code_points: typing.List[int], pair_to_replace: typing.Tuple[int, int], replacement: int) -> typing.List[int]:
    new_code_points = []
    i = 0
    while i < len(code_points):
        if i < len(code_points) - 1 and code_points[i] == pair_to_replace[0] and code_points[i+1] == pair_to_replace[1]:
            new_code_points.append(replacement)
            i += 2
        else:
            new_code_points.append(code_points[i])
            i += 1
    return new_code_points

In [15]:
replace_pairs([1, 2, 2, 3, 2, 4, 1, 2, 9, 3, 2, 1, 2], (1, 2), 100)

[100, 2, 3, 2, 4, 100, 9, 3, 2, 100]

We can replace the pair `(101, 32)` in our tokens with a new token: `256`.

In [16]:
old_length = len(tokens)
tokens2 = replace_pairs(tokens, top_pair, 256)
new_length = len(tokens2)
print(tokens2)

[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140, 240, 159, 135, 169, 226, 128, 140, 240, 159, 135, 170, 33, 32, 240, 159, 152, 132, 32, 84, 104, 256, 118, 101, 114, 121, 32, 110, 97, 109, 256, 115, 116, 114, 105, 107, 101, 115, 32, 102, 101, 97, 114, 32, 97, 110, 100, 32, 97, 119, 256, 105, 110, 116, 111, 32, 116, 104, 256, 104, 101, 97, 114, 116, 115, 32, 111, 102, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 115, 32, 119, 111, 114, 108, 100, 119, 105, 100, 101, 46, 32, 87, 256, 97, 108, 108, 32, 107, 110, 111, 119, 32, 119, 256, 111, 117, 103, 104, 116, 32, 116, 111, 32, 226, 128, 156

In [17]:
print(f'Original length of tokens: {old_length}, Length after replacing the top pair: {new_length}, Difference: {old_length - new_length}')

Original length of tokens: 616, Length after replacing the top pair: 596, Difference: 20


Before applying the full byte pair encoding algorithm, we're going to get a larger piece of text to get better pair counting statistics.

In [18]:
# Text is from https://www.reedbeta.com/blog/programmers-intro-to-unicode/
text = open('unicode-article.txt', 'r', encoding='utf-8').read()
tokens = text.encode("utf-8")
tokens = list(map(int, tokens))
len(tokens)

24636

In [19]:
original_vocab_size = 256
target_vocab_size = 276
number_of_merges = target_vocab_size - original_vocab_size
print(f'Number of times to merge: {number_of_merges}')

Number of times to merge: 20


In [20]:
def bpe(
    code_points: typing.List[int],
    original_vocab_size: int,
    number_of_merges: int
) -> typing.Tuple[typing.List[int], typing.Dict[typing.Tuple[int, int], int]]:
    # Make a copy of the code points list
    replacements = OrderedDict()
    for i in range(number_of_merges):
        counts_dict = pair_counts(code_points)
        top_pair = max(counts_dict, key=counts_dict.get)
        new_index = original_vocab_size + i
        print(f'Merging {top_pair} into new token {new_index}')
        code_points = replace_pairs(code_points, top_pair, new_index)
        replacements[top_pair] = new_index
    return code_points, replacements

In [21]:
bpe_tokens, replacements = bpe(tokens, original_vocab_size, number_of_merges)

Merging (101, 32) into new token 256
Merging (105, 110) into new token 257
Merging (115, 32) into new token 258
Merging (116, 104) into new token 259
Merging (101, 114) into new token 260
Merging (99, 111) into new token 261
Merging (116, 32) into new token 262
Merging (226, 128) into new token 263
Merging (44, 32) into new token 264
Merging (97, 110) into new token 265
Merging (111, 114) into new token 266
Merging (100, 32) into new token 267
Merging (97, 114) into new token 268
Merging (101, 110) into new token 269
Merging (257, 103) into new token 270
Merging (261, 100) into new token 271
Merging (121, 32) into new token 272
Merging (46, 32) into new token 273
Merging (97, 108) into new token 274
Merging (259, 256) into new token 275


In [22]:
replacements

OrderedDict([((101, 32), 256),
             ((105, 110), 257),
             ((115, 32), 258),
             ((116, 104), 259),
             ((101, 114), 260),
             ((99, 111), 261),
             ((116, 32), 262),
             ((226, 128), 263),
             ((44, 32), 264),
             ((97, 110), 265),
             ((111, 114), 266),
             ((100, 32), 267),
             ((97, 114), 268),
             ((101, 110), 269),
             ((257, 103), 270),
             ((261, 100), 271),
             ((121, 32), 272),
             ((46, 32), 273),
             ((97, 108), 274),
             ((259, 256), 275)])

In [23]:
print(f'Original tokens length: {len(tokens)}, Compressed tokens length: {len(bpe_tokens)}, Compression ratio: {len(tokens)/len(bpe_tokens):.2f}x')

Original tokens length: 24636, Compressed tokens length: 19484, Compression ratio: 1.26x


Now that we have trained the tokenizer using BPE, we can now perform encoding and decoding.

In [24]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in replacements.items():
    vocab[idx] = vocab[p0] + vocab[p1]

In [25]:
def decode(ids: typing.List[int], vocab: typing.Dict[int, bytes]) -> str:
    tokens = b''.join(vocab[idx] for idx in ids)
    text = tokens.decode('utf-8', errors='replace')
    return text

In [26]:
def encode(text: str, replacements: typing.Dict[typing.Tuple[int, int], int]) -> typing.List[int]:
    tokens = list(text.encode('utf-8'))
    while len(tokens) >= 2:
        counts_dict = pair_counts(tokens)
        pair = min(counts_dict, key=lambda p: replacements.get(p, float('inf')))
        if pair not in replacements:
            # Nothing else can be merged
            break
        idx = replacements[pair]
        tokens = replace_pairs(tokens, pair, idx)
    return tokens

In [27]:
print(encode('hello world', replacements))
print(decode(encode('hello world', replacements), vocab))

[104, 101, 108, 108, 111, 32, 119, 266, 108, 100]
hello world


In [28]:
text2 = decode(encode(text, replacements), vocab)
text == text2

True

## Splitting the Text into Mergeable Portions

Sometimes we want to force the tokenizer to not merge certain tokens. We use the following regular expression to split the text into mergeable sub-texts. On each of these sub-texts, we apply the tokenizer individually and concatenate the results.

In [29]:
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

In [32]:
print(gpt2pat.findall('hello world'))

['hello', ' world']


In [40]:
print(gpt2pat.findall("hello're     123world's!!!?   "))

['hello', "'re", '    ', ' 123', 'world', "'s", '!!!?', '   ']


Note that there are some issues with how this pre-processing step splits up the text. For example, uppercase suffixes and unicode apostrophes.

In [42]:
print(gpt2pat.findall("how's HOW'S how’s"))

['how', "'s", ' HOW', "'", 'S', ' how', '’', 's']


We can also use the tiktoken library to use GPT's tokenizers. Notices how the `gpt2` tokenizer (used in GPT-2) does not merge spaces `220` is a space, whereas the `cl100k_base` tokenizer (used in GPT-4) merges multiple spaces into a single token `262`.

In [6]:
import tiktoken

example_text = '    hello world!!!'

enc = tiktoken.get_encoding('gpt2')
print(f'gpt2 tokenizer tokenizing "{example_text}": {enc.encode(example_text)}')

enc = tiktoken.get_encoding('cl100k_base')
print(f'cl100k_base tokenizer tokenizing "{example_text}": {enc.encode(example_text)}')

gpt2 tokenizer tokenizing "    hello world!!!": [220, 220, 220, 23748, 995, 10185]
cl100k_base tokenizer tokenizing "    hello world!!!": [262, 24748, 1917, 12340]


## minbpe