In [1]:
from abc import ABC
from dataclasses import dataclass
from collections import defaultdict

class Tokenizer(ABC):
    """abstract interface for a tokenizer."""
    def encode(self, string: str) -> list[int]:
        raise NotImplementedError
    def decode(self, indices: list[int]) -> str:
        raise NotImplementedError
    
def get_compression_ratio(string: str, indices: list[int]) -> float:
    num_bytes = len(string.encode("utf-8"))
    num_tokens = len(indices)
    return num_bytes / num_tokens

In [2]:
class CharacterTokenizer(Tokenizer):
    """represent a string as a sequence of unicode code points."""
    def encode(self, string: str) -> list[int]:
        return [ord(c) for c in string] # list(map(ord, string))
    
    def decode(self, indices: list[int]) -> str:
        return "".join([chr(i) for i in indices]) # "".join(map(chr, indices))
   
tokenizer = CharacterTokenizer()
string = "Hello, 🌍! 你好!"
indices = tokenizer.encode(string)
reconstructed_string = tokenizer.decode(indices)

vocabulary_size = max(indices) + 1
compression_ratio = get_compression_ratio(string, indices)

assert string == reconstructed_string
print(f"string: {string}")
print(f"reconstructed_string: {reconstructed_string}")
print(f"vocabulary_size: {vocabulary_size}")
print(f"compression_ratio: {compression_ratio}")

string: Hello, 🌍! 你好!
reconstructed_string: Hello, 🌍! 你好!
vocabulary_size: 127758
compression_ratio: 1.5384615384615385


In [3]:
class ByteTokenizer(Tokenizer):
    def encode(self, string: str) -> list[int]:
        string_bytes = string.encode("utf-8")
        indices = [int(b) for b in string_bytes] # list(map(int, string_bytes))
        return indices
    
    def decode(self, indices: list[int]) -> str:
        string_bytes = bytes(indices)
        string = string_bytes.decode("utf-8")
        return string
    
tokenizer = ByteTokenizer()
string = "Hello, 🌍! 你好!"
indices = tokenizer.encode(string)
reconstructed_string = tokenizer.decode(indices)

vocabulary_size = 256
compression_ratio = get_compression_ratio(string, indices)

assert string == reconstructed_string
print(f"string: {string}")
print(f"reconstructed_string: {reconstructed_string}")
print(f"vocabulary_size: {vocabulary_size}")
print(f"compression_ratio: {compression_ratio}")

string: Hello, 🌍! 你好!
reconstructed_string: Hello, 🌍! 你好!
vocabulary_size: 256
compression_ratio: 1.0


In [4]:
@dataclass(frozen=True)
class BPETokenizerParams:
    vocab: dict[int, bytes] # index -> bytes
    merges: dict[tuple[int, int], int] # (index1, index2) -> new_index

def merge(indices: list[int], pair: tuple[int, int], new_index: int) -> list[int]:
    new_indices = []
    i=0
    while i<len(indices):
        # check is current and next elements match the pair
        if i+1<len(indices) and indices[i]==pair[0] and indices[i+1]==pair[1]:
            # append new_index and skip both elements
            new_indices.append(new_index)
            i+=2
        else: # no match, keep the current element and move to the next
            new_indices.append(indices[i])
            i+=1
    return new_indices
# -------------------------------------------------------------- #
# merge([1, 2, 3, 4, 2, 3, 5], (2, 3), 9)

# Walkthrough:
#     At i=0: 1 → not part of (2, 3) → keep → [1]
#     At i=1: 2,3 → match → replace with 9 → [1,9]
#     At i=3: 4 → keep → [1,9,4]
#     At i=4: 2,3 again → match → replace → [1,9,4,9]
#     At i=6: 5 → keep → [1,9,4,9,5]

# Result: [1, 9, 4, 9, 5]
# -------------------------------------------------------------- #

def train_bpe(string: int, num_merges: int) -> BPETokenizerParams:
    indices = list(map(int, string.encode("utf-8")))
    merges: dict[tuple[int, int], int] = {} # (index1, index2) -> merged_index
    vocab: dict[int, bytes] = {x: bytes([x]) for x in range(256)} # index -> bytes
    
    for i in range(num_merges):
        counts = defaultdict(int)
        for index1, index2 in zip(indices, indices[1:]):
            counts[(index1, index2)] += 1

        pair = max(counts, key=counts.get)
        index1, index2 = pair
        
        new_index = 256 + i
        merges[pair] = new_index
        vocab[new_index] = vocab[index1] + vocab[index2]
        indices = merge(indices, pair, new_index)
    return BPETokenizerParams(vocab=vocab, merges=merges)

string = "the cat in the hat"
params = train_bpe(string, num_merges=3)

In [None]:
class BPETokenizer(Tokenizer):
    def __init__(self, params: BPETokenizerParams):
        self.params = params
    
    def encode(self, string: str) -> list[int]:
        string_bytes = string.encode("utf-8")
        indices = list(map(int, string_bytes))
        for pair, new_index in self.params.merges.items():
            indices = merge(indices, pair, new_index)
        return indices
    
    def decode(self, indices: list[int]) -> str:
        bytes_list = list(map(self.params.vocab.get, indices))
        string_bytes = b"".join(bytes_list)
        string = string_bytes.decode("utf-8")
        return string

tokenizer = BPETokenizer(params)
string = "the quick brown fox"
indices = tokenizer.encode(string)
reconstructed_string = tokenizer.decode(indices)

vocabulary_size = len(tokenizer.params.vocab)
compression_ratio = get_compression_ratio(string, indices)

assert string == reconstructed_string
print(f"string: {string}")
print(f"reconstructed_string: {reconstructed_string}")
print(f"vocabulary_size: {vocabulary_size}")
print(f"compression_ratio: {compression_ratio}")

string: the quick brown fox
reconstructed_string: the quick brown fox
vocabulary_size: 259
compression_ratio: 1.1875
