In [2]:
import os
import regex as re
from typing import BinaryIO

In [4]:
def get_token_frequencies(text, pattern, special_tokens=None):
    """
    Takes text and a compiled regex pattern.
    Returns a frequency table: { 'token': count, ... }
    """
    stats = {}
    
    # Step 1: Split on special tokens if provided
    if special_tokens:
        # Build pattern: escape each special token and join with |
        split_pattern = "|".join(re.escape(token) for token in special_tokens)
        # Split the text, keeping only non-empty chunks
        chunks = [chunk for chunk in re.split(split_pattern, text) if chunk]
    else:
        chunks = [text]
    
    # Step 2: Pre-tokenize each chunk separately
    for chunk in chunks:
        # iterate using finditer (memory efficient)
        for match in pattern.finditer(chunk):
            # Extract the actual string matched
            token = match.group().encode("utf-8")
            
            # Manual counting logic
            if token in stats:
                stats[token] += 1
            else:
                stats[token] = 1
            
    return stats


def get_pair_frequencies(vocab):
    """
    Given a vocab dict like {('l','o','w'): 5, ('l','o','w','e','r'): 2}
    Returns pair counts like {('l','o'): 7, ('o','w'): 7, ('w','e'): 2, ...}
    """
    pairs = {}
    
    # Loop through each token and its frequency
    for token_bytes, freq in vocab.items():
        # Look at consecutive pairs in this token
        # Example: ('l','o','w') → pairs are ('l','o') and ('o','w')
        for i in range(len(token_bytes) - 1):
            # Get pair at position i
            pair = (bytes([token_bytes[i]]), bytes([token_bytes[i + 1]]))
            
            # Add this pair's count (weighted by token frequency)
            if pair in pairs:
                pairs[pair] += freq
            else:
                pairs[pair] = freq
    
    return pairs


def get_best_pair(pair_frequencies):
    """
    Returns the pair with highest frequency.
    In case of tie, returns lexicographically greater pair.
    """
    if not pair_frequencies:
        return None
    
    # Find the maximum frequency
    max_freq = max(pair_frequencies.values())
    
    # Get all pairs with that frequency
    top_pairs = [pair for pair, freq in pair_frequencies.items() if freq == max_freq]
    
    # Return the lexicographically greatest one
    # max() on tuples compares element by element
    return max(top_pairs)

def merge_pair_in_token(token_bytes, pair_to_merge):
    """
    Merge a specific pair in a single token.
    Returns the modified token as bytes.
    """
    if len(token_bytes) < 2:
        return token_bytes
    
    new_token = []
    i = 0
    
    while i < len(token_bytes):
        # Check if we can merge at position i
        if (i < len(token_bytes) - 1 and 
            bytes([token_bytes[i]]) == pair_to_merge[0] and 
            bytes([token_bytes[i + 1]]) == pair_to_merge[1]):
            # Merge! Combine the two bytes
            new_token.append(pair_to_merge[0] + pair_to_merge[1])
            i += 2
        else:
            # Don't merge, just copy this byte
            new_token.append(bytes([token_bytes[i]]))
            i += 1
    
    # Concatenate all byte sequences
    return b''.join(new_token)

def merge_pair(token_frequencies, pair_to_merge):
    """
    Takes vocab and a pair like ('s', 't')
    Returns new vocab where that pair is merged everywhere
    Example: ('n','e','w','e','s','t') → ('n','e','w','e','st') if merging ('s','t')
    """
    new_freq_table = {}
    
    for token_bytes, freq in token_frequencies.items():
        # We'll build a new version of this token
        new_token = merge_pair_in_token(token_bytes, pair_to_merge)
        
        if new_token in token_frequencies:
            new_freq_table[new_token] += freq
        else:
            new_freq_table[new_token] = freq
    
    return new_freq_table


def train_bpe(input_path="", vocab_size=1000, special_tokens=["<|endoftext|>"]):

    if special_tokens is None:
        special_tokens = []
    
    vocab = {i: bytes([i]) for i in range(256)}
    next_id = 256

    for special_token in special_tokens:
        vocab[next_id] = special_token.encode("utf-8")
        next_id += 1

    with open(input_path, "rb") as f:
        raw_data = f.read(vocab_size).decode("utf-8", errors="ignore")
        print(raw_data)

    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    compiled_pattern = re.compile(PAT)
    token_frequencies = get_token_frequencies(raw_data, compiled_pattern, special_tokens)
    print(f"Found {len(token_frequencies)} unique pre-tokens")

    num_merges = vocab_size - len(vocab)
    if num_merges <= 0:
        return vocab, []
    print(f"Performing {num_merges} merges...")

    merges = []  # Track the sequence of merges
    
    for merge_num in range(num_merges):
        # Step 1: Count all pairs
        pair_freq = get_pair_frequencies(token_frequencies)
        
        if not pair_freq:
            print(f"No more pairs to merge at step {merge_num}")
            break
        
        # Step 2: Find best pair
        best_pair = get_best_pair(pair_freq)
        
        # Step 3: Merge it
        token_frequencies = merge_pair(token_frequencies, best_pair)
        
        # Step 4: Record this merge
        merged_token = best_pair[0] + best_pair[1]
        vocab[next_id] = merged_token
        next_id += 1

        merges.append(best_pair)
        
        if (merge_num + 1) % 100 == 0 or merge_num < 10:
            print(f"Merge {merge_num + 1}/{num_merges}: "
                  f"'{best_pair[0]!r} {best_pair[1]!r} -> {merged_token!r} "
                  f"(freq: {pair_freq[best_pair]})")

    print(f"\nTraining complete! Final vocab size: {len(vocab)}")
    return vocab, merges

vocab, merges = train_bpe(
    input_path = "../data/TinyStoriesV2-GPT4-train.txt", 
    vocab_size=100, 
    special_tokens=["<|endoftext|>"]
)


print("\n" + "="*50)
print(f"Vocabulary size: {len(vocab)}")
print(f"Number of merges: {len(merges)}")
print("\nFirst 10 merges:")
for i, (byte1, byte2) in enumerate(merges[:10]):
    merged = byte1 + byte2
    print(f"{i+1}. {byte1!r} + {byte2!r} -> {merged!r}")

print("\nLast 10 vocabulary entries:")
for idx in sorted(vocab.keys())[-10:]:
    print(f"{idx}: {vocab[idx]!r}")


Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He sa
Found 21 unique pre-tokens

Vocabulary size: 257
Number of merges: 0

First 10 merges:

Last 10 vocabulary entries:
247: b'\xf7'
248: b'\xf8'
249: b'\xf9'
250: b'\xfa'
251: b'\xfb'
252: b'\xfc'
253: b'\xfd'
254: b'\xfe'
255: b'\xff'
256: b'<|endoftext|>'
