# Byte Pair Encoding (BPE) Implementation

This notebook implements the Byte Pair Encoding algorithm using PyTorch for tokenization. BPE is a subword tokenization algorithm that iteratively merges the most frequent pairs of characters or subwords.

## Features:
- Complete BPE implementation with PyTorch
- PDF text extraction using PyPDF
- Vocabulary building and tokenization
- Encoding and decoding functionality


In [2]:
# Import required libraries
import torch
import torch.nn as nn
import numpy as np
from collections import defaultdict, Counter
import re
import os
from typing import List, Dict, Tuple, Set
import pypdf
from io import BytesIO
import matplotlib.pyplot as plt
import seaborn as sns

print("PyTorch version:", torch.__version__)
print("Available device:", "cuda" if torch.cuda.is_available() else "cpu")


PyTorch version: 2.8.0
Available device: cpu


## PDF Text Extraction Utility


In [7]:
class PDFTextExtractor:
    """Utility class for extracting text from PDF files"""
    
    def __init__(self,):
        self.extracted_text = ""
    
    def extract_text_from_pdf(self, pdf_path: str) -> str:
        try:
            with open(pdf_path, 'rb') as file:
                pdf_reader = pypdf.PdfReader(file)
                text = ""
                
                for page_num in range(len(pdf_reader.pages)):
                    page = pdf_reader.pages[page_num]
                    text += page.extract_text() + "\n"
                

                self.extracted_text = self.preprocess_text(text)
                
                print(f"Successfully extracted text from {pdf_path}")
                print(f"Total pages: {len(pdf_reader.pages)}")
                print(f"Text length: {len(text)} characters")
                
                return text                
        
        except Exception as e:
            print(f"Error extracting text from PDF: {e}")
            return ""
    
    
    def preprocess_text(self, text: str) -> str:
        # Remove extra whitespace and normalize
        text = re.sub(r'\s+', ' ', text)
        text = text.strip()
        
        # Remove special characters that might interfere with tokenization
        text = re.sub(r'[^\w\s.,!?;:()\-]', '', text)
        
        return text

pdf_extractor = PDFTextExtractor()
text = pdf_extractor.extract_text_from_pdf('/Users/maverick/pankaj/LLM-Ground/ux_law.pdf')


Successfully extracted text from /Users/maverick/pankaj/LLM-Ground/ux_law.pdf
Total pages: 186
Text length: 257256 characters


## Byte Pair Encoding Implementation


In [None]:
class BytePairEncoder:
    """
    Byte Pair Encoding implementation using PyTorch
    
    BPE is a subword tokenization algorithm that iteratively merges the most frequent
    pairs of characters or subwords to create a vocabulary.
    """
    
    def __init__(self, vocab_size: int = 1000):
        """
        Initialize BPE encoder
        
        Args:
            vocab_size (int): Target vocabulary size
        """
        self.vocab_size = vocab_size
        self.word_freqs = Counter()
        self.vocab = {}
        self.merges = []
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    def get_word_freqs(self, text: str) -> Counter:
        """
        Get word frequencies from text
        
        Args:
            text (str): Input text
            
        Returns:
            Counter: Word frequency counter
        """
        # Split text into words and add end-of-word token
        words = text.split()
        word_freqs = Counter()
        
        for word in words:
            # Add end-of-word token
            word_freqs[' '.join(list(word)) + ' </w>'] += 1
            
        return word_freqs
    
    def get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple[str, str], int]:
        """
        Get statistics of character pairs
        
        Args:
            vocab (Dict[str, int]): Current vocabulary with frequencies
            
        Returns:
            Dict[Tuple[str, str], int]: Pair frequencies
        """
        pairs = defaultdict(int)
        
        for word, freq in vocab.items():
            symbols = word.split()
            
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
                
        return pairs
    
    def merge_vocab(self, pair: Tuple[str, str], vocab: Dict[str, int]) -> Dict[str, int]:
        """
        Merge a pair in the vocabulary
        
        Args:
            pair (Tuple[str, str]): Pair to merge
            vocab (Dict[str, int]): Current vocabulary
            
        Returns:
            Dict[str, int]: Updated vocabulary
        """
        bigram = re.escape(' '.join(pair))
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        
        new_vocab = {}
        for word in vocab:
            new_word = p.sub(''.join(pair), word)
            new_vocab[new_word] = vocab[word]
            
        return new_vocab
    
    def train(self, text: str) -> None:
        """
        Train BPE model on text
        
        Args:
            text (str): Training text
        """
        print("Starting BPE training...")
        
        # Get initial word frequencies
        self.word_freqs = self.get_word_freqs(text)
        print(f"Initial vocabulary size: {len(self.word_freqs)}")
        
        # Initialize vocabulary with character-level tokens
        vocab = dict(self.word_freqs)
        
        # Get initial character set
        chars = set()
        for word in vocab.keys():
            chars.update(word.split())
        
        # Add special tokens
        chars.add('<unk>')
        chars.add('<pad>')
        chars.add('<s>')
        chars.add('</s>')
        
        # Initialize vocabulary with characters
        for char in chars:
            if char not in vocab:
                vocab[char] = 0
        
        print(f"Character vocabulary size: {len(chars)}")
        
        # Perform BPE merges
        num_merges = self.vocab_size - len(chars)
        
        for i in range(num_merges):
            pairs = self.get_stats(vocab)
            if not pairs:
                break
                
            # Get most frequent pair
            best_pair = max(pairs, key=pairs.get)
            
            # Merge the pair
            vocab = self.merge_vocab(best_pair, vocab)
            self.merges.append(best_pair)
            
            if (i + 1) % 100 == 0:
                print(f"Completed {i + 1} merges. Current vocab size: {len(vocab)}")
        
        self.vocab = vocab
        print(f"Training completed. Final vocabulary size: {len(self.vocab)}")
        print(f"Number of merges performed: {len(self.merges)}")
    
    def encode_word(self, word: str) -> List[str]:
        """
        Encode a single word using BPE
        
        Args:
            word (str): Word to encode
            
        Returns:
            List[str]: List of BPE tokens
        """
        if word in self.vocab:
            return word.split()
        
        # Apply merges in order
        word = ' '.join(list(word)) + ' </w>'
        
        for pair in self.merges:
            bigram = re.escape(' '.join(pair))
            p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
            word = p.sub(''.join(pair), word)
        
        return word.split()
    
    def encode(self, text: str) -> List[str]:
        """
        Encode text using BPE
        
        Args:
            text (str): Text to encode
            
        Returns:
            List[str]: List of BPE tokens
        """
        words = text.split()
        tokens = []
        
        for word in words:
            tokens.extend(self.encode_word(word))
        
        return tokens
    
    def decode(self, tokens: List[str]) -> str:
        """
        Decode BPE tokens back to text
        
        Args:
            tokens (List[str]): BPE tokens
            
        Returns:
            str: Decoded text
        """
        # Join tokens and remove end-of-word markers
        text = ''.join(tokens)
        text = text.replace('</w>', ' ')
        
        return text.strip()
    
    def get_vocab_stats(self) -> Dict[str, int]:
        """
        Get vocabulary statistics
        
        Returns:
            Dict[str, int]: Vocabulary statistics
        """
        stats = {
            'total_vocab_size': len(self.vocab),
            'number_of_merges': len(self.merges),
            'most_frequent_tokens': dict(Counter(self.vocab).most_common(10))
        }
        return stats

# Initialize BPE encoder
bpe_encoder = BytePairEncoder(vocab_size=1000)


## PyTorch Integration for Tokenization


In [None]:
class PyTorchTokenizer(nn.Module):
    """
    PyTorch module for BPE tokenization with tensor operations
    """
    
    def __init__(self, bpe_encoder: BytePairEncoder):
        """
        Initialize PyTorch tokenizer
        
        Args:
            bpe_encoder (BytePairEncoder): Trained BPE encoder
        """
        super().__init__()
        self.bpe_encoder = bpe_encoder
        self.vocab_size = len(bpe_encoder.vocab)
        
        # Create token to index mapping
        self.token_to_idx = {token: idx for idx, token in enumerate(bpe_encoder.vocab.keys())}
        self.idx_to_token = {idx: token for token, idx in self.token_to_idx.items()}
        
        # Special token indices
        self.unk_idx = self.token_to_idx.get('<unk>', 0)
        self.pad_idx = self.token_to_idx.get('<pad>', 1)
        self.sos_idx = self.token_to_idx.get('<s>', 2)
        self.eos_idx = self.token_to_idx.get('</s>', 3)
        
    def encode_to_tensor(self, text: str, max_length: int = 512) -> torch.Tensor:
        """
        Encode text to PyTorch tensor
        
        Args:
            text (str): Input text
            max_length (int): Maximum sequence length
            
        Returns:
            torch.Tensor: Token indices tensor
        """
        tokens = self.bpe_encoder.encode(text)
        token_ids = []
        
        for token in tokens:
            if token in self.token_to_idx:
                token_ids.append(self.token_to_idx[token])
            else:
                token_ids.append(self.unk_idx)
        
        # Truncate or pad to max_length
        if len(token_ids) > max_length:
            token_ids = token_ids[:max_length]
        else:
            token_ids.extend([self.pad_idx] * (max_length - len(token_ids)))
        
        return torch.tensor(token_ids, dtype=torch.long)
    
    def decode_from_tensor(self, tensor: torch.Tensor) -> str:
        """
        Decode PyTorch tensor back to text
        
        Args:
            tensor (torch.Tensor): Token indices tensor
            
        Returns:
            str: Decoded text
        """
        token_ids = tensor.tolist()
        tokens = []
        
        for token_id in token_ids:
            if token_id in self.idx_to_token:
                token = self.idx_to_token[token_id]
                if token not in ['<pad>', '<s>', '</s>']:
                    tokens.append(token)
        
        return self.bpe_encoder.decode(tokens)
    
    def get_embedding_layer(self, embedding_dim: int = 128) -> nn.Embedding:
        """
        Create embedding layer for tokens
        
        Args:
            embedding_dim (int): Embedding dimension
            
        Returns:
            nn.Embedding: PyTorch embedding layer
        """
        return nn.Embedding(self.vocab_size, embedding_dim, padding_idx=self.pad_idx)
    
    def forward(self, text: str, max_length: int = 512) -> torch.Tensor:
        """
        Forward pass for tokenization
        
        Args:
            text (str): Input text
            max_length (int): Maximum sequence length
            
        Returns:
            torch.Tensor: Token indices tensor
        """
        return self.encode_to_tensor(text, max_length)

# Initialize PyTorch tokenizer (will be used after BPE training)
pytorch_tokenizer = None


## Usage Examples and Testing


In [None]:
# Example usage with sample text
sample_text = """
Byte Pair Encoding is a subword tokenization algorithm that iteratively merges 
the most frequent pairs of characters or subwords. It is widely used in 
natural language processing for handling out-of-vocabulary words and 
reducing vocabulary size while maintaining semantic meaning.
"""

print("Sample text:")
print(sample_text)
print("\n" + "="*50 + "\n")

# Train BPE on sample text
print("Training BPE encoder...")
bpe_encoder.train(sample_text)

# Get vocabulary statistics
stats = bpe_encoder.get_vocab_stats()
print(f"\nVocabulary Statistics:")
print(f"Total vocabulary size: {stats['total_vocab_size']}")
print(f"Number of merges: {stats['number_of_merges']}")
print(f"Most frequent tokens: {list(stats['most_frequent_tokens'].keys())[:10]}")

# Initialize PyTorch tokenizer
pytorch_tokenizer = PyTorchTokenizer(bpe_encoder)
print(f"\nPyTorch tokenizer initialized with vocabulary size: {pytorch_tokenizer.vocab_size}")

# Test encoding and decoding
test_text = "Byte Pair Encoding algorithm"
print(f"\nTest text: '{test_text}'")

# Encode using BPE
tokens = bpe_encoder.encode(test_text)
print(f"BPE tokens: {tokens}")

# Encode to PyTorch tensor
tensor = pytorch_tokenizer.encode_to_tensor(test_text, max_length=20)
print(f"PyTorch tensor shape: {tensor.shape}")
print(f"Token indices: {tensor.tolist()}")

# Decode back to text
decoded_text = pytorch_tokenizer.decode_from_tensor(tensor)
print(f"Decoded text: '{decoded_text}'")


## PDF Processing Example


In [None]:
# Example function to process PDF files
def process_pdf_with_bpe(pdf_path: str, vocab_size: int = 2000):
    """
    Process a PDF file and train BPE on its content
    
    Args:
        pdf_path (str): Path to the PDF file
        vocab_size (int): Target vocabulary size for BPE
    """
    print(f"Processing PDF: {pdf_path}")
    
    # Extract text from PDF
    extracted_text = pdf_extractor.extract_text_from_pdf(pdf_path)
    
    if not extracted_text:
        print("No text extracted from PDF")
        return None
    
    # Preprocess text
    processed_text = pdf_extractor.preprocess_text(extracted_text)
    print(f"Processed text length: {len(processed_text)} characters")
    
    # Create new BPE encoder for this document
    doc_bpe_encoder = BytePairEncoder(vocab_size=vocab_size)
    
    # Train BPE on the document
    doc_bpe_encoder.train(processed_text)
    
    # Create PyTorch tokenizer
    doc_pytorch_tokenizer = PyTorchTokenizer(doc_bpe_encoder)
    
    # Show some statistics
    stats = doc_bpe_encoder.get_vocab_stats()
    print(f"\nDocument BPE Statistics:")
    print(f"Vocabulary size: {stats['total_vocab_size']}")
    print(f"Number of merges: {stats['number_of_merges']}")
    
    # Test tokenization on a sample
    sample_sentences = processed_text.split('.')[:3]  # First 3 sentences
    print(f"\nTokenization examples:")
    
    for i, sentence in enumerate(sample_sentences):
        if sentence.strip():
            tokens = doc_bpe_encoder.encode(sentence.strip())
            tensor = doc_pytorch_tokenizer.encode_to_tensor(sentence.strip(), max_length=50)
            decoded = doc_pytorch_tokenizer.decode_from_tensor(tensor)
            
            print(f"\nSentence {i+1}:")
            print(f"Original: {sentence.strip()[:100]}...")
            print(f"Tokens: {tokens[:10]}...")  # Show first 10 tokens
            print(f"Tensor shape: {tensor.shape}")
            print(f"Decoded: {decoded[:100]}...")
    
    return doc_bpe_encoder, doc_pytorch_tokenizer

# Example usage (uncomment when you have a PDF file)
# pdf_path = "path/to/your/document.pdf"
# bpe_encoder, pytorch_tokenizer = process_pdf_with_bpe(pdf_path)

print("PDF processing function ready!")
print("To use with your PDF file:")
print("1. Place your PDF file in the TOKENIZATION folder")
print("2. Update the pdf_path variable with the correct filename")
print("3. Uncomment and run the process_pdf_with_bpe() function call")


## Visualization and Analysis


In [None]:
def visualize_bpe_stats(bpe_encoder: BytePairEncoder):
    """
    Visualize BPE statistics and token distribution
    
    Args:
        bpe_encoder (BytePairEncoder): Trained BPE encoder
    """
    stats = bpe_encoder.get_vocab_stats()
    
    # Create visualizations
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('BPE Tokenization Analysis', fontsize=16)
    
    # 1. Most frequent tokens
    most_frequent = stats['most_frequent_tokens']
    tokens = list(most_frequent.keys())[:15]
    frequencies = list(most_frequent.values())[:15]
    
    axes[0, 0].bar(range(len(tokens)), frequencies)
    axes[0, 0].set_title('Most Frequent Tokens')
    axes[0, 0].set_xlabel('Token Rank')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_xticks(range(len(tokens)))
    axes[0, 0].set_xticklabels(tokens, rotation=45, ha='right')
    
    # 2. Token length distribution
    token_lengths = [len(token) for token in bpe_encoder.vocab.keys()]
    axes[0, 1].hist(token_lengths, bins=20, alpha=0.7, color='skyblue')
    axes[0, 1].set_title('Token Length Distribution')
    axes[0, 1].set_xlabel('Token Length')
    axes[0, 1].set_ylabel('Count')
    
    # 3. Merge progression (if we have merge data)
    if bpe_encoder.merges:
        merge_indices = list(range(len(bpe_encoder.merges)))
        vocab_sizes = [len(bpe_encoder.vocab) - len(bpe_encoder.merges) + i for i in merge_indices]
        
        axes[1, 0].plot(merge_indices, vocab_sizes, marker='o', markersize=3)
        axes[1, 0].set_title('Vocabulary Size Growth')
        axes[1, 0].set_xlabel('Merge Step')
        axes[1, 0].set_ylabel('Vocabulary Size')
    
    # 4. Character frequency in vocabulary
    char_freq = Counter()
    for token in bpe_encoder.vocab.keys():
        for char in token:
            char_freq[char] += 1
    
    top_chars = dict(char_freq.most_common(15))
    axes[1, 1].bar(range(len(top_chars)), list(top_chars.values()))
    axes[1, 1].set_title('Character Frequency in Vocabulary')
    axes[1, 1].set_xlabel('Character')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].set_xticks(range(len(top_chars)))
    axes[1, 1].set_xticklabels(list(top_chars.keys()), rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print(f"\nBPE Summary Statistics:")
    print(f"Total vocabulary size: {stats['total_vocab_size']}")
    print(f"Number of merges performed: {stats['number_of_merges']}")
    print(f"Average token length: {np.mean(token_lengths):.2f}")
    print(f"Token length std: {np.std(token_lengths):.2f}")
    print(f"Unique characters: {len(char_freq)}")

# Test visualization with our sample BPE encoder
if 'bpe_encoder' in locals() and bpe_encoder.vocab:
    print("Generating BPE visualization...")
    visualize_bpe_stats(bpe_encoder)
else:
    print("Please train a BPE encoder first to see visualizations")


## Instructions for Using with Your PDF File

To use this notebook with your own PDF file:

1. **Place your PDF file** in the `TOKENIZATION` folder
2. **Update the PDF path** in the cell below with your filename
3. **Run the processing function** to train BPE on your document
4. **Analyze the results** using the visualization functions

### Example Usage:
```python
# Replace 'your_document.pdf' with your actual PDF filename
pdf_path = "your_document.pdf"
bpe_encoder, pytorch_tokenizer = process_pdf_with_bpe(pdf_path, vocab_size=2000)

# Visualize the results
visualize_bpe_stats(bpe_encoder)
```


In [None]:
# Ready to process your PDF file!
# Uncomment and modify the following lines when you have a PDF file:

# pdf_path = "your_document.pdf"  # Replace with your PDF filename
# bpe_encoder, pytorch_tokenizer = process_pdf_with_bpe(pdf_path, vocab_size=2000)

print("🎉 Byte Pair Encoding implementation complete!")
print("\nFeatures implemented:")
print("✅ Complete BPE algorithm with PyTorch integration")
print("✅ PDF text extraction using PyPDF")
print("✅ Vocabulary building and tokenization")
print("✅ Encoding and decoding functionality")
print("✅ Visualization and analysis tools")
print("✅ Ready for your PDF files!")

print("\nNext steps:")
print("1. Add your PDF file to the TOKENIZATION folder")
print("2. Update the pdf_path variable above")
print("3. Run the process_pdf_with_bpe() function")
print("4. Explore the tokenization results!")
