# BPE Tokenizer

This is a widely used two-stage algorithm introduced to the LLM domain by the [GPT-2 Paper](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf). The first builds a vocabulary of frequent subword units. Once this is built, the vocabulary is used to encode/decode text into/from tokens.

## Step 1 - Building the Vocabulary

The first step is to build a vocabulary of subword units from training data. The algorithm works by iteratively merging the most frequent pairs of characters or subwords until a predefined vocabulary size is reached. There's a "sweet spot" for the vocabulary size in the LLM domain:
- One the one hand, a larger vocabulary allows for more compressed / accurate representation of words.
- On the other hand, this bloats the LLM vocabulary, meaning the embeddings layer and output layers will be larger, also increasing the complexity of the final softmax operation.

So picking the right vocabulary size is a balancing act. It's also important that the training data is representative of the language and domain the model will be used in, as this will affect the quality of the vocabulary, and thus will impact the downstream representational ability of the model.

In [None]:
VOCABULARY_SIZE = 100000

# TODO: Update with better training data
with open("../data/shakespeare.txt", "r") as f:
    TRAINING_TEXT = f.read()

Note - I'll be performing tokenization at the byte level, the alternative is to tokenize at the Unicode Code Point level.

Want to know more on text representation at this level? Check out [this article](https://www.joelonsoftware.com/2003/10/08/the-absolute-minimum-every-software-developer-absolutely-positively-must-know-about-unicode-and-character-sets-no-excuses/) by Joel Spolsky.

To keep things simple, UTF-8 encoding represents each character as a sequence of up to 4 bytes, where each byte is an integer between 0 and 255. This means, with the default vocabulary size of 100,000, we can represent the first 256 tokens as single byte tokens (0-255), and the rest as merged tokens.

In [None]:
training_bytes = TRAINING_TEXT.encode('utf-8')

next_token_id = 256 # bytes 0-255 are reserved for single byte tokens

In order to create a vocabulary, we need to count the frequency of each pair of tokens in the training data. We'll use a doubly linked list to represent the sequence of tokens, allowing us to efficiently merge pairs of tokens as we build the vocabulary. We'll also use a max heap to efficiently retrieve the most frequent pairs of tokens.

In [None]:
class TokenNode:
    def __init__(self, value):
        self.value = value  # byte / merged token ID
        self.prev = None
        self.next = None

def build_linked_list(token_list):
    nodes = [TokenNode(tok) for tok in token_list]
    for i in range(1, len(nodes)):
        nodes[i].prev = nodes[i - 1]
        nodes[i - 1].next = nodes[i]
    return nodes


In [None]:
doubly_linked_tokens = build_linked_list(training_bytes)

In [None]:
from collections import defaultdict

pair_freqs = defaultdict(int)
pair_positions = defaultdict(set)

def index_pairs(nodes):
    for node in nodes:
        if node.next:
            pair = (node.value, node.next.value)
            pair_freqs[pair] += 1
            pair_positions[pair].add(node)


In [None]:
index_pairs(doubly_linked_tokens)

In [None]:
def merge_pair(pair, new_token_id, max_heap):
    nodes_to_merge = list(pair_positions[pair])
    updated_pairs = set()
    for node in nodes_to_merge:
        if not node.next or (node.value, node.next.value) != pair:
            continue

        right = node.next

        # Unlink the right node
        node.value = new_token_id
        node.next = right.next
        if right.next:
            right.next.prev = node

        # Clean up stale pair positions
        if node.prev:
            old_left = (node.prev.value, pair[0])
            pair_freqs[old_left] -= 1
            pair_positions[old_left].discard(node.prev)
            new_left = (node.prev.value, new_token_id)
            pair_freqs[new_left] += 1
            pair_positions[new_left].add(node.prev)
            updated_pairs.add(new_left)

        if node.next:
            old_right = (pair[1], node.next.value)
            pair_freqs[old_right] -= 1
            pair_positions[old_right].discard(right)
            new_right = (new_token_id, node.next.value)
            pair_freqs[new_right] += 1
            pair_positions[new_right].add(node)
            updated_pairs.add(new_right)

        # Clean up merged pair
        pair_freqs[pair] -= 1
        pair_positions[pair].discard(node)

        # Add new/updated pairs to updated_pairs set, for adding to the heap later
        updated_pairs.add(pair)

    print(f"Merged {pair} into {new_token_id} with {len(nodes_to_merge)} occurrences.")
    # Add updated pairs to the heap
    for updated_pair in updated_pairs:
        if pair_freqs[updated_pair] > 0:
            heapq.heappush(max_heap, (-pair_freqs[updated_pair], updated_pair))

In [None]:
import heapq

heap = [(-freq, pair) for pair, freq in pair_freqs.items()]
heapq.heapify(heap)
merge_list = []
merge_table = {}
while heap and len(merge_table) < VOCABULARY_SIZE - 256:
    neg_freq, pair = heapq.heappop(heap)
    if pair_freqs.get(pair, 0) != -neg_freq:
        continue  # stale
    new_token_id = next_token_id
    merge_pair(pair, new_token_id, heap)
    merge_table[pair] =  new_token_id
    merge_list.append((pair, new_token_id))
    next_token_id += 1


In [None]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merge_list:
    vocab[idx] = vocab[p0] + vocab[p1]

# Step 2 - Encoding and Decoding

Now that we've built a vocabulary, we can use it to encode and decode text by iterating over the merge rules and replacing pairs of tokens with their merged token ID. The encoding process will convert text into a sequence of token IDs, while the decoding process will convert token IDs back into text.

In [None]:
# TODO: Update Naive encoding method to use linked list & prio queues for O(nlog n) time complexity, instead of current O(m*n) where m is the number of merge rules and n is the length of the text.

def encode(text, merge_list):
    text_bytes = list(text.encode('utf-8'))
    for merge_rule in merge_list:
        pair, new_token = merge_rule
        i = 0
        while i < len(text_bytes) - 1 and len(text_bytes) > 1:
            if (text_bytes[i], text_bytes[i + 1]) == pair:
                text_bytes[i] = new_token
                del text_bytes[i + 1]
            i+= 1

    return text_bytes


def decode(ids):
    # given ids (list of integers), return Python string
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text


In [12]:
input_string = "how is this encoding method?"
encoded_ids = encode(input_string, merge_list)
for idx in encoded_ids:
    print(f"token ID: {idx}, decoded individual token: {decode([idx])}")

print(f"Decoded string: {decode(encoded_ids)}")

token ID: 10986, decoded individual token: how is 
token ID: 370, decoded individual token: this 
token ID: 270, decoded individual token: en
token ID: 99, decoded individual token: c
token ID: 12178, decoded individual token: oding 
token ID: 109, decoded individual token: m
token ID: 101, decoded individual token: e
token ID: 257, decoded individual token: th
token ID: 111, decoded individual token: o
token ID: 100, decoded individual token: d
token ID: 63, decoded individual token: ?
Decoded string: how is this encoding method?
