In [None]:
# BPE Tokenizer from First Principles

This notebook implements Byte-Pair Encoding (BPE) from scratch to understand the fundamental mechanics of subword tokenization used in modern language models.

## Algorithm Overview

1. **Initialize:** Start with character-level tokens for each word
2. **Iterate:** Find the most frequent adjacent pair and merge them
3. **Repeat:** Continue until reaching desired vocabulary size


In [4]:
import re
from collections import defaultdict, Counter

class BPETokenizer:
    def __init__(self, vocab_size=1000):
        """
        Initialize BPE tokenizer with target vocabulary size.
        
        Args:
            vocab_size (int): Target size of final vocabulary
        """
        self.vocab_size = vocab_size
        self.word_freqs = {}
        self.vocab = set()
        self.merges = []
    
    def _get_word_tokens(self, word):
        """
        Convert a word into character-level tokens.
        
        Args:
            word (str): Input word
            
        Returns:
            list: List of character tokens
        """
        # Shahab, implement this method
        # Hint: Split word into individual characters, add </w> at the end
        tokens=list(word)
        tokens.append("</w>")
        return tokens
    
    def _get_stats(self):
        """
        Count frequency of all adjacent pairs in the current vocabulary.
        
        Returns:
            Counter: Frequency count of adjacent pairs
        """
        # Shahab, implement this method
        # Hint: Look at all adjacent pairs across all words and their frequencies
        pairs_count = Counter()
        for word,freq in self.word_freqs.items():
            for i in range(len(word)-1):
                pair = (word[i],word[i+1])
                pairs_count[pair]+=freq
        return pairs_count
    
    def _merge_vocab(self, pair):
        """
        Merge the most frequent pair throughout the vocabulary.
        
        Args:
            pair (tuple): The pair to merge (e.g., ('a', 't'))
        """
        # Shahab, implement this method
        # Hint: Replace all occurrences of the pair with the merged token
        new_word_freqs={}
        for word,freq in self.word_freqs.items():
            new_word = []
            i=0
            while i < len(word):
                if i < len(word)-1 and (word[i],word[i+1])==pair:
                    new_word.append(pair[0]+pair[1])
                    i=i+2
                else:
                    new_word.append(word[i])
                    i=i+1
            new_word_freqs[tuple(new_word)]=freq
        self.word_freqs=new_word_freqs
    
    def train(self, corpus):
        """
        Train the BPE tokenizer on a corpus.
        
        Args:
            corpus (list): List of sentences/documents
        """
        # Step 1: Extract word frequencies from corpus
        all_words = []
        for text in corpus:
            words = re.findall(r'\w+', text.lower())
            all_words.extend(words)
        
        # Count word frequencies
        word_counts = Counter(all_words)
        
        # Step 2: Initialize vocabulary with character-level tokens
        for word, freq in word_counts.items():
            self.word_freqs[tuple(self._get_word_tokens(word))] = freq
        
        # Build initial character vocabulary
        for word_tokens in self.word_freqs.keys():
            self.vocab.update(word_tokens)
        
        print(f"Initial vocabulary size: {len(self.vocab)}")
        print(f"Target vocabulary size: {self.vocab_size}")
        
        # Step 3: Iteratively merge most frequent pairs
        while len(self.vocab) < self.vocab_size:
            pairs = self._get_stats()
            if not pairs:
                break
                
            best_pair = pairs.most_common(1)[0][0]
            print(f"Merging: {best_pair}")
            
            self._merge_vocab(best_pair)
            self.merges.append(best_pair)
            
            # Add merged token to vocabulary
            merged_token = best_pair[0] + best_pair[1]
            self.vocab.add(merged_token)
        
        print(f"Final vocabulary size: {len(self.vocab)}")
    
    def encode(self, text):
        """
        Encode text using the trained BPE model.
        
        Args:
            text (str): Input text to encode
            
        Returns:
            list: List of BPE tokens
        """
        # Shahab, implement this method
        # Hint: Apply the learned merges in order to tokenize new text
        
        words = re.findall(r'\w+', text.lower())
        all_tokens = []
        for word in words:
            word_tokens = self._get_word_tokens(word)
            new_word_tokens = word_tokens
            for pair in self.merges:
                newer_word_tokens=[]
                i=0
                while i < len(new_word_tokens):
                    if i < len(new_word_tokens)-1 and (new_word_tokens[i],new_word_tokens[i+1])==pair:
                        newer_word_tokens.append(pair[0]+pair[1])
                        i=i+2
                    else:
                        newer_word_tokens.append(new_word_tokens[i])
                        i=i+1
                new_word_tokens = newer_word_tokens
            all_tokens.extend(new_word_tokens)
        return all_tokens


In [5]:
# Test corpus - simple example
corpus = [
    "the cat sat on the mat",
    "the cat saw the rat", 
    "the rat ran from the cat",
    "cats and rats are animals",
    "animals run and cats catch rats"
]

# Create and train tokenizer
tokenizer = BPETokenizer(vocab_size=50)
tokenizer.train(corpus)

print("\nLearned merges:")
for i, merge in enumerate(tokenizer.merges):
    print(f"{i+1}: {merge[0]} + {merge[1]} -> {merge[0] + merge[1]}")

print(f"\nFinal vocabulary: {sorted(tokenizer.vocab)}")

# Test encoding
test_text = "the cat catches rats"
tokens = tokenizer.encode(test_text)
print(f"\nEncoding '{test_text}': {tokens}")


Initial vocabulary size: 17
Target vocabulary size: 50
Merging: ('a', 't')
Merging: ('e', '</w>')
Merging: ('at', '</w>')
Merging: ('t', 'h')
Merging: ('th', 'e</w>')
Merging: ('s', '</w>')
Merging: ('a', 'n')
Merging: ('at', 's</w>')
Merging: ('c', 'at</w>')
Merging: ('n', '</w>')
Merging: ('r', 'at</w>')
Merging: ('c', 'ats</w>')
Merging: ('an', 'd')
Merging: ('and', '</w>')
Merging: ('r', 'ats</w>')
Merging: ('an', 'i')
Merging: ('ani', 'm')
Merging: ('anim', 'a')
Merging: ('anima', 'l')
Merging: ('animal', 's</w>')
Merging: ('s', 'at</w>')
Merging: ('o', 'n</w>')
Merging: ('m', 'at</w>')
Merging: ('s', 'a')
Merging: ('sa', 'w')
Merging: ('saw', '</w>')
Merging: ('r', 'an')
Merging: ('ran', '</w>')
Merging: ('f', 'r')
Merging: ('fr', 'o')
Merging: ('fro', 'm')
Merging: ('from', '</w>')
Merging: ('a', 'r')
Final vocabulary size: 50

Learned merges:
1: a + t -> at
2: e + </w> -> e</w>
3: at + </w> -> at</w>
4: t + h -> th
5: th + e</w> -> the</w>
6: s + </w> -> s</w>
7: a + n -> an
8: