In [15]:
import os
import regex as re
from typing import BinaryIO


PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
compiled_pattern = re.compile(PAT)
input_path = "../data/TinyStoriesV2-GPT4-train.txt"
vocab_size=10000
special_tokens=["<|endoftext|>"]

In [16]:
with open(input_path, "rb") as f:
    raw_data = f.read(10000).decode("utf-8", errors="ignore")
    print(f"Read {len(raw_data)} characters")

stats = {}

# Step 1: Split on special tokens if provided
if special_tokens:
    # Build pattern: escape each special token and join with |
    split_pattern = "|".join(re.escape(token) for token in special_tokens)
    # Split the text, keeping only non-empty chunks
    chunks = [chunk for chunk in re.split(split_pattern, raw_data) if chunk]
else:
    chunks = [raw_data]

# Step 2: Pre-tokenize each chunk separately
for chunk in chunks:
    # iterate using finditer (memory efficient)
    for match in compiled_pattern.finditer(chunk):
        # Extract the actual string matched
        token = match.group().encode("utf-8")
        
        # Manual counting logic
        if token in stats:
            stats[token] += 1
        else:
            stats[token] = 1


print(f"Found {len(stats)} unique pre-tokens")
print("\nFirst 10 pre-tokens:")
for i, (token, count) in enumerate(list(stats.items())[:10]):
    print(f"{token!r}: {count}")

Read 9992 characters
Found 523 unique pre-tokens

First 10 pre-tokens:
b'\n': 70
b'Once': 8
b' upon': 8
b' a': 64
b' time': 8
b' there': 8
b' was': 49
b' little': 10
b' boy': 6
b' named': 10


In [17]:
vocab = {i: bytes([i]) for i in range(256)}
vocab[0], vocab[1], vocab[2]

(b'\x00', b'\x01', b'\x02')

In [18]:
vocab_size

10000

In [19]:
len(vocab)

256

In [20]:
num_merges = vocab_size - len(vocab)
num_merges

9744

In [23]:
def get_token_frequencies(text, pattern, special_tokens=None):
    """
    Takes text and a compiled regex pattern.
    Returns a frequency table: { 'token': count, ... }
    """
    stats = {}
    
    # Step 1: Split on special tokens if provided
    if special_tokens:
        # Build pattern: escape each special token and join with |
        split_pattern = "|".join(re.escape(token) for token in special_tokens)
        # Split the text, keeping only non-empty chunks
        chunks = [chunk for chunk in re.split(split_pattern, text) if chunk]
    else:
        chunks = [text]
    
    # Step 2: Pre-tokenize each chunk separately
    for chunk in chunks:
        # iterate using finditer (memory efficient)
        for match in pattern.finditer(chunk):
            # Extract the actual string matched
            token = match.group().encode("utf-8")
            
            # Manual counting logic
            if token in stats:
                stats[token] += 1
            else:
                stats[token] = 1
            
    return stats

In [25]:
token_frequencies = get_token_frequencies(raw_data, compiled_pattern, special_tokens)
pairs = {}

# Loop through each token and its frequency
for token_bytes, freq in token_frequencies.items():
    # Look at consecutive pairs in this token
    # Example: ('l','o','w') → pairs are ('l','o') and ('o','w')
    for i in range(len(token_bytes) - 1):
        # Get pair at position i
        pair = (bytes([token_bytes[i]]), bytes([token_bytes[i + 1]]))
        
        # Add this pair's count (weighted by token frequency)
        if pair in pairs:
            pairs[pair] += freq
        else:
            pairs[pair] = freq

print(pairs)

{(b'O', b'n'): 20, (b'n', b'c'): 12, (b'c', b'e'): 15, (b' ', b'u'): 16, (b'u', b'p'): 12, (b'p', b'o'): 9, (b'o', b'n'): 50, (b' ', b'a'): 202, (b' ', b't'): 277, (b't', b'i'): 29, (b'i', b'm'): 35, (b'm', b'e'): 54, (b't', b'h'): 205, (b'h', b'e'): 269, (b'e', b'r'): 102, (b'r', b'e'): 59, (b' ', b'w'): 132, (b'w', b'a'): 71, (b'a', b's'): 71, (b' ', b'l'): 67, (b'l', b'i'): 47, (b'i', b't'): 81, (b't', b't'): 19, (b't', b'l'): 11, (b'l', b'e'): 50, (b' ', b'b'): 86, (b'b', b'o'): 25, (b'o', b'y'): 9, (b' ', b'n'): 43, (b'n', b'a'): 11, (b'a', b'm'): 36, (b'e', b'd'): 124, (b' ', b'B'): 25, (b'B', b'e'): 13, (b'e', b'n'): 74, (b'l', b'o'): 40, (b'o', b'v'): 9, (b'v', b'e'): 47, (b't', b'o'): 95, (b' ', b'e'): 26, (b'e', b'x'): 8, (b'x', b'p'): 3, (b'p', b'l'): 29, (b'o', b'r'): 42, (b'w', b'o'): 4, (b'r', b'l'): 3, (b'l', b'd'): 18, (b'a', b'r'): 53, (b'r', b'o'): 14, (b'o', b'u'): 58, (b'u', b'n'): 26, (b'n', b'd'): 142, (b' ', b'h'): 118, (b'h', b'i'): 49, (b' ', b'H'): 24, (b'H', 

In [27]:
def get_pair_frequencies(vocab):
    """
    Given a vocab dict like {('l','o','w'): 5, ('l','o','w','e','r'): 2}
    Returns pair counts like {('l','o'): 7, ('o','w'): 7, ('w','e'): 2, ...}
    """
    pairs = {}
    
    # Loop through each token and its frequency
    for token_bytes, freq in vocab.items():
        # Look at consecutive pairs in this token
        # Example: ('l','o','w') → pairs are ('l','o') and ('o','w')
        for i in range(len(token_bytes) - 1):
            # Get pair at position i
            pair = (bytes([token_bytes[i]]), bytes([token_bytes[i + 1]]))
            
            # Add this pair's count (weighted by token frequency)
            if pair in pairs:
                pairs[pair] += freq
            else:
                pairs[pair] = freq
    
    return pairs
    
pair_freq = get_pair_frequencies(token_frequencies)

In [28]:
max_freq = max(pair_freq.values())
max_freq

277

In [29]:
top_pairs = [pair for pair, freq in pair_freq.items() if freq == max_freq]
top_pairs

[(b' ', b't')]

In [30]:
max(top_pairs)

(b' ', b't')

In [31]:
def get_best_pair(pair_frequencies):
    """
    Returns the pair with highest frequency.
    In case of tie, returns lexicographically greater pair.
    """
    if not pair_frequencies:
        return None
    
    # Find the maximum frequency
    max_freq = max(pair_frequencies.values())
    
    # Get all pairs with that frequency
    top_pairs = [pair for pair, freq in pair_frequencies.items() if freq == max_freq]
    
    # Return the lexicographically greatest one
    # max() on tuples compares element by element
    return max(top_pairs)
    
best_pair = get_best_pair(pair_freq)

In [33]:
def merge_pair_in_token(token_bytes, pair_to_merge):
    """
    Merge a specific pair in a single token.
    Returns the modified token as bytes.
    """
    if len(token_bytes) < 2:
        return token_bytes
    
    new_token = []
    i = 0
    
    while i < len(token_bytes):
        # Check if we can merge at position i
        if (i < len(token_bytes) - 1 and 
            bytes([token_bytes[i]]) == pair_to_merge[0] and 
            bytes([token_bytes[i + 1]]) == pair_to_merge[1]):
            # Merge! Combine the two bytes
            new_token.append(pair_to_merge[0] + pair_to_merge[1])
            i += 2
        else:
            # Don't merge, just copy this byte
            new_token.append(bytes([token_bytes[i]]))
            i += 1
    
    # Concatenate all byte sequences
    return b''.join(new_token)

new_token = merge_pair_in_token(token_bytes, best_pair)

In [39]:
token_frequencies.items()

dict_items([(b'\n', 70), (b'Once', 8), (b' upon', 8), (b' a', 64), (b' time', 8), (b' there', 8), (b' was', 49), (b' little', 10), (b' boy', 6), (b' named', 10), (b' Ben', 12), (b'.', 188), (b' loved', 7), (b' to', 56), (b' explore', 1), (b' the', 91), (b' world', 1), (b' around', 3), (b' him', 6), (b' He', 21), (b' saw', 12), (b' many', 3), (b' amazing', 5), (b' things', 4), (b',', 97), (b' like', 8), (b' beautiful', 3), (b' vases', 1), (b' that', 12), (b' were', 12), (b' on', 11), (b' display', 1), (b' in', 16), (b' store', 3), (b' One', 5), (b' day', 18), (b' walking', 2), (b' through', 1), (b' when', 2), (b' he', 10), (b' came', 4), (b' across', 1), (b' very', 12), (b' special', 3), (b' vase', 10), (b' When', 3), (b' it', 27), (b' amazed', 2), (b'!', 10), (b'  ', 1), (b'He', 2), (b' said', 19), (b' \xe2\x80\x9c', 2), (b'Wow', 1), (b' is', 7), (b' really', 2), (b' Can', 1), (b' I', 4), (b' buy', 1), (b'?\xe2\x80\x9d', 1), (b' ', 3), (b'The', 5), (b' shopkeeper', 1), (b' smiled', 7),

In [44]:
for i, (token_bytes, freq) in enumerate(token_frequencies.items()):
    if i >= 3:  # Stop after 3
        break
    print(f"Token {i+1}: {token_bytes!r} -> frequency: {freq}")

Token 1: b'\n' -> frequency: 70
Token 2: b'Once' -> frequency: 8
Token 3: b' upon' -> frequency: 8


In [48]:
token_bytes, freq = next(iter(token_frequencies.items()))
token_bytes

b'\n'

In [49]:
freq

70

In [46]:
# merge the most frequent pair of bytes that appear together
new_token = []
i = 0

while i < len(token_bytes):
    # Check: Is token_bytes[3:5] == (b'e', b's') == pair_to_merge?
    # bytes([101]) == b'e'?  YES
    # bytes([115]) == b's'?  YES
    # Check if we can merge at position i
    if (i < len(token_bytes) - 1 and 
        bytes([token_bytes[i]]) == best_pair[0] and 
        bytes([token_bytes[i + 1]]) == best_pair[1]):
        # Merge! Combine the two bytes
        new_token.append(best_pair[0] + best_pair[1])
        print(new_token)
        i += 2
    else:
        # Don't merge, just copy this byte
        new_token.append(bytes([token_bytes[i]]))
        i += 1 # Move to next position
print(b''.join(new_token))
# new_token = [b'n', b'e', b'w', b'es']

b' a'


In [42]:
def merge_pair_in_token(token_bytes, pair_to_merge):
    """
    Merge a specific pair in a single token.
    Returns the modified token as bytes.
    """
    if len(token_bytes) < 2:
        return token_bytes
    
    new_token = []
    i = 0
    
    while i < len(token_bytes):
        # Check if we can merge at position i
        if (i < len(token_bytes) - 1 and 
            bytes([token_bytes[i]]) == pair_to_merge[0] and 
            bytes([token_bytes[i + 1]]) == pair_to_merge[1]):
            # Merge! Combine the two bytes
            new_token.append(pair_to_merge[0] + pair_to_merge[1])
            i += 2
        else:
            # Don't merge, just copy this byte
            new_token.append(bytes([token_bytes[i]]))
            i += 1
    
    # Concatenate all byte sequences
    return b''.join(new_token)

new_token = merge_pair_in_token(token_bytes, best_pair)
new_token

b'\n'

In [51]:
new_freq_table = {}

for token_bytes, freq in token_frequencies.items():
    # We'll build a new version of this token
    new_token = merge_pair_in_token(token_bytes, best_pair)
    
    if new_token in token_frequencies:
        token_frequencies[new_token] += freq
    else:
        token_frequencies[new_token] = freq

print(new_freq_table)

{}


In [57]:
def merge_pair(token_frequencies, pair_to_merge):
    """
    Takes vocab and a pair like ('s', 't')
    Returns new vocab where that pair is merged everywhere
    Example: ('n','e','w','e','s','t') → ('n','e','w','e','st') if merging ('s','t')
    """
    new_freq_table = {}
    
    for token_bytes, freq in token_frequencies.items():
        # We'll build a new version of this token
        new_token = merge_pair_in_token(token_bytes, pair_to_merge)
        
        if new_token in new_freq_table:  
            new_freq_table[new_token] += freq
        else:
            new_freq_table[new_token] = freq
    
    return new_freq_table
    
token_frequencies = merge_pair(token_frequencies, best_pair)


In [58]:
merged_token = best_pair[0] + best_pair[1]
merged_token

b' t'

In [59]:
vocab[next_id] = merged_token
vocab[next_id]

b' t'

In [62]:
merges = []  # Track the sequence of merges
    
for merge_num in range(3):
    # Step 1: Count all pairs
    pair_freq = get_pair_frequencies(token_frequencies)
    print(f"pair_freq: {pair_freq}")
    
    if not pair_freq:
        print(f"No more pairs to merge at step {merge_num}")
        break
    
    # Step 2: Find best pair
    best_pair = get_best_pair(pair_freq)
    print(f"best_pair: {best_pair}")
    
    # Step 3: Merge it
    token_frequencies = merge_pair(token_frequencies, best_pair)
    print(f"token_frequencies: {token_frequencies}")
    # Step 4: Record this merge
    merged_token = best_pair[0] + best_pair[1]
    vocab[next_id] = merged_token
    next_id += 1
    print(merges)
    merges.append(best_pair)
    
    if (merge_num + 1) % 100 == 0 or merge_num < 10:
        print(f"Merge {merge_num + 1}/{num_merges}: "
                f"'{best_pair[0]!r} {best_pair[1]!r} -> {merged_token!r} "
                f"(freq: {pair_freq[best_pair]})")

print(f"\nTraining complete! Final vocab size: {len(vocab)}")
print(f"merges: {merges}")
print(f"vocab: {vocab}")


pair_freq: {(b'O', b'n'): 40, (b'n', b'c'): 24, (b'c', b'e'): 30, (b' ', b'u'): 32, (b'u', b'p'): 24, (b'p', b'o'): 18, (b'o', b'n'): 100, (b' ', b'a'): 404, (b' ', b't'): 554, (b't', b'i'): 58, (b'i', b'm'): 70, (b'm', b'e'): 108, (b't', b'h'): 410, (b'h', b'e'): 538, (b'e', b'r'): 204, (b'r', b'e'): 118, (b' ', b'w'): 264, (b'w', b'a'): 142, (b'a', b's'): 142, (b' ', b'l'): 134, (b'l', b'i'): 94, (b'i', b't'): 162, (b't', b't'): 38, (b't', b'l'): 22, (b'l', b'e'): 100, (b' ', b'b'): 172, (b'b', b'o'): 50, (b'o', b'y'): 18, (b' ', b'n'): 86, (b'n', b'a'): 22, (b'a', b'm'): 72, (b'e', b'd'): 248, (b' ', b'B'): 50, (b'B', b'e'): 26, (b'e', b'n'): 148, (b'l', b'o'): 80, (b'o', b'v'): 18, (b'v', b'e'): 94, (b't', b'o'): 190, (b' ', b'e'): 52, (b'e', b'x'): 16, (b'x', b'p'): 6, (b'p', b'l'): 58, (b'o', b'r'): 84, (b'w', b'o'): 8, (b'r', b'l'): 6, (b'l', b'd'): 36, (b'a', b'r'): 106, (b'r', b'o'): 28, (b'o', b'u'): 116, (b'u', b'n'): 52, (b'n', b'd'): 284, (b' ', b'h'): 236, (b'h', b'i'): 9