In [13]:
import time
import json
import psutil
import os
from pathlib import Path

In [14]:
%pip install regex
import regex as re

def train_bpe(input_path:str, vocab_size:int, special_tokens:list[str])->tuple[dict[int,bytes], list[tuple[bytes, bytes]]]:
    with open(input_path, 'rb') as f:
        text_bytes=f.read()
    text_str=text_bytes.decode('utf-8')

    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    words=re.findall(PAT, text_str)

    from collections import Counter
    word_freq=Counter(words)
    
    word_tokens={}
    for word, freq in word_freq.items():
        word_bytes=word.encode('utf-8')
        token_seq=tuple(word_bytes) # "abs" -> (64,37,37)
        word_tokens[token_seq]=freq 

    vocab={}
    for i in range(256):
        vocab[i]=bytes([i])
    nxt_tid=256
    for x in special_tokens:
        vocab[nxt_tid]=x.encode('utf-8')
        nxt_tid+=1

    def cnt_pairs(word_tokens):
        pair_cnt=Counter()
        for token_seq, freq in word_tokens.items():
            for i in range(len(token_seq)-1):
                pair=(token_seq[i], token_seq[i+1])
                pair_cnt[pair]+=freq
        return pair_cnt
    
    def mfr_pair(pair_cnt, word_tokens):
        if not pair_cnt:
            return None
        max_freq = max(pair_cnt.values())
        most_freq_pairs = [pair for pair, freq in pair_cnt.items() if freq == max_freq]
        return max(most_freq_pairs, key=lambda p: (vocab[p[0]], vocab[p[1]])) #This was the error that gpt resolved
    
    def merge(word_tokens, pair, new_token_id):
        new_word_tokens={}
        tk1, tk2=pair
        for token_seq, freq in word_tokens.items():
            new_seq=[]
            i=0
            while i<len(token_seq):
                if(i<len(token_seq)-1 and token_seq[i]==tk1 and token_seq[i+1]==tk2):
                    new_seq.append(new_token_id)
                    i+=2
                else:
                    new_seq.append(token_seq[i])
                    i+=1

            new_word_tokens[tuple(new_seq)]=freq
        return new_word_tokens
    
    merges=[]
    while len(vocab)<vocab_size :
        pair_count=cnt_pairs(word_tokens)
        mfp=mfr_pair(pair_count, word_tokens)
        if mfp is None:
            break
        tk1_bytes=vocab[mfp[0]]
        tk2_bytes=vocab[mfp[1]]
        new_token_id=nxt_tid
        nxt_tid+=1
        vocab[new_token_id]=tk1_bytes+tk2_bytes
        merges.append((tk1_bytes, tk2_bytes))
        word_tokens=merge(word_tokens, mfp, new_token_id)

    return vocab, merges
    


Collecting regex
  Downloading regex-2025.9.18-cp39-cp39-macosx_11_0_arm64.whl.metadata (40 kB)
Downloading regex-2025.9.18-cp39-cp39-macosx_11_0_arm64.whl (286 kB)
Installing collected packages: regex
Successfully installed regex-2025.9.18
Note: you may need to restart the kernel to use updated packages.


In [15]:
def get_mem_gb():
    process=psutil.Process(os.getpid())
    return process.memory_info().rss / (1024**3)

In [16]:
def save_vo_merges(vocab, merges, output_dir="tokenizer_output"):
    os.makedirs(output_dir, exist_ok=True)
    import base64
    vocab_serial={
        str(k): base64.b64encode(v).decode('utf-8') for k,v in vocab.items()
    }
    vocab_path=os.path.join(output_dir, "vocab.json")
    with open(vocab_path, 'w') as f:
        json.dump(vocab_serial, f, indent=2, sort_keys=True)
    print(f" Vocabulary saved to : {vocab_path} ")
    merges_path=os.path.join(output_dir, "merges.txt")
    with open(merges_path, 'w') as f:
        for tk1, tk2 in merges:
            try:
                tk1_str=tk1.decode('utf-8')
                tk2_str=tk2.decode('utf-8')
            except UnicodeDecodeError:
                tk1_str=repr(tk1)
                tk2_str=repr(tk2)
            f.write(f"{tk1_str} {tk2_str}\n")
    print(f" Merges saved to : {merges_path} ")

    vocab_read_path=os.path.join(output_dir, "vocab_read.txt")
    with open(vocab_read_path, 'w', encoding='utf-8') as f:
        for tk_id, tk_bytes in sorted(vocab.items()):
            try:
                tk_str=tk_bytes.decode('utf-8')
            except UnicodeDecodeError:
                tk_str=repr(tk_bytes)
            f.write(f"{tk_id}\t{tk_str}\n")
    print(f" Vocabulary (readable) saved to : {vocab_read_path} ")


In [17]:
def main():
    dataset_path="/Users/vsingh/Documents/CS336 Tatsu/data/TinyStoriesV2-GPT4-train.txt"
    vocab_size=10000
    special_tokens=["<|endoftext|>"]
    output_dir="tinystories_tokenizer"

    print("TRAINING BPE TOKENIZER ON TINYSTORIES")
    if not os.path.exists(dataset_path):
        print(f" File not found: {dataset_path}")
        print("\nPlease update dataset_path to point to your TinyStories dataset.")
        print("Common locations:")
        print("  - ./TinyStories.txt")
        print("  - ./data/TinyStories.txt")
        print("  - ~/datasets/TinyStories.txt")
        return
    
    initial_memory = get_mem_gb()
    start_time = time.time()
    
    print(f"\n{'='*60}")
    print("TRAINING STARTED")
    print(f"{'='*60}")
    print(f"Initial memory usage: {initial_memory:.2f} GB")
    print(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S')}")


    vocab, merges = train_bpe(
        input_path=dataset_path,
        vocab_size=vocab_size,
        special_tokens=special_tokens
    )

    end_time = time.time()
    final_memory = get_mem_gb()
    peak_memory = final_memory  # Note: This is current, not peak
    
    training_time_seconds = end_time - start_time
    training_time_hours = training_time_seconds / 3600
    training_time_minutes = training_time_seconds / 60
    
    print(f"\n{'='*60}")
    print("TRAINING COMPLETED")
    print(f"{'='*60}")
    print(f"End time: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Training time: {training_time_hours:.2f} hours ({training_time_minutes:.2f} minutes)")
    print(f"Final memory usage: {final_memory:.2f} GB")
    print(f"Memory increase: {final_memory - initial_memory:.2f} GB")

    print(f"\n{'='*60}")
    print("SAVING RESULTS")
    print(f"{'='*60}")
    save_vo_merges(vocab, merges, output_dir)


In [18]:
if __name__ == "__main__":
    main()

TRAINING BPE TOKENIZER ON TINYSTORIES

TRAINING STARTED
Initial memory usage: 0.95 GB
Start time: 2025-10-04 22:13:38

TRAINING COMPLETED
End time: 2025-10-04 22:52:43
Training time: 0.65 hours (39.09 minutes)
Final memory usage: 0.03 GB
Memory increase: -0.91 GB

SAVING RESULTS
 Vocabulary saved to : tinystories_tokenizer/vocab.json 
 Merges saved to : tinystories_tokenizer/merges.txt 
 Vocabulary (readable) saved to : tinystories_tokenizer/vocab_read.txt 
