exercise
Build your own GPT-4 Tokenizer!

Step 1
Write the BasicTokenizer class, with the following three core functions:

def train(self, text, vocab_size, verbose=False)
def encode(self, text)
def decode(self, ids)
Train your tokenizer on whatever text you like and visualize the merged tokens. Do they look reasonable? One default test you may wish to use is the text file tests/taylorswift.txt.



In [1]:
class BasicTokenizer:
    def __init__(self):
        self.merges = {}  # (int, int) -> int
        self.vocab = {}   # int -> bytes
        self.vocab_size = None

    def train(self, text, vocab_size, verbose=False):
        # convert raw text to a list of byte IDs
        ids = list(text.encode("utf-8"))
        # we start with a "vocab" for raw bytes [0..255]
        self.vocab = {i: bytes([i]) for i in range(256)}
        self.merges = {}
        self.vocab_size = vocab_size

        # define a helper to count pair frequencies
        def get_stats(seq):
            stats = {}
            for a, b in zip(seq, seq[1:]):
                stats[(a, b)] = stats.get((a, b), 0) + 1
            return stats

        # define a helper to merge a specific pair in the sequence
        def merge_pair(seq, pair, new_id):
            merged = []
            i = 0
            while i < len(seq):
                if i < len(seq) - 1 and seq[i] == pair[0] and seq[i+1] == pair[1]:
                    merged.append(new_id)
                    i += 2
                else:
                    merged.append(seq[i])
                    i += 1
            return merged

        # repeatedly merge the most common pair until reaching vocab_size
        # there are already 256 base tokens, so we can do merges up to vocab_size - 256
        merges_needed = vocab_size - 256
        for i in range(merges_needed):
            stats = get_stats(ids)
            if not stats:
                break
            best = max(stats, key=stats.get)
            new_id = 256 + i
            self.merges[best] = new_id
            self.vocab[new_id] = self.vocab[best[0]] + self.vocab[best[1]]
            ids = merge_pair(ids, best, new_id)
            if verbose:
                print(f"Merge {best} -> {new_id}, frequency={stats[best]}")

    def encode(self, text):
        # convert text to a list of raw byte IDs
        seq = list(text.encode("utf-8"))
        # repeatedly attempt merges that we know about
        # we do a naive loop: check all pairs, if found in merges, merge them
        # keep going until no merges apply
        def get_stats(seq):
            stats = {}
            for a, b in zip(seq, seq[1:]):
                stats[(a, b)] = stats.get((a, b), 0) + 1
            return stats

        # same helper as above
        def merge_pair(seq, pair, new_id):
            merged = []
            i = 0
            while i < len(seq):
                if i < len(seq) - 1 and seq[i] == pair[0] and seq[i+1] == pair[1]:
                    merged.append(new_id)
                    i += 2
                else:
                    merged.append(seq[i])
                    i += 1
            return merged

        while True:
            stats = get_stats(seq)
            # pick a pair that is in our merges dictionary and has the highest frequency
            # if no known merges appear, we're done
            known_pairs = [(pair, freq) for pair, freq in stats.items() if pair in self.merges]
            if not known_pairs:
                break
            best = max(known_pairs, key=lambda x: x[1])[0]
            seq = merge_pair(seq, best, self.merges[best])
        return seq

    def decode(self, ids):
        # convert list of IDs back into bytes, then decode to string
        # each ID’s bytes are stored in self.vocab
        tokens = b"".join(self.vocab[idx] for idx in ids)
        return tokens.decode("utf-8", errors="replace")


In [2]:

# Example
if __name__ == "__main__":
    text = "Hello world! Hello GPT-4 Tokenizer example. 안녕하세요 👋  (hello in Korean!)"
    tokenizer = BasicTokenizer()
    tokenizer.train(text, vocab_size=300, verbose=True)  # request ~300 vocab
    enc = tokenizer.encode("Hello GPT-4 Tokenizer example. 안녕하세요 👋")
    dec = tokenizer.decode(enc)
    print("Encoded:", enc)
    print("Decoded:", dec)


Merge (101, 108) -> 256, frequency=3
Merge (256, 108) -> 257, frequency=3
Merge (257, 111) -> 258, frequency=3
Merge (258, 32) -> 259, frequency=3
Merge (72, 259) -> 260, frequency=2
Merge (111, 114) -> 261, frequency=2
Merge (260, 119) -> 262, frequency=1
Merge (262, 261) -> 263, frequency=1
Merge (263, 108) -> 264, frequency=1
Merge (264, 100) -> 265, frequency=1
Merge (265, 33) -> 266, frequency=1
Merge (266, 32) -> 267, frequency=1
Merge (267, 260) -> 268, frequency=1
Merge (268, 71) -> 269, frequency=1
Merge (269, 80) -> 270, frequency=1
Merge (270, 84) -> 271, frequency=1
Merge (271, 45) -> 272, frequency=1
Merge (272, 52) -> 273, frequency=1
Merge (273, 32) -> 274, frequency=1
Merge (274, 84) -> 275, frequency=1
Merge (275, 111) -> 276, frequency=1
Merge (276, 107) -> 277, frequency=1
Merge (277, 101) -> 278, frequency=1
Merge (278, 110) -> 279, frequency=1
Merge (279, 105) -> 280, frequency=1
Merge (280, 122) -> 281, frequency=1
Merge (281, 101) -> 282, frequency=1
Merge (282, 

Step 2
Convert you BasicTokenizer into a RegexTokenizer, which takes a regex pattern and splits the text exactly as GPT-4 would. Process the parts separately as before, then concatenate the results. Retrain your tokenizer and compare the results before and after. You should see that you will now have no tokens that go across categories (numbers, letters, punctuation, more than one whitespace). Use the GPT-4 pattern:

GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

In [3]:
import regex as re

# GPT-4 splitting pattern provided
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

class RegexTokenizer(BasicTokenizer):
    def __init__(self, pattern=GPT4_SPLIT_PATTERN):
        super().__init__()
        self.pattern = re.compile(pattern)

    def encode(self, text):
        # Split text using the provided regex pattern into parts
        parts = self.pattern.findall(text)
        tokens = []
        # For each part, apply the BasicTokenizer's encode logic
        for part in parts:
            # Use parent's encode to process each segment individually
            tokens.extend(super().encode(part))
        return tokens


In [4]:

# Example
if __name__ == "__main__":
    text = "Hello world! 12345 😊 Test문자열 combined."

    # Train BasicTokenizer and RegexTokenizer on the same text for comparison
    basic_tokenizer = BasicTokenizer()
    basic_tokenizer.train(text, vocab_size=300, verbose=False)

    regex_tokenizer = RegexTokenizer()
    regex_tokenizer.train(text, vocab_size=300, verbose=False)

    # Encode using both tokenizers
    basic_tokens = basic_tokenizer.encode(text)
    regex_tokens = regex_tokenizer.encode(text)

    print("BasicTokenizer tokens:", basic_tokens)
    print("RegexTokenizer tokens:", regex_tokens)

    # Verify correctness
    print("Basic decoded:", basic_tokenizer.decode(basic_tokens))
    print("Regex decoded:", regex_tokenizer.decode(regex_tokens))


BasicTokenizer tokens: [299, 100, 46]
RegexTokenizer tokens: [259, 32, 119, 111, 114, 108, 100, 33, 32, 49, 50, 51, 52, 53, 32, 240, 159, 152, 138, 32, 84, 101, 115, 116, 235, 172, 184, 236, 158, 144, 236, 151, 180, 32, 99, 111, 109, 98, 105, 110, 101, 100, 46]
Basic decoded: Hello world! 12345 😊 Test문자열 combined.
Regex decoded: Hello world! 12345 😊 Test문자열 combined.


Step 3
You're now ready to load the merges from the GPT-4 tokenizer and show that your tokenizer produces the identical results for both encode and decode, matching tiktoken.

# match this
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back
Unfortunately, you will run into two issues:

It is not trivial to recover the raw merges from the GPT-4 tokenizer. You can easily recover what we call vocab here, and what they call and store under enc._mergeable_ranks. Feel free to copy paste the recover_merges function in minbpe/gpt4.py, which takes these ranks and returns the raw merges. If you wish to know how this function works, read this and this. Basically, under some conditions it is enough to only store the parent nodes (and their rank) and get rid of the precise details of which children merged up to any parent.
Second, the GPT-4 tokenizer for some reason permutes its raw bytes. It stores this permutation in the first 256 elements of the mergeable ranks, so you can recover this byte shuffle relatively simply as byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}. In both your encode and decode, you'll have to shuffle bytes around accordingly. If you're stuck, reference the minbpe/gpt4.py` file for hints.

In [6]:
!pip install regex tiktoken

Collecting tiktoken
  Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.8.0


In [10]:
import regex as re
import tiktoken

# Assume GPT4_SPLIT_PATTERN is defined as before
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

def bpe(mergeable_ranks, token, max_rank):
    # helper function used in get_gpt4_merges() to reconstruct the merge forest
    parts = [bytes([b]) for b in token]
    while True:
        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
        if min_rank is None or (max_rank is not None and min_rank >= max_rank):
            break
        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
    return parts

def recover_merges(mergeable_ranks):
    # the `merges` are already the byte sequences in their merged state.
    # so we have to recover the original pairings. We can do this by doing
    # a small BPE training run on all the tokens, in their order.
    # also see https://github.com/openai/tiktoken/issues/60
    # also see https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306
    merges = {}
    for token, rank in mergeable_ranks.items():
        if len(token) == 1:
            continue # skip raw bytes
        pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
        assert len(pair) == 2
        # recover the integer ranks of the pair
        ix0 = mergeable_ranks[pair[0]]
        ix1 = mergeable_ranks[pair[1]]
        merges[(ix0, ix1)] = rank

    return merges

GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
GPT4_SPECIAL_TOKENS = {
    '<|endoftext|>': 100257,
    '<|fim_prefix|>': 100258,
    '<|fim_middle|>': 100259,
    '<|fim_suffix|>': 100260,
    '<|endofprompt|>': 100276
}


In [11]:

class GPT4Tokenizer(RegexTokenizer):
    def __init__(self, pattern=GPT4_SPLIT_PATTERN):
        super().__init__(pattern)
        self.byte_shuffle = None
        self.inverse_shuffle = None

    def load_from_tiktoken(self, enc):
        # Recover raw merges from tiktoken's mergeable_ranks
        self.merges = recover_merges(enc._mergeable_ranks)

        # Recover byte permutation for first 256 bytes
        # This creates a mapping: original_byte -> permuted_byte
        self.byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}
        # Create inverse mapping for decoding: permuted_byte -> original_byte
        self.inverse_shuffle = {v: k for k, v in self.byte_shuffle.items()}

        # Initialize vocabulary using recovered merges
        self.vocab = {i: bytes([i]) for i in range(256)}
        for (p0, p1), new_id in self.merges.items():
            self.vocab[new_id] = self.vocab[p0] + self.vocab[p1]

    def encode(self, text):
        # Apply regex splitting as in RegexTokenizer
        parts = self.pattern.findall(text)
        all_tokens = []
        for part in parts:
            # Convert part to UTF-8 bytes and apply byte permutation
            raw_bytes = list(part.encode("utf-8"))
            shuffled_bytes = [self.byte_shuffle[b] for b in raw_bytes]

            # Use BPE merging on permuted bytes
            seq = shuffled_bytes
            def get_stats(seq):
                stats = {}
                for a, b in zip(seq, seq[1:]):
                    stats[(a, b)] = stats.get((a, b), 0) + 1
                return stats

            def merge_pair(seq, pair, new_id):
                merged = []
                i = 0
                while i < len(seq):
                    if i < len(seq) - 1 and seq[i] == pair[0] and seq[i+1] == pair[1]:
                        merged.append(new_id)
                        i += 2
                    else:
                        merged.append(seq[i])
                        i += 1
                return merged

            while True:
                stats = get_stats(seq)
                known_pairs = [(pair, freq) for pair, freq in stats.items() if pair in self.merges]
                if not known_pairs:
                    break
                best = max(known_pairs, key=lambda x: x[1])[0]
                seq = merge_pair(seq, best, self.merges[best])
            all_tokens.extend(seq)
        return all_tokens

    def decode(self, ids):
        # Convert token IDs back to a byte sequence using the vocabulary
        merged_bytes = b"".join(self.vocab[idx] for idx in ids)
        # Reverse the byte permutation
        unshuffled_bytes = bytes([self.inverse_shuffle[b] for b in merged_bytes])
        # Decode from UTF-8
        return unshuffled_bytes.decode("utf-8", errors="replace")


In [12]:

# Checking matching tiktoken's behavior
if __name__ == "__main__":
    # Load GPT-4 tiktoken encoding
    enc = tiktoken.get_encoding("cl100k_base")

    # Initialize and load GPT4Tokenizer from tiktoken data
    tokenizer = GPT4Tokenizer()
    tokenizer.load_from_tiktoken(enc)

    # Test encoding and decoding
    test_text = "hello world!!!? (안녕하세요!) lol123 😉"
    custom_ids = tokenizer.encode(test_text)
    custom_decoded = tokenizer.decode(custom_ids)

    # Use tiktoken for comparison
    tiktoken_ids = enc.encode(test_text)
    tiktoken_decoded = enc.decode(tiktoken_ids)

    print("Custom tokenizer IDs:", custom_ids)
    print("tiktoken IDs:           ", tiktoken_ids)
    print("IDs match:", custom_ids == tiktoken_ids)
    print("Decoded text matches:", custom_decoded == tiktoken_decoded)


Custom tokenizer IDs: [383, 657, 78, 24670, 2438, 67, 12340, 30, 320, 31495, 230, 75265, 243, 92245, 16715, 781, 75, 4513, 57037]
tiktoken IDs:            [15339, 1917, 12340, 30, 320, 31495, 230, 75265, 243, 92245, 16715, 28509, 4513, 57037]
IDs match: False
Decoded text matches: True
