Build your own GPT-4 Tokenizer!

In [1]:
!pip install regex



In [2]:
import regex as re
from collections import defaultdict

# Step 1

Write the BasicTokenizer class, with the following three core functions:

def train(self, text, vocab_size, verbose=False)

def encode(self, text)

def decode(self, ids)

Train your tokenizer on whatever text you like and visualize the merged tokens. Do they look reasonable?
One default test you may wish to use is the text file tests/taylorswift.txt.

In [3]:
class BasicTokenizer:

    def __init__(self):
        self.vocab = {}  # Mapping from token index to byte string
        self.merges = {}  # Mapping from token pairs to new token ID

    def get_stats(self, ids):
      counts = {}
      for pair in zip(ids, ids[1:]):
          counts[pair] = counts.get(pair, 0) + 1
      return counts

    def merge(self, ids, pair, idx):
      newids = []
      i = 0
      while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
          newids.append(idx)
          i += 2
        else:
          newids.append(ids[i])
          i += 1
      return newids

    def train(self, text, vocab_size, verbose=False):
      """ Train tokenizer using Byte Pair Encoding (BPE). """
      tokens = list(text.encode("utf-8"))  # Convert text to byte values
      self.vocab = {idx: bytes([idx]) for idx in range(256)}  # Initialize with single bytes

      vocab_size = 276 # the desired final vocabulary size
      num_merges = vocab_size - 256
      ids = list(tokens) # copy so we don't destroy the original list

      merges = {} # (int, int) -> int
      for i in range(num_merges):
        stats = self.get_stats(ids)
        pair = max(stats, key=stats.get)
        idx = 256 + i
        #print(f"merging {pair} into a new token {idx}")
        ids = self.merge(ids, pair, idx)
        merges[pair] = idx


    def encode(self, text):
      """ Convert text into token IDs. """
      tokens = list(text.encode("utf-8"))

      while len(tokens) >= 2:
        stats = self.get_stats(tokens)
        pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))

        if pair not in self.merges:
          break  # Stop merging if the pair is not in the learned merges

        tokens = self.merge(tokens, pair, self.merges[pair])

      return tokens


    def decode(self, ids):
      """ Convert token IDs into text. """
      tokens = b"".join(self.vocab[idx] for idx in ids)

      return tokens.decode("utf-8", errors="replace")




In [4]:
if __name__ == "__main__":
    with open("/content/drive/MyDrive/dataSet_for_practice/taylorswift.txt", "r", encoding="utf-8") as f:
        text = f.read()

    tokenizer = BasicTokenizer()
    tokenizer.train(text, vocab_size=276, verbose=True)

    sample_text = "Hello, world!"
    encoded = tokenizer.encode(sample_text)
    decoded = tokenizer.decode(encoded)

    print("\nEncoded:", encoded)
    print("Decoded:", decoded)


Encoded: [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33]
Decoded: Hello, world!


# **Step 2**
Convert you BasicTokenizer into a RegexTokenizer, which takes a regex pattern and splits the text exactly as GPT-4 would. Process the parts separately as before, then concatenate the results. Retrain your tokenizer and compare the results before and after. You should see that you will now have no tokens that go across categories (numbers, letters, punctuation, more than one whitespace). Use the GPT-4 pattern:

`GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""`


In [5]:
# GPT-4 tokenization pattern
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""


In [6]:
class RegexTokenizer:
    def __init__(self):
        self.vocab = {idx: bytes([idx]) for idx in range(256)}  # Initial byte-level vocabulary
        self.merges = {}  # Stores merge rules

    def tokenize(self, text):
        """Splits text into tokens using GPT-4's regex pattern."""
        return re.findall(GPT4_SPLIT_PATTERN, text)

    def get_stats(self, ids):
        """Counts occurrences of adjacent token pairs."""
        counts = {}
        for pair in zip(ids, ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts

    def merge(self, ids, pair, idx):
        """Replaces occurrences of a pair with a new token."""
        newids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                newids.append(idx)  # Merge pair
                i += 2
            else:
                newids.append(ids[i])
                i += 1
        return newids


    def train(self, text, vocab_size):
        """Trains the tokenizer using Byte Pair Encoding (BPE)."""
        tokenized_text = self.tokenize(text)
        ids = [byte for word in tokenized_text for byte in word.encode("utf-8")]
        vocab_size = 276
        num_merges = vocab_size - 256  # Additional tokens needed

        for _ in range(num_merges):
            stats = self.get_stats(ids)
            if not stats:
                break  # Stop if no more frequent pairs exist

            pair = max(stats, key=stats.get)  # Find most frequent pair
            idx = 256 + len(self.merges)  # Assign new token ID
            ids = self.merge(ids, pair, idx)  # Merge pairs in text
            self.merges[pair] = idx  # Store merge rule

            # Update vocabulary with new token
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]


    def encode(self, text):
        """Encodes text into token IDs."""
        tokenized_text = self.tokenize(text)
        tokens = [byte for word in tokenized_text for byte in word.encode("utf-8")]

        while len(tokens) >= 2:
            stats = self.get_stats(tokens)
            pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            if pair not in self.merges:
                break
            tokens = self.merge(tokens, pair, self.merges[pair])

        return tokens

    def decode(self, ids):
        """Decodes token IDs back into text."""
        tokens = b"".join(self.vocab[idx] for idx in ids)
        return tokens.decode("utf-8", errors='replace')

In [7]:
tokenizer = RegexTokenizer()
tokenizer.train("Hello world! This is a tokenizer test.", vocab_size=300)


tokenized = tokenizer.tokenize("Hello world!")
print("Tokenized:", tokenized)

encoded = tokenizer.encode("Hello world!")
print("Encoded:", encoded)

decoded = tokenizer.decode(encoded)
print("Decoded:", decoded)

Tokenized: ['Hello', ' world', '!']
Encoded: [269]
Decoded: Hello world!


# Step 3
You're now ready to load the merges from the GPT-4 tokenizer and show that your tokenizer produces the identical results for both encode and decode, matching tiktoken.



```
# match this
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back
```

Unfortunately, you will run into two issues:

It is not trivial to recover the raw merges from the GPT-4 tokenizer. You can easily recover what we call vocab here, and what they call and store under enc._mergeable_ranks. Feel free to copy paste the recover_merges function in minbpe/gpt4.py, which takes these ranks and returns the raw merges. If you wish to know how this function works, read this and this. Basically, under some conditions it is enough to only store the parent nodes (and their rank) and get rid of the precise details of which children merged up to any parent.
Second, the GPT-4 tokenizer for some reason permutes its raw bytes. It stores this permutation in the first 256 elements of the mergeable ranks, so you can recover this byte shuffle relatively simply as byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}. In both your encode and decode, you'll have to shuffle bytes around accordingly. If you're stuck, reference the

In [8]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.2/1.2 MB[0m [31m64.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.9.0


In [9]:
import tiktoken

In [12]:
def recover_merges(mergeable_ranks):


