# GPT Tokenizer

Auke Bruinsma

Based on the video by Andrej Karpathy: https://www.youtube.com/watch?v=zduSFxRajkE&

## Load data

In [303]:
import re

class DataLoader:
    """A class for loading and cleaning text from a file."""

    def __init__(self, file_path: str) -> None:
        """
        Initializes the DataLoader with a file path.

        Args:
            filepath (str): The path to the text file.
        """
        self.file_path = file_path

    def load_text(self) -> str:
        """
        Loads text from the specified file.

        Returns:
            (str): The contents of the file as a string, or an empty string if
            an error occurs.
        """
        try:
            with open(self.file_path, "r", encoding="utf-8") as file:
                return file.read()
        except FileNotFoundError:
            print(f"Error: File '{self.file_path}' not found.")
            return ""
        except Exception as e:
            print(f"Error: {e}")
            return ""
    
    def clean_text(self, text: str) -> str:
        """
        Cleans the given text by removing extra spaces and newlines.

        Args:
            text (str): The input text to be cleaned.

        Returns:
            (str): The cleaned text with excess spaces and newlines removed.
        """
        text = re.sub(r" {2,}", " ", text)
        text = re.sub(r"\n{2,}", "\n", text)
        return text.strip()


In [304]:
file_path ="data/herman_finkers.txt"

data_loader = DataLoader(file_path)
text = data_loader.load_text()
cleaned_text = data_loader.clean_text(text)

In [306]:
tokens = cleaned_text.encode("utf-8")
tokens[:20]

b'De beginnend cabaret'

In [307]:
tokens = list(map(int, tokens))
print(tokens[:20])

[68, 101, 32, 98, 101, 103, 105, 110, 110, 101, 110, 100, 32, 99, 97, 98, 97, 114, 101, 116]


## Exercise 1:
Find the pair of bytes that occur most frequently

### My solution

In [309]:
from collections import defaultdict

def find_top_pairs(encoded_text: list[int]) -> list[tuple[int, int]]:
    """
    Finds the most frequently occurring adjacent byte pairs.

    Args:
        encoded_text (list[int]): A list of integers representing encoded text.

    Returns:
        list[tuple[int, int]]: A list of the most frequent byte pairs.
    """
    byte_pair_count = defaultdict(int)

    for current_element, next_element in zip(encoded_text[:-2], encoded_text[1:]):
        byte_pair = (current_element, next_element)
        byte_pair_count[byte_pair] += 1
    
    highest_count = max(byte_pair_count.values(), default=0)
    top_pairs = [byte_pair for byte_pair, count in byte_pair_count.items() if count == highest_count]

    return top_pairs

top_pairs = find_top_pairs(encoded_text=tokens)
top_pairs


[(110, 32)]

### Video solution

In [310]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

In [311]:
stats = get_stats(tokens)
print(sorted(((v,k) for k,v in stats.items()), reverse=True))

[(140, (110, 32)), (136, (101, 110)), (98, (101, 114)), (86, (116, 32)), (74, (32, 105)), (73, (107, 32)), (73, (101, 101)), (71, (101, 32)), (68, (32, 101)), (68, (32, 100)), (65, (100, 101)), (58, (115, 32)), (57, (114, 32)), (51, (105, 106)), (50, (32, 119)), (49, (44, 32)), (48, (32, 104)), (46, (105, 101)), (46, (101, 108)), (46, (97, 110)), (45, (105, 110)), (45, (101, 116)), (45, (97, 97)), (45, (32, 109)), (43, (105, 107)), (42, (103, 101)), (40, (104, 101)), (40, (32, 111)), (37, (32, 97)), (35, (108, 32)), (34, (97, 108)), (33, (32, 118)), (32, (32, 122)), (31, (110, 100)), (30, (97, 114)), (29, (32, 110)), (28, (226, 128)), (28, (115, 116)), (28, (111, 111)), (28, (98, 101)), (28, (46, 32)), (27, (100, 97)), (27, (32, 98)), (25, (100, 32)), (25, (97, 116)), (25, (32, 103)), (24, (111, 110)), (24, (107, 101)), (23, (114, 101)), (23, (109, 105)), (23, (105, 115)), (23, (32, 107)), (22, (119, 101)), (22, (118, 101)), (22, (111, 114)), (22, (108, 101)), (22, (103, 32)), (22, (10

In [312]:
chr(101), chr(32)

('e', ' ')

### My new solution

- Changed the for-loop to a more pythonic way to do it.
- Changed method for retrieving max value of dictionary values, this one looks more nice.

In [313]:
from collections import defaultdict

def find_top_pair(encoded_text: list[int]) -> tuple[int, int]:
    """
    Finds the most frequently occurring adjacent byte pair.

    Args:
        encoded_text (list[int]): A list of integers representing encoded text.

    Returns:
        tuple[int, int]: The most frequent byte pair.
    """
    byte_pair_count = defaultdict(int)

    for current_element, next_element in zip(encoded_text, encoded_text[1:]):
        byte_pair = (current_element, next_element)
        byte_pair_count[byte_pair] += 1

    return max(byte_pair_count, key=byte_pair_count.get, default=(0, 0))

top_pair = find_top_pair(encoded_text=tokens)
top_pair


(110, 32)

## Exercise 2:
Merge all top pairs with a new byte value.

### My solution

In [314]:
def merge_top_pair(
    encoded_text: list[int],
    top_pair_value: tuple[int, int],
    new_byte_value: int = 256,
) -> list[int]:
    """
    Replaces occurrences of a given byte pair in the encoded text.

    This function scans the encoded text for consecutive occurrences of 
    `top_pair_value` and replaces them with `new_byte_value`.

    Args:
        encoded_text (list[int]): The input list of byte values.
        top_pair_value (tuple[int, int]): The byte pair to be merged.
        new_byte_value (int, optional): The replacement byte. Defaults to 256.

    Returns:
        list[int]: The modified list with merged byte pairs.
    """
    i = 0 # For tracking iterations
    new_bytes = []

    while i < len(encoded_text) - 1:
        current_element = encoded_text[i]
        next_element = encoded_text[i+1]

        if current_element == top_pair_value[0] and next_element == top_pair_value[1]:
            new_bytes.append(new_byte_value)
            i += 2
        else:
            new_bytes.append(current_element)
            i += 1
    
    # If the last element was not part of a skipped pair, add it
    if i < len(encoded_text):
        new_bytes.append(encoded_text[i])
    
    return new_bytes

In [315]:
print(merge_top_pair([5, 6, 6, 7, 9, 1], (6, 7), 99))

[5, 6, 99, 9, 1]


In [316]:
print(f"{len(tokens)=}")

new_encoded_text = merge_top_pair(
    encoded_text=tokens,
    top_pair_value=top_pair,
)

print(f"{len(new_encoded_text)=}")

len(tokens)=4379
len(new_encoded_text)=4239


### Video solution

In [317]:
def merge(ids, pair, idx):
    newids = []
    i = 0

    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    
    return newids

In [318]:
print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))

[5, 6, 99, 9, 1]


### My new solution

- Merged the while and if statement into a while statement containing the if statement, looks more nice and you handle the edge case inside the loop.

In [319]:
def merge_top_pair(
    encoded_text: list[int],
    top_pair_value: tuple[int, int],
    new_byte_value: int = 256,
) -> list[int]:
    """
    Replaces occurrences of a given byte pair in the encoded text.

    This function scans the encoded text for consecutive occurrences of 
    `top_pair_value` and replaces them with `new_byte_value`.

    Args:
        encoded_text (list[int]): The input list of byte values.
        top_pair_value (tuple[int, int]): The byte pair to be merged.
        new_byte_value (int, optional): The replacement byte. Defaults to 256.

    Returns:
        list[int]: The modified list with merged byte pairs.
    """
    i = 0 # Iterations
    new_bytes = []

    while i < len(encoded_text):
        if i < len(encoded_text) - 1 and (encoded_text[i] == top_pair_value[0] and encoded_text[i+1] == top_pair_value[1]):
            new_bytes.append(new_byte_value)
            i += 2
        else:
            new_bytes.append(encoded_text[i])
            i += 1
    
    return new_bytes

In [320]:
print(merge_top_pair([5, 6, 6, 7, 9, 1], (6, 7), 99))

[5, 6, 99, 9, 1]


In [322]:
print(f"{len(tokens)=}")

new_encoded_text = merge_top_pair(
    encoded_text=tokens,
    top_pair_value=top_pair,
)

print(f"{len(new_encoded_text)=}")

len(tokens)=4379
len(new_encoded_text)=4239


## Exercise 3:
Do this iteratively, with a hyperparameter that determines the number of iterations.

### My solution

In [323]:
from collections import defaultdict

class GPTTokenizer():
    """
    Implements a simple Byte Pair Encoding (BPE) tokenizer.

    This tokenizer iteratively replaces the most frequent adjacent byte pairs 
    with new byte values to form a more compact representation of the text.
    """
    def __init__(
        self,
        encoded_text: list[int],
        num_iterations: int,
    ) -> None:
        """
        Initializes the tokenizer with encoded text and iteration count.

        Args:
            encoded_text (list[int]): A list of integers representing encoded text.
            num_iterations (int): The number of iterations for byte pair merging.
        """
        self.encoded_text = encoded_text
        self.num_iterations = num_iterations
        self.replace_byte_value = 256

    def find_top_pair(self) -> tuple[int, int]:
        """
        Finds the most frequently occurring adjacent byte pair.

        Args:
            None

        Returns:
            tuple[int, int]: The most frequent byte pair.
        """
        byte_pair_count = defaultdict(int)

        for current_element, next_element in zip(self.encoded_text, self.encoded_text[1:]):
            byte_pair = (current_element, next_element)
            byte_pair_count[byte_pair] += 1

        return max(byte_pair_count, key=byte_pair_count.get, default=(0, 0))

    def merge_top_pair(
        self,
        top_pair: tuple[int, int],
    ) -> list[int]:
        """
        Replaces occurrences of a given byte pair in the encoded text.

        This function scans the encoded text for consecutive occurrences of 
        `top_pair` and replaces them with `self.replace_byte_value`.

        Args:
            top_pair (tuple[int, int]): The byte pair to be merged.

        Returns:
            new_bytes (list[int]): The modified list with merged byte pairs.
        """
        i = 0
        new_bytes = []

        while i < len(self.encoded_text):
            if i < len(self.encoded_text) - 1 and (self.encoded_text[i] == top_pair[0] and self.encoded_text[i+1] == top_pair[1]):
                new_bytes.append(self.replace_byte_value)
                i += 2
            else:
                new_bytes.append(self.encoded_text[i])
                i += 1
        
        self.encoded_text = new_bytes
    
    def tokenize(
        self,
    ) -> list[int]:
        """
        Performs byte pair encoding for a fixed number of iterations.

        This method iteratively finds and replaces the most frequent byte pairs 
        in the encoded text, assigning new byte values sequentially.

        Returns:
            (list[int]): The final encoded text after all iterations.
        """
        for i in range(self.num_iterations):
            top_pair = self.find_top_pair()
            self.merge_top_pair(top_pair=top_pair)
            self.replace_byte_value += 1
            
        return self.encoded_text

In [324]:
num_iterations = 20

tokenizer = GPTTokenizer(
    encoded_text=tokens,
    num_iterations=num_iterations
)

new_tokens = tokenizer.tokenize()

# Perform simple check to check if this worked
assert max(new_tokens) == 256 + num_iterations - 1

### Video solution

In [326]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    newids = []
    i = 0

    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    
    return newids

vocab_size = 276
num_merges = vocab_size - 256
ids = list(tokens)

merges = {}
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

merging (110, 32) into a new token 256
merging (101, 114) into a new token 257
merging (116, 32) into a new token 258
merging (101, 256) into a new token 259
merging (107, 32) into a new token 260
merging (101, 32) into a new token 261
merging (115, 32) into a new token 262
merging (101, 110) into a new token 263
merging (105, 106) into a new token 264
merging (44, 32) into a new token 265
merging (101, 108) into a new token 266
merging (97, 97) into a new token 267
merging (105, 260) into a new token 268
merging (257, 32) into a new token 269
merging (105, 101) into a new token 270
merging (105, 110) into a new token 271
merging (101, 259) into a new token 272
merging (100, 261) into a new token 273
merging (46, 32) into a new token 274
merging (111, 111) into a new token 275


In [328]:
merges

{(110, 32): 256,
 (101, 114): 257,
 (116, 32): 258,
 (101, 256): 259,
 (107, 32): 260,
 (101, 32): 261,
 (115, 32): 262,
 (101, 110): 263,
 (105, 106): 264,
 (44, 32): 265,
 (101, 108): 266,
 (97, 97): 267,
 (105, 260): 268,
 (257, 32): 269,
 (105, 101): 270,
 (105, 110): 271,
 (101, 259): 272,
 (100, 261): 273,
 (46, 32): 274,
 (111, 111): 275}

### My new solution

- Added the printing and the merges variable.

In [329]:
from collections import defaultdict
import logging

class GPTTokenizer():
    """
    Implements a simple Byte Pair Encoding (BPE) tokenizer.

    This tokenizer iteratively replaces the most frequent adjacent byte pairs 
    with new byte values to form a more compact representation of the text.
    """
    def __init__(
        self,
        encoded_text: list[int],
        num_iterations: int,
    ) -> None:
        """
        Initializes the tokenizer with encoded text and iteration count.

        Args:
            encoded_text (list[int]): A list of integers representing encoded text.
            num_iterations (int): The number of iterations for byte pair merging.
        """
        self.encoded_text = encoded_text
        self.num_iterations = num_iterations
        self.replace_byte_value = 256
        self.merges = {} # Will contain the replace byte-pairs and their replacement values.

    def find_top_pair(self) -> tuple[int, int]:
        """
        Finds the most frequently occurring adjacent byte pair.

        Args:
            None

        Returns:
            tuple[int, int]: The most frequent byte pair.
        """
        byte_pair_count = defaultdict(int)

        for current_element, next_element in zip(self.encoded_text, self.encoded_text[1:]):
            byte_pair = (current_element, next_element)
            byte_pair_count[byte_pair] += 1

        return max(byte_pair_count, key=byte_pair_count.get, default=(0, 0))

    def merge_top_pair(
        self,
        top_pair: tuple[int, int],
    ) -> list[int]:
        """
        Replaces occurrences of a given byte pair in the encoded text.

        This function scans the encoded text for consecutive occurrences of 
        `top_pair` and replaces them with `self.replace_byte_value`.

        Args:
            top_pair (tuple[int, int]): The byte pair to be merged.

        Returns:
            new_bytes (list[int]): The modified list with merged byte pairs.
        """
        i = 0
        new_bytes = []

        while i < len(self.encoded_text):
            if i < len(self.encoded_text) - 1 and (self.encoded_text[i] == top_pair[0] and self.encoded_text[i+1] == top_pair[1]):
                new_bytes.append(self.replace_byte_value)
                i += 2
            else:
                new_bytes.append(self.encoded_text[i])
                i += 1
        
        self.encoded_text = new_bytes
    
    def tokenize(
        self,
    ) -> list[int]:
        """
        Performs byte pair encoding for a fixed number of iterations.

        This method iteratively finds and replaces the most frequent byte pairs 
        in the encoded text, assigning new byte values sequentially.

        Returns:
            (list[int]): The final encoded text after all iterations.
        """
        print_width_1 = len(str(self.num_iterations-1))
        print_width_2 = len(str(self.num_iterations + self.replace_byte_value -1))

        for i in range(self.num_iterations):
            top_pair = self.find_top_pair()

            print(f"{i=:0{print_width_1}}: Merging ({top_pair[0]:{print_width_2}}, {top_pair[1]:{print_width_2}}) into a new token {self.replace_byte_value}")
            self.merges[top_pair] = self.replace_byte_value

            self.merge_top_pair(top_pair=top_pair)
            self.replace_byte_value += 1
            
        return self.encoded_text, self.merges

In [330]:
num_iterations = 20

tokenizer = GPTTokenizer(
    encoded_text=tokens,
    num_iterations=num_iterations
)

new_tokens, merges = tokenizer.tokenize()

i=00: Merging (110,  32) into a new token 256
i=01: Merging (101, 114) into a new token 257
i=02: Merging (116,  32) into a new token 258
i=03: Merging (101, 256) into a new token 259
i=04: Merging (107,  32) into a new token 260
i=05: Merging (101,  32) into a new token 261
i=06: Merging (115,  32) into a new token 262
i=07: Merging (101, 110) into a new token 263
i=08: Merging (105, 106) into a new token 264
i=09: Merging ( 44,  32) into a new token 265
i=10: Merging (101, 108) into a new token 266
i=11: Merging ( 97,  97) into a new token 267
i=12: Merging (105, 260) into a new token 268
i=13: Merging (257,  32) into a new token 269
i=14: Merging (105, 101) into a new token 270
i=15: Merging (105, 110) into a new token 271
i=16: Merging (101, 259) into a new token 272
i=17: Merging (100, 261) into a new token 273
i=18: Merging ( 46,  32) into a new token 274
i=19: Merging (111, 111) into a new token 275


In [332]:
print(f"{len(tokens)=}")
print(f"{len(new_tokens)=}")

print(f"\nCompression rate: {len(tokens)/len(new_tokens):.2f}")

len(tokens)=4379
len(new_tokens)=3275

Compression rate: 1.34


## Exercise 4

Write the `encode` and `decode` methods between raw text (Unicode code point sequence) and a token sequence.

### My solution

### Video solution

### My new solution

## Final result

- Don't specify the number of iterations to perform, instead specify the desired vocabulary size and compute the number of iterations based on the number of unique byte values of your original text and the vocabulary size you will achieve.