# Understanding and Implementing Byte-Pair Encoding (BPE)

This tutorial will walk you through the fundamentals of BPE tokenization and implement a custom tokenizer class called `BPE`. 

## What is BPE?
Byte-Pair Encoding is a data compression technique that was adapted for subword tokenization. It's used by many modern language models like GPT and BERT to break words into meaningful subword units.


# 1. Understanding BPE Tokenization

BPE works by:
1. Starting with a vocabulary of individual characters
2. Iteratively finding the most frequent pair of adjacent tokens
3. Merging these pairs to create new tokens
4. Repeating until a desired vocabulary size is reached

Let's see this with a simple example:

In [2]:
# Example of BPE process
from collections import Counter
import re

def visualize_bpe_step(word_list, vocab_size=10):
    # Count initial character frequencies
    words = [' '.join(list(word)) for word in word_list]
    
    # Current vocabulary is individual characters
    vocab = set(''.join(word_list))
    print(f"Initial vocabulary size: {len(vocab)}")
    print("Vocabulary:", sorted(list(vocab)))
    print("\nStarting BPE process:")
    
    while len(vocab) < vocab_size:
        # Get pairs of tokens
        pairs = Counter()
        for word in words:
            tokens = word.split()
            for i in range(len(tokens)-1):
                pair = (tokens[i], tokens[i+1])
                pairs[pair] += 1
        
        if not pairs:
            break
            
        # Get most frequent pair
        best_pair = max(pairs.items(), key=lambda x: x[1])
        print(f"\nMost frequent pair: {best_pair[0]} (frequency: {best_pair[1]})")
        
        # Merge pair in all words
        new_token = ''.join(best_pair[0])
        vocab.add(new_token)
        
        new_words = []
        for word in words:
            new_word = word.replace(' '.join(best_pair[0]), new_token)
            new_words.append(new_word)
        words = new_words
        
        print(f"Words after merge: {words}")
        print(f"Vocabulary size: {len(vocab)}")
        print("Current vocabulary:", sorted(list(vocab)))

# Example usage
word_list = ['low', 'lowest', 'newer', 'wider', 'new', 'low']
visualize_bpe_step(word_list, vocab_size=15)

Initial vocabulary size: 10
Vocabulary: ['d', 'e', 'i', 'l', 'n', 'o', 'r', 's', 't', 'w']

Starting BPE process:

Most frequent pair: ('l', 'o') (frequency: 3)
Words after merge: ['lo w', 'lo w e s t', 'n e w e r', 'w i d e r', 'n e w', 'lo w']
Vocabulary size: 11
Current vocabulary: ['d', 'e', 'i', 'l', 'lo', 'n', 'o', 'r', 's', 't', 'w']

Most frequent pair: ('lo', 'w') (frequency: 3)
Words after merge: ['low', 'low e s t', 'n e w e r', 'w i d e r', 'n e w', 'low']
Vocabulary size: 12
Current vocabulary: ['d', 'e', 'i', 'l', 'lo', 'low', 'n', 'o', 'r', 's', 't', 'w']

Most frequent pair: ('n', 'e') (frequency: 2)
Words after merge: ['low', 'low e s t', 'ne w e r', 'w i d e r', 'ne w', 'low']
Vocabulary size: 13
Current vocabulary: ['d', 'e', 'i', 'l', 'lo', 'low', 'n', 'ne', 'o', 'r', 's', 't', 'w']

Most frequent pair: ('ne', 'w') (frequency: 2)
Words after merge: ['low', 'low e s t', 'new e r', 'w i d e r', 'new', 'low']
Vocabulary size: 14
Current vocabulary: ['d', 'e', 'i', 'l',

# How to tokenize the courpus

1. easy split based on space
2. using regex to follow some rules and patterns

In [3]:
import regex as re
GPT4PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
compiler = re.compile(GPT4PATTERN)
re.findall(compiler, "Hello, world!")

['Hello', ',', ' world', '!']

# Understanding the Regex Pattern

Let's break down the regex pattern used in our tokenizer. First, let's understand what `regex` is and why we use it:

## What is the `regex` module?
```python
import regex as re
```
- `regex` is an enhanced version of Python's built-in `re` module
- It adds support for Unicode properties (`\p{L}`, `\p{N}`, etc.)
- It's more powerful than `re` for handling complex patterns

## Breaking Down the Pattern
```python
GPT4PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
```

This pattern is broken into several parts, separated by `|` (OR operator):

1. `'(?i:[sdmt]|ll|ve|re)`
   - Matches contractions like 's, 'll, 've, 're
   - `(?i:...)` makes this part case-insensitive
   - Examples: "I'm", "I'll", "I've", "they're"

2. `[^\r\n\p{L}\p{N}]?+\p{L}+`
   - `\p{L}` matches any kind of letter from any language
   - `\p{N}` matches any kind of numeric character
   - `[^\r\n\p{L}\p{N}]?+` optionally matches one non-letter, non-number character
   - Examples: "Hello", "ä¸–ç•Œ", "cafÃ©"

3. `\p{N}{1,3}`
   - Matches 1 to 3 consecutive numbers
   - Examples: "1", "42", "999"

4. ` ?[^\s\p{L}\p{N}]++[\r\n]*`
   - Matches punctuation and symbols
   - ` ?` optionally matches a space
   - `[^\s\p{L}\p{N}]++` matches non-space, non-letter, non-number characters
   - Examples: "!", ",", "?"

5. `\s*[\r\n]|\s+(?!\S)|\s+`
   - Matches different types of whitespace
   - `\s*[\r\n]` matches newlines with optional spaces
   - `\s+(?!\S)` matches spaces at the end of text
   - `\s+` matches any other whitespace

In [4]:
# Let's see how this pattern works with examples
import regex as re

def explain_matches(text):
    """Show how the pattern matches parts of the text"""
    pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
    compiler = re.compile(pattern)
    matches = compiler.findall(text)
    
    print(f"Text: {text}")
    print("Tokens:", matches)
    print("\nBreakdown:")
    for i, token in enumerate(matches, 1):
        if token.isspace():
            token_display = "[SPACE]"
        else:
            token_display = token
        print(f"{i}. '{token_display}'")

# Example 1: Simple text with punctuation
print("Example 1: Simple text")
explain_matches("Hello, world!")

# Example 2: Text with contractions
print("\nExample 2: Contractions")
explain_matches("I'm going to the store, they've been there.")

# Example 3: Mixed content
print("\nExample 3: Mixed content")
explain_matches("In 2023, AI's progress is amazing! ðŸ˜Š")

Example 1: Simple text
Text: Hello, world!
Tokens: ['Hello', ',', ' world', '!']

Breakdown:
1. 'Hello'
2. ','
3. ' world'
4. '!'

Example 2: Contractions
Text: I'm going to the store, they've been there.
Tokens: ['I', "'m", ' going', ' to', ' the', ' store', ',', ' they', "'ve", ' been', ' there', '.']

Breakdown:
1. 'I'
2. ''m'
3. ' going'
4. ' to'
5. ' the'
6. ' store'
7. ','
8. ' they'
9. ''ve'
10. ' been'
11. ' there'
12. '.'

Example 3: Mixed content
Text: In 2023, AI's progress is amazing! ðŸ˜Š
Tokens: ['In', ' ', '202', '3', ',', ' AI', "'s", ' progress', ' is', ' amazing', '!', ' ðŸ˜Š']

Breakdown:
1. 'In'
2. '[SPACE]'
3. '202'
4. '3'
5. ','
6. ' AI'
7. ''s'
8. ' progress'
9. ' is'
10. ' amazing'
11. '!'
12. ' ðŸ˜Š'


# Understanding Pattern Compilation

When we use:
```python
compiler = re.compile(GPT4PATTERN)
```

This is important because:

1. **Performance**: Compiling a pattern is more efficient when you plan to use it multiple times. The pattern is parsed once and can be reused.

2. **Validation**: Compilation checks if the pattern is valid. If there's a syntax error, it will be caught at compile time rather than during execution.

3. **Options**: You can set flags during compilation that affect how the pattern works (like case-insensitivity).

## Pattern Execution Methods

The `re` module provides several ways to use patterns:

1. `findall()`: Returns all non-overlapping matches in a string
   ```python
   re.findall(compiler, "Hello, world!")
   ```
   - Returns a list of all matches
   - Each match is returned as a string
   - Great for tokenization where you want all pieces

2. `match()`: Tries to match at the beginning of the string

3. `search()`: Looks for a match anywhere in the string

4. `split()`: Splits the string by the pattern

For tokenization, we use `findall()` because we want to:
- Get all tokens in the text
- Preserve the order of tokens
- Include both words and separators

In [5]:
# Let's compare different regex methods
import regex as re

text = "Hello, world! This is a test."
pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
compiler = re.compile(pattern)

# 1. findall() - Get all matches
print("1. findall() results:")
findall_results = compiler.findall(text)
print(findall_results)

# 2. match() - Try to match at the beginning
print("\n2. match() result:")
match_result = compiler.match(text)
if match_result:
    print(f"Matched: '{match_result.group()}'")
    print(f"Start: {match_result.start()}, End: {match_result.end()}")
else:
    print("No match at beginning")

# 3. search() - Find first match anywhere
print("\n3. search() result:")
search_result = compiler.search(text)
if search_result:
    print(f"Found: '{search_result.group()}'")
    print(f"Start: {search_result.start()}, End: {search_result.end()}")
else:
    print("No match found")

# 4. split() - Split by pattern
print("\n4. split() results:")
split_results = compiler.split(text)
print(split_results)

# Show which method is best for tokenization
print("\nFor tokenization, findall() is best because:")
print("- Gets all tokens in order")
print("- Preserves all parts of the text")
print("- Returns a simple list of strings")

1. findall() results:
['Hello', ',', ' world', '!', ' This', ' is', ' a', ' test', '.']

2. match() result:
Matched: 'Hello'
Start: 0, End: 5

3. search() result:
Found: 'Hello'
Start: 0, End: 5

4. split() results:
['', '', '', '', '', '', '', '', '', '']

For tokenization, findall() is best because:
- Gets all tokens in order
- Preserves all parts of the text
- Returns a simple list of strings


In [6]:
import regex as re
GPT4PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
compiler = re.compile(GPT4PATTERN)
words = re.findall(compiler, "Hello, world!")
ids = [list(ch.encode('utf-8')) for ch in words]
ids

[[72, 101, 108, 108, 111], [44], [32, 119, 111, 114, 108, 100], [33]]

# Understanding and Implementing the BPE Class

Let's break down how to implement a Byte-Pair Encoding (BPE) tokenizer step by step. This guide will help you understand each component and implement your own BPE tokenizer.

## Class Structure Overview

The BPE class needs three main components:
1. Initialization and storage of basic information
2. Methods to analyze and count token pairs
3. Methods to merge tokens and update the vocabulary

## Step-by-Step Implementation Guide

### Step 1: Class Initialization
Your class needs to track:
- The target vocabulary size
- The input texts
- The current vocabulary
- The merge operations performed

Think about:
- What data structures would be best for storing the vocabulary?
- How will you keep track of merge operations?
- What initial size should your vocabulary have?

### Step 2: Statistics Collection (get_stats method)
This method needs to:
1. Look at pairs of adjacent tokens in the text
2. Count how often each pair appears
3. Return the counts for all pairs

Consider:
- How to efficiently count pairs
- How to handle pairs that span word boundaries
- What data structure is best for storing counts

### Step 3: Pair Merging (merge_pair method)
This method should:
1. Take a pair of tokens and their new token ID
2. Find all occurrences of this pair in the text
3. Replace them with the new token ID

Think about:
- How to efficiently find pairs in the text
- How to handle overlapping pairs
- How to maintain the text structure while merging

### Step 4: BPE Learning (learn_bpe method)
This is the main method that:
1. Initializes the base vocabulary (256 bytes)
2. Processes the input text into initial tokens
3. Repeatedly:
   - Counts token pairs
   - Finds the most frequent pair
   - Merges this pair into a new token
   - Updates the vocabulary
   - Continues until reaching the target vocabulary size

Consider:
- How to handle text preprocessing
- When to stop merging
- How to maintain the vocabulary efficiently

## Implementation Tips

1. Data Structures:
   - Use dictionaries for the vocabulary (fast lookup)
   - Use defaultdict for counting pairs (convenient counting)
   - Consider using lists for storing merge operations

2. Text Processing:
   - Start with byte-level tokenization
   - Use regex for initial text splitting
   - Handle UTF-8 encoding properly

3. Optimization Tips:
   - Cache frequently accessed data
   - Avoid unnecessary string operations
   - Use efficient data structures for frequent operations

4. Error Handling:
   - Check input parameters
   - Handle edge cases (empty text, small vocab size)
   - Validate vocabulary size (>= 256)

## Common Pitfalls to Avoid

1. Vocabulary Size:
   - Don't forget the base vocabulary of 256 bytes
   - Ensure target size is greater than base size

2. Merging Process:
   - Be careful with overlapping pairs
   - Don't modify data while iterating
   - Track indices carefully during merges

3. Text Handling:
   - Remember to handle UTF-8 properly
   - Consider whitespace and special characters
   - Handle case sensitivity correctly

## Testing Your Implementation

Test your implementation with:
1. Simple repeated words
2. Mixed case text
3. Special characters
4. Unicode text
5. Various vocabulary sizes

## Assignment Tips

When implementing your own BPE:
1. Start with the basic structure
2. Implement one method at a time
3. Test each method thoroughly
4. Add features incrementally
5. Document your code
6. Add error handling last

Remember: BPE is an iterative process. Make sure each step works correctly before moving to the next one.

In [None]:
from collections import defaultdict
class BPE():

    def __init__(self,vocab_size,texts) -> None:
        self.vocab_size = vocab_size
        self.texts = texts
        self.vocab = dict()
        self.bpe_merges = []

    def get_stats(self, ids, pairs= None):
        """
        Get counts of all pairs of adjacent symbols in the dataset
        """
        pairs = defaultdict(int) if pairs is None else pairs
        
        # Code here to count pairs 
        # --------------------------------------------  

        return pairs
    
    def merge_pair(self,ids, pair, idx):
        """
        Merge all occurrences of the given pair in the dataset
        """
        new_ids = []

        #code here to merge pair
        # ---------------------------------------
        return new_ids
    
    def learn_bpe(self):
        """
        Learn BPE merges until reaching the desired vocabulary size
        """
        # Code here to Initialize vocabulary with individual characters
        self.vocab = None

        corpus = "" 
        for text in self.texts:
            corpus += text

        pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
        compiler = re.compile(pattern)
        words = re.findall(compiler, corpus)

        ids = [list(ch.encode('utf-8')) for ch in words]


    
        print(f"Initial vocabulary size: {len(self.vocab)}")
        
        idx = len(self.vocab)
        while len(self.vocab) < self.vocab_size:
            pairs = defaultdict(int)
            for chunk_ids in ids:
                self.get_stats(chunk_ids,pairs)
            
            if not pairs:
                break
            
            # Code here to Get most frequent pair
            best_pair = None


            self.bpe_merges.append(best_pair)
            new_token = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
            self.vocab[idx] = new_token
            
            # Merge pair in all texts
            new_ids = []
            for chunk_ids in ids:
                new_text = self.merge_pair(chunk_ids, best_pair, idx)
                new_ids.append(new_text)
            ids = new_ids
            idx+=1
        
        print("vocab size:", len(self.vocab))

# Validation

In [21]:
def validate_bpe_results(tokenizer, test_text, description):
    """
    Validate the BPE tokenizer results for different test cases
    """
    print(f"\n=== Validating {description} ===")
    print(f"Test text: {test_text}")
    
    # Check vocabulary fundamentals
    vocab_size = len(tokenizer.vocab)
    print(f"Current vocab size: {vocab_size}")
    print(f"Base vocabulary size (bytes): 256")
    print(f"Additional merged tokens: {vocab_size - 256}")
    
    # Analyze merge patterns
    if tokenizer.bpe_merges:
        print("\nSample merges that affect this text:")
        relevant_merges = []
        test_bytes = ''.join(test_text).encode('utf-8')
        
        for merge in tokenizer.bpe_merges:
            merged = tokenizer.vocab[merge[0]] + tokenizer.vocab[merge[1]]
            if merged in test_bytes:
                relevant_merges.append(merged.decode('utf-8', errors='ignore'))
                if len(relevant_merges) >= 5:  # Show up to 5 relevant merges
                    break
        
        if relevant_merges:
            for i, merge in enumerate(relevant_merges, 1):
                print(f"Relevant merge {i}: '{merge}'")
        else:
            print("No directly relevant merges found for this text")
    
    # Special case validations
    print("\nValidation checks:")
    
    # Check case sensitivity
    if any(word.isupper() for word in test_text.split()):
        upper_preserved = any(any(c.isupper() for c in v.decode('utf-8', errors='ignore'))
                            for v in tokenizer.vocab.values())
        print(f"Case sensitivity preserved: {'âœ“' if upper_preserved else 'âœ—'}")
    
    # Check for common subwords
    if "play" in test_text.lower():
        play_merged = any(b"play" in v for v in tokenizer.vocab.values())
        print(f"Common prefix 'play' merged: {'âœ“' if play_merged else 'âœ—'}")
    
    # Check for repetition handling
    if len(set(test_text.split())) < len(test_text.split()):
        repeated_words = [word for word in set(test_text.split()) 
                         if test_text.split().count(word) > 1]
        for word in repeated_words[:3]:  # Check up to 3 repeated words
            word_merged = any(word.encode('utf-8') in v for v in tokenizer.vocab.values())
            print(f"Repeated word '{word}' merged: {'âœ“' if word_merged else 'âœ—'}")
    
    # Check for special character handling
    if any(not c.isalnum() and not c.isspace() for c in test_text):
        spec_chars_preserved = any(not all(c.isalnum() or c.isspace() 
                                         for c in v.decode('utf-8', errors='ignore'))
                                 for v in tokenizer.vocab.values())
        print(f"Special characters preserved: {'âœ“' if spec_chars_preserved else 'âœ—'}")
    
    print("\n" + "="*50)

In [23]:
# Create a single tokenizer with comprehensive training data
print("Creating and training tokenizer...")
training_text = """
The quick brown fox jumps over the lazy dog.
Hello World! This is a test of the tokenizer.
We have numbers like 123, 456, and 789.
Playing player played plays playful playground
UPPERCASE, lowercase, and MiXeDcAsE words.
Special characters: !@#$%^&*()_+-=[]{}|;:',.<>?
Multiple     spaces    and   tabs   are    here.
Unicode characters: cafÃ© rÃ©sumÃ© chÃ¢teau Å¡koda
Contractions: don't won't can't I'm you're they've
"""

# Initialize the tokenizer with training data
tokenizer = BPE(vocab_size=400, texts=training_text)
tokenizer.learn_bpe()

print("\nRunning validation tests...\n")

# Test 1: Basic English text
print("Test 1: Basic English")
validate_bpe_results(tokenizer, "The quick brown fox jumps over the lazy dog", 
                    "Basic English text")

# Test 2: Repeated words
print("\nTest 2: Repeated Words")
validate_bpe_results(tokenizer, "the the the quick quick fox fox dog dog", 
                    "Text with repeated words")

# Test 3: Mixed case
print("\nTest 3: Mixed Case")
validate_bpe_results(tokenizer, "Hello WORLD wOrLd World HELLO", 
                    "Mixed case handling")

# Test 4: Numbers and special characters
print("\nTest 4: Numbers and Special Characters")
validate_bpe_results(tokenizer, "User123 has $100.00 in their account!", 
                    "Numbers and special characters")

# Test 5: Common subwords
print("\nTest 5: Subwords")
validate_bpe_results(tokenizer, "playing player played plays playful playground", 
                    "Common subword patterns")

# Test 6: Contractions
print("\nTest 6: Contractions")
validate_bpe_results(tokenizer, "I'm don't won't can't they've you're", 
                    "Contractions")

# Test 7: Whitespace handling
print("\nTest 7: Whitespace")
validate_bpe_results(tokenizer, "Multiple    spaces   and   tabs   here", 
                    "Whitespace handling")

# Test 8: Unicode characters
print("\nTest 8: Unicode")
validate_bpe_results(tokenizer, "cafÃ© rÃ©sumÃ© chÃ¢teau Å¡koda", 
                    "Unicode character handling")

# Test 9: Punctuation
print("\nTest 9: Punctuation")
validate_bpe_results(tokenizer, "Hello, world! How are you? This is: amazing.", 
                    "Punctuation handling")

# Test 10: Long compound words
print("\nTest 10: Compound Words")
validate_bpe_results(tokenizer, "tokenization implementation methodology understanding", 
                    "Long compound words")

Creating and training tokenizer...
Initial vocabulary size: 256
vocab size: 400

Running validation tests...

Test 1: Basic English

=== Validating Basic English text ===
Test text: The quick brown fox jumps over the lazy dog
Current vocab size: 400
Base vocabulary size (bytes): 256
Additional merged tokens: 144

Sample merges that affect this text:
Relevant merge 1: 'er'
Relevant merge 2: 'la'
Relevant merge 3: ' t'
Relevant merge 4: 'he'
Relevant merge 5: 'um'

Validation checks:


Test 2: Repeated Words

=== Validating Text with repeated words ===
Test text: the the the quick quick fox fox dog dog
Current vocab size: 400
Base vocabulary size (bytes): 256
Additional merged tokens: 144

Sample merges that affect this text:
Relevant merge 1: ' t'
Relevant merge 2: 'he'
Relevant merge 3: ' the'
Relevant merge 4: 'ic'
Relevant merge 5: ' d'

Validation checks:
Repeated word 'dog' merged: âœ“
Repeated word 'the' merged: âœ“
Repeated word 'quick' merged: âœ“


Test 3: Mixed Case

=== Valid