## Tokenization with byte pair encoding

This notebook, lecture, and homework is based on Anrey Karpathy's excellent online lecture on tokenizers:
https://www.youtube.com/watch?v=zduSFxRajkE

## Byte Pair Encoding

The original Byte Pair Encoding (BPE) was introduced for data compression and iteratively replaces the most commen sequence of characters with a new variable.

Here is the example from https://en.wikipedia.org/wiki/Byte_pair_encoding:


Suppose the data string is: "aaabdaaabac".

The counts of the byte pairs are:
(aa, 5),  (ab, 2), (bd, 1), (da, 1), (ac, 1)

The most frequent byte pair is aa, so it is replaced with an ununsed byte, say Z. This yields:

ZabdZabac  
Z = aa

For this data string, the most frequent pair is ab, which is replaced with Y, yielding:

ZYdZYac  
Z = aa  
Y = ab

For this data string, the most frequent pair is ZY, replaced by X yields:

XdXac  
Z = aa  
Y = ab  
X = ZY

Now there is no sequence that appears more than once.

BPE for natural language processing (NLP) works analogously, only that we start with a dictionary which initially consists of the byte valies of the 256 ASCII characters. Then we add values 256,257, ... and so on by merging the most frequent pairs.


## History of BPE

BPE for language modeling was introduced by Sennrich et al., ``Neural Machine Translation of Rare Words with Subword Units''.

The paper writes as it's main contributions:
- "We show that open-vocabulary neural machine translation is possible by encoding (rare) words via subword units."
- "We adapt byte pair encoding (BPE) (Gage, 1994), a compression algorithm, to the task of word segmentation. BPE allows for the representation of an open vocabulary through a fixed-size vocabulary of variable-length character sequences, making it a very suitable word segmentation strategy for neural network models''

The GPT2 paper used a BPE encoding algorithm and justifies the choice as follows:

"A general language model (LM) should be able to compute the probability of (and also generate) any string. Current large scale LMs include pre-processing steps such as lowercasing, tokenization, and out-of-vocabulary tokens which restrict the space of model-able strings. While processing Unicode strings as a sequence of UTF-8 bytes elegantly fulfills this requirement as exemplified in work such as Gillick et al. (2015), current byte-level LMs are not competitive with word-level LMs on large scale datasets such as the One Billion Word Benchmark (Al-Rfou et al., 2018). We observed a similar performance gap in our own attempts to train standard byte-level LMs on WebText.

Byte Pair Encoding (BPE) (Sennrich et al., 2015) is a practical middle ground between character and word level language modeling which effectively interpolates between word level inputs for frequent symbol sequences and character level inputs for infrequent symbol sequences. Despite its name, reference BPE implementations often operate on Unicode code points and not byte sequences. These implementations would require including the full space of Unicode symbols in order to model all Unicode strings. This would result in a base vocabulary of over 130,000 before any multi-symbol tokens are added. This is prohibitively large compared to the 32,000 to 64,000 token vocabularies often used with BPE. In contrast, a byte-level version of BPE only requires a base vocabulary of size 256."






# Step-by-step to a tokenizer

### Converting characters to bytes

Charaters are typically encoded using UTF-8, a character encoding that encodes all possible characters in Unicode as a sequence of one to four bytes. UTF-8 is designed to be backward compatible with ASCII for the first 128 characters, making it efficient for texts where these characters are predominant. UTF-8 is the dominant encoding for the web and many computing systems because it can represent any character in the Unicode standard, yet is space-efficient for texts primarily using Latin characters.

In [None]:
# Convert text to bytes

text = "Hello World!"
tokens = text.encode('utf-8') # convert text to bytes
print( tokens, ", number of bytes: ", len(tokens) ) # one token per character

text = "ö"
tokens = text.encode('utf-8') # convert text to bytes
print( tokens, ", number of bytes: ", len(tokens) ) # two token for this special character

text = "😀"
tokens = text.encode('utf-8') # convert text to bytes
print( tokens, ", number of bytes: ", len(tokens) ) # four token for this special character


# Let's use the same string as in the previous example for illustrating BPE

text = "aaabdaaabac"
tokens = text.encode('utf-8') # convert text to bytes
tokens = list(map(int, tokens)) # convert bytes to list of integers
print(f"String in previous example as tokens: {tokens}")


b'Hello World!' , number of bytes:  12
b'\xc3\xb6' , number of bytes:  2
b'\xf0\x9f\x98\x80' , number of bytes:  4
String in previous example as tokens: [97, 97, 97, 98, 100, 97, 97, 97, 98, 97, 99]


### Computing frequencies of consequitive bytes


In [None]:
def get_frequencies(seq):
    counts = {}
    for pair in zip(seq,seq[1:]): # iterate over consequitive elements
        counts[pair] = counts.get(pair,0) + 1
    return counts

frequencies = get_frequencies(tokens)

sorted(frequencies.items(), key=lambda x: x[1], reverse=True)

print(frequencies)

{(97, 97): 4, (97, 98): 2, (98, 100): 1, (100, 97): 1, (98, 97): 1, (97, 99): 1}


### Merging function

In [None]:
top_pair = max(frequencies.items(), key=lambda x: x[1])[0]
print(top_pair)

def merge(seq,pair,index):
    new_seq = []
    i = 0
    while i < len(seq):
        if seq[i:i+2] == list(pair) and i < len(seq) - 1:
            new_seq.append(index)
            i += 2
        else:
            new_seq.append(seq[i])
            i += 1
    return new_seq

new_seq = merge(tokens,top_pair,256)
print(tokens)
print(new_seq)



(97, 97)
[97, 97, 97, 98, 100, 97, 97, 97, 98, 97, 99]
[256, 97, 98, 100, 256, 97, 98, 97, 99]


### Tokenizer training loop

In [None]:
vocab_size = 256+5 # desired final vocabulary size
num_merges = vocab_size - 256
seq = list(tokens) # make a copy of the original sequence
merges = {} # dictionary to store merges, int x int -> int
for i in range(num_merges):
    frequencies = get_frequencies(seq)
    pair = max(frequencies.items(), key=lambda x: x[1])[0]
    seq = merge(seq,pair,256+i)
    merges[pair] = 256+i
    print(f"merge {pair} -> {256+i}")


    """[97, 97, 97, 98, 100, 97, 97, 97, 98, 97, 99]

    ** 256:
    (97,97) , 4
    [256, 97, 98, 100, 256, 97, 98, 97, 99]

    ** 257:
    (256,97), 2
    [257, 98, 100, 257, 98, 97, 99]

    **258:
    (257,98), 2
    [258, 100, 258, 97, 99]

    ** 259:
    (258,100), 1
    [259, 258, 97, 99]

    ** 260:
    (259,258), 1
    [260, 97, 99]"""



merge (97, 97) -> 256
merge (256, 97) -> 257
merge (257, 98) -> 258
merge (258, 100) -> 259
merge (259, 258) -> 260


### Decoding

Given a sequence of integers in the range [0,vocab_size], decoding yields the corresponding text

In [None]:
# construct the vocabulary

# the first 172 elements of the are the ASCII characters
vocab = {idx: bytes([idx]) for idx in range(256)}

# the rest of the vocabulary is constructed from the merges
for (i,j), idx in merges.items(): # this must be done in the same order as we filled in the merges dictionary
    # concatenates the bytes; for example for our first merge: b'a' + b'a' -> b'aa'
    vocab[idx] = vocab[i] + vocab[j]

print(vocab)

def decode(seq,vocab):
    # convert the sequence of integers to bytes, and then concatenate them
    tokens = b"".join([vocab[s] for s in seq]) # concatenate the bytes
    # replace option is used to replace unknown characters with a special character,
    # incase the LLM
    text = tokens.decode('utf-8',errors='replace')
    return text

{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[',

In [None]:
def encode(text):
    # convert the string to list of integers
    tokens = list(text.encode('utf-8'))

    while True:
        freqs = get_frequencies(tokens)
        # find the pair that has the lowest index in the merges dictionary
        # since this prioeritize the pairs that were merged first
        # the custom key function looks up each pair's index in the merges dictionary,
        # and defaults to infinity if the pair is not in the merges dictionary
        pair = min(freqs, key=lambda x: merges.get(x,float('inf')))
        if pair not in merges: # nothing to merge
            break
        # merge the pair
        index = merges[pair]
        tokens = merge(tokens,pair,index)
    return tokens

print(encode("aaabdaaabac"))

print(decode(encode("aaabdaaabac"),vocab))





[260, 97, 99]
aaabdaaabac


### Pre-tokenization

Pre-tokenization is the process of using a set of rules to restrict the creation of certain tokens. Pre-tokenization is often done via splitting with regular expressions before appling the BPE algorithm.

The motivation for this is as follows, from Radford et al. "Language Models are Unsupervised Multitask Learners":

"However, directly applying BPE to the byte sequence results in suboptimal merges due to BPE using a greedy frequency based heuristic for building the token vocabulary. We observed BPE including many versions of common words like dog since they occur in many variations such as dog. dog! dog? . This results in a sub-optimal allocation of limited vocabulary slots and model capacity. To avoid this, we prevent BPE from merging across character categories for any byte sequence. We add an exception for spaces which significantly improves the compression efficiency while adding only minimal fragmentation of words across multiple vocab tokens. This input representation allows us to combine the empirical benefits of word-level LMs with the generality of byte-level approaches. Since our approach can assign a probability to any Unicode string, this allows us to evaluate our LMs on any dataset regardless of pre-processing, tokenization, or vocab size."

In [3]:
import regex as re

# the main GPT text split patterns, see
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""


# the GP2 split pattern splits the text into words, numbers, punctuation, and whitespace segments.
# a few examples:

print(re.findall(GPT2_SPLIT_PATTERN, "He's in the gym, and he'll work out for an hour."))
print(re.findall(GPT2_SPLIT_PATTERN, "The cost is 200,000.00 dollars. Thats 2x the cost of the previous model."))

print(re.findall(GPT4_SPLIT_PATTERN, "He's in the gym, and he'll work out for an hour."))
print(re.findall(GPT4_SPLIT_PATTERN, "The cost is 200,000.00 dollars. Thats 2x the cost of the previous model."))


# only merges within elements in those lists are considered in the BPE algorithm

['He', "'s", ' in', ' the', ' gym', ',', ' and', ' he', "'ll", ' work', ' out', ' for', ' an', ' hour', '.']
['The', ' cost', ' is', ' 200', ',', '000', '.', '00', ' dollars', '.', ' Thats', ' 2', 'x', ' the', ' cost', ' of', ' the', ' previous', ' model', '.']
['He', "'s", ' in', ' the', ' gym', ',', ' and', ' he', "'ll", ' work', ' out', ' for', ' an', ' hour', '.']
['The', ' cost', ' is', ' ', '200', ',', '000', '.', '00', ' dollars', '.', ' Thats', ' ', '2', 'x', ' the', ' cost', ' of', ' the', ' previous', ' model', '.']


### Treatment of special tokens
'<|endoftext|>': markes the end of one text, gets assigned an extra tokens, is applied outside of the BPE algorithm and is not getting merged

# Concluding notes

Try this webinterface to see different tokenizers in action:
https://huggingface.co/spaces/Xenova/the-tokenizer-playground

Here you can inspect the vocabulary of GPT4:
https://github.com/kaisugi/gpt4_vocab_list/blob/main/cl100k_base_vocab_list.txt

Here are the vocab sizes for a few popoular language models:


| Model         |  Vocab Size |
|---------------|--------------|
| GPT-1         | 40,478     |
| GPT-2         | 50,304     |
| GPT-3 (large) | 50,257     |
| GPT-4         |  100,256    |
| LLama 2       | 32,000      |
| LLama 3       | 128,256     |


- BPE merges the most common pairs of tokens; another option for merging used by WordPiece is to use the pointwise mutal information instead the count to decide which tokens to merge

- Tokenizer and language models should be trained on same or similar data. A tokenizer trained on natural language will not perform well when used for a llm for code, it will lead to sequences beeing encoded inefficiently (average of character/tokens is small) and sub-optimal llm performance

- Good tokenizers are often efficient in that they somewhat compress text (i.e., the average of character/tokens is small).  Howerver, fewer tokens do not automatically lead to better downstream performance, we see this for example from pre-tokenization beeing efficient. See for example the paper: http://arxiv.org/abs/2402.18376










# Homework

Implement the train, decode, and encode functions of the tokenizer below.
The tokenizer should use the GPT2_SPLIT_PATTERN and treat the special token <|endoftext|> appropriately.

In [5]:
import regex as re
# -----------------------------------------------------------------------------
# the base Tokenizer class

class Tokenizer:
    """Base class for Tokenizers"""

    def __init__(self):
        # default: vocab size of 256 (all bytes), no merges, no patterns
        self.merges = {} # (int, int) -> int
        self.pattern = "" # str
        self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
        self.vocab = self._build_vocab() # int -> bytes


    def _build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab

    def train(self, text, vocab_size, verbose=False):
        # Tokenizer can train a vocabulary of size vocab_size from text

        # GPT2_SPLIT_PATTERN

          # '(?:[sdmt]|ll|ve|re) means 's 'd 'm 't 'll 've 're
          # ?\p{L} means all unicode letter, latin + arabic letter groups
          # ?\p{N} means all unicode numbers

          # [^   ] means exclude
          # all unicode letter and unicode numbers groups
          # are excluded.
          # i.e. This exception means '!', ',', '@#$%' punctuations & symbols

          # |\s+(?!\S)|\s+ matches spaces, tabs, etc.,

          if self.pattern == "":
            global GPT2_SPLIT_PATTERN
            GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
            self.pattern = GPT2_SPLIT_PATTERN

          tokens = re.findall(self.pattern, text )
          tokens_byte = [token.encode('utf-8') for token in tokens]

          merges = {}
          num_merges = vocab_size - 256
          next_id = 256

          for _ in range(num_merges):
            # Count frequencies of all adjacent byte pairs
            pair_counts = {}
            for token in tokens_byte:
                for i in range(len(token) - 1):
                    pair = (token[i], token[i+1])
                    pair_counts[pair] = pair_counts.get(pair, 0) + 1

            if not pair_counts:
                break  # No more pairs to merge

            # Find the most frequent pair
            most_freq_pair = max(pair_counts, key = pair_counts.get)
            merges[most_freq_pair] = next_id

            # Replace that pair in all tokens
            new_tokens_byte = []
            for token in tokens_byte:
                i = 0
                new_token = []
                while i < len(token):
                    if i < len(token) - 1 and (token[i], token[i+1]) == most_freq_pair:
                        new_token.append(next_id)
                        i += 2
                    else:
                        new_token.append(token[i])
                        i += 1
                new_tokens_byte.append(new_token)

            tokens_byte = new_tokens_byte
            next_id += 1


          self.merges = merges
          self.special_tokens = {'<|endoftext|>': 100257}
          self.vocab = self._build_vocab()


        #raise NotImplementedError

    def _apply_merges(self, byte_seq):
        token = list(byte_seq)
        merge_pairs = self.merges

        while True:
            pairs = list(zip(token, token[1:]))
            match = None
            for i in range(len(pairs)):
                if pairs[i] in merge_pairs:
                    match = (i, pairs[i])
                    break
            if not match:
                break
            i, pair = match
            token = token[:i] + [merge_pairs[pair]] + token[i+2:]
        return token

    def encode(self, text):
        # Tokenizer can encode a string into a list of integers

        tokens = re.findall(self.pattern, text)

        encoded_byte= []

        for token in tokens:
            token_byte = token.encode('utf-8')
            merged_byte = self._apply_merges(token_byte)

            encoded_byte.extend(merged_byte)

        encoded_byte.append(self.special_tokens['<|endoftext|>'])
        return encoded_byte


    def decode(self, ids):
        # Tokenizer can decode a list of integers into a string
        decoded_byte = []
        decoded_text = ""

        for id in ids:
          if id in self.vocab:
            decoded_byte.append(self.vocab[id])

          else: #For an unknown id
            decoded_byte.append(b'\xef\xbf\xbd')

        decoded_byte = b''.join(decoded_byte)
        decoded_text = decoded_byte.decode('utf-8', errors='replace')

        return decoded_text



    def save(self, file_prefix):
        """
        Saves two files: file_prefix.vocab and file_prefix.model
        This is inspired (but not equivalent to!) sentencepiece's model saving:
        - model file is the critical one, intended for load()
        - vocab file is just a pretty printed version for human inspection only
        """
        # write the model: to be used in load() later
        model_file = file_prefix + ".model"
        with open(model_file, 'w') as f:
            # write the version, pattern and merges, that's all that's needed
            f.write("minbpe v1\n")
            f.write(f"{self.pattern}\n")
            # write the special tokens, first the number of them, then each one
            f.write(f"{len(self.special_tokens)}\n")
            for special, idx in self.special_tokens.items():
                f.write(f"{special} {idx}\n")
            # the merges dict
            for idx1, idx2 in self.merges:
                f.write(f"{idx1} {idx2}\n")
        # write the vocab: for the human to look at
        vocab_file = file_prefix + ".vocab"
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:
            for idx, token in self.vocab.items():
                # note: many tokens may be partial utf-8 sequences
                # and cannot be decoded into valid strings. Here we're using
                # errors='replace' to replace them with the replacement char �.
                # this also means that we couldn't possibly use .vocab in load()
                # because decoding in this way is a lossy operation!
                s = render_token(token)
                # find the children of this token, if any
                if idx in inverted_merges:
                    # if this token has children, render it nicely as a merge
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
                else:
                    # otherwise this is leaf token, just print it
                    # (this should just be the first 256 tokens, the bytes)
                    f.write(f"[{s}] {idx}\n")

    def load(self, model_file):
        """Inverse of save() but only for the model file"""
        assert model_file.endswith(".model")
        # read the model file
        merges = {}
        special_tokens = {}
        idx = 256
        with open(model_file, 'r', encoding="utf-8") as f:
            # read the version
            version = f.readline().strip()
            assert version == "minbpe v1"
            # read the pattern
            self.pattern = f.readline().strip()
            # read the special tokens
            num_special = int(f.readline().strip())
            for _ in range(num_special):
                special, special_idx = f.readline().strip().split()
                special_tokens[special] = int(special_idx)
            # read the merges
            for line in f:
                idx1, idx2 = map(int, line.split())
                merges[(idx1, idx2)] = idx
                idx += 1
        self.merges = merges
        self.special_tokens = special_tokens
        self.vocab = self._build_vocab()

In [7]:
tokenizer = Tokenizer()
text = "Don't stop believing!"
tokenizer.train(text, vocab_size=300)

vocab_trained = tokenizer.vocab
print(vocab_trained)

encoded = tokenizer.encode(text)
print("Encoded:", encoded)

decoded = tokenizer.decode(encoded)
print("Decoded:", decoded)

{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[',