In this chapter we build GPT tokenizer

byte-pair encoding is pretty simple:
* we determine a most popular pair of bytes
* we invent a new byte and replace the pair with the new byte
* repeat

In [38]:
import collections
import itertools


Pair = tuple[int, int]
Merges = dict[Pair, int]

# TODO: rename, so it won't conflict with byte_pair_encode that treats merges as
# an input not as output
def byte_pair_encode(text: str, num_iterations: int, merges: Merges) -> list[int]:
    encoded = list(text.encode('utf-8'))
    new_bytes = itertools.count(256)
    for _ in range(num_iterations):
        if len(encoded) < 2:
            break
        pair = find_most_popular_pair(encoded)
        replacement = next(new_bytes)
        encoded = replace(encoded, pair, replacement)
        merges[pair] = replacement
    return encoded


def find_most_popular_pair(encoded: list[int]) -> tuple[int, int]:
    counts = collections.Counter()
    prev = None
    for cur in encoded:
        if prev is not None:
            pair = (prev, cur)
            counts[pair] += 1
        prev = cur
    return counts.most_common(1)[0][0]


def replace(encoded: list[int], what: tuple[int, int], replacement: int):
    result = []
    state = []
    for cur in encoded:
        # invariant: len(state) < 2
        state.append(cur)
        if tuple(state) == what:
            result.append(replacement)
            state = []
        elif len(state) == 2:
            # not a match, we can add the first element, since it's not part of the `what` pair
            result.append(state[0])
            state.pop(0)

    # invariant: len(state) < 2
    result.extend(state)
    return result

In [37]:
byte_pair_encode("aaabdaaabac", num_iterations=3)

[258, 100, 258, 97, 99]

In [32]:
len(byte_pair_encode("приветики вам, хочу проверить byte-pair encoding", num_iterations=20))

39

In [49]:
def byte_pair_decode(encoded: list[int], merges: Merges) -> str:
    # TODO: optimize, use toposort or smth similar to make it more efficient
    token_values = {m: [a, b] for (a, b), m in merges.items()}

    decoded_parts = encoded
    replaced = True
    while replaced:
        replaced = False

        new_decoded_parts = []
        for token in decoded_parts:
            if token in token_values:
                replaced = True
                new_decoded_parts.extend(token_values[token])
            else:
                new_decoded_parts.append(token)
        decoded_parts = new_decoded_parts
    return bytes(decoded_parts).decode("utf-8")

# TODO: implement byte_pair_encode that is symmetric to byte_pair_decode:
# it should take merges and don't modify them

In [50]:
ascii_merges = {}
ascii_encoded = byte_pair_encode("aaabdaaabac", num_iterations=3, merges=ascii_merges)

print(byte_pair_decode(ascii_encoded, ascii_merges))

aaabdaaabac


Real implementation of byte pair encoding/decoding in OpenAI GPT are doing extra processing: they split text into different categories of characters (letters, numbers, punctuation) and bpe can't cross this boundaries during merges.

The reason for that is that it's not right to mix punctuation & letters:
let's say you have a separate tokens for "dog", "dog.", "dog?". It's the same concept but it will be represented by different tokens, which can't help.