In [71]:
from collections import defaultdict
import nltk
from string import punctuation
from nltk.corpus import gutenberg
from nltk import word_tokenize

nltk.download('gutenberg')
nltk.download('punkt')

class BPETokenizer:
    def __init__(self, text, num_merges):
        self.token_vocab = set()
        self.corpus = {}
        self.num_merges = num_merges
        self.merges = set()
        self.text = self._preprocess_text(text)
        
    def _preprocess_text(self, text):
        words = []
        current_word = ''
        
        for char in text:
            if char.isspace() or char in punctuation:
                if current_word:
                    words.append(current_word)
                    current_word = ''
                if not char.isspace():
                    words.append(char)
            else:
                current_word += char
                
        if current_word:
            words.append(current_word)
        
        return ' '.join(words)

    def get_vocab(self):
        vocab = set(char for word in self.text.split() for char in word)
        vocab.add("_")
        return vocab

    def count_words(self):
        freqs = defaultdict(int)
        for word in self.text.split():
            word += "_"
            freqs[word] += 1
        return freqs
    
    def find_pairs(self):
        pairs = defaultdict(int)
        
        for word, count in self.corpus.items():
            tokens = word.split()
            if len(tokens) < 2:
                continue
                
            for i in range(len(tokens) - 1):
                pair = (tokens[i], tokens[i + 1])
                overlap = sum(1 for c in tokens[i] if c in tokens[i + 1])
                merged_len = len(tokens[i]) + len(tokens[i + 1]) - overlap
                pair_score = count * (1 + overlap/merged_len)
                pairs[pair] += pair_score
                
        return pairs

    def merge_pair(self, pair):
        new_dict = defaultdict(int)
        bigram = " ".join(pair)
        merged = "".join(pair)
        
        for word, count in self.corpus.items():
            new_word = word.replace(bigram, merged)
            new_dict[new_word] = count
            
        return new_dict

    def BPE(self):
        vocab = self.get_vocab()
        self.corpus = {" ".join(list(word)): count 
                      for word, count in self.count_words().items()}

        for i in range(self.num_merges):
            pairs = self.find_pairs()
            if not pairs:
                break
                
            best_pair = max(pairs.items(), key=lambda x: x[1])[0]
            self.corpus = self.merge_pair(best_pair)
            self.merges.add(best_pair)
            
        return self.corpus, vocab, self.merges

    def encode(self, text):
        processed_text = self._preprocess_text(text)
        result = []
        merge_dict = {" ".join(pair): "".join(pair) for pair in self.merges}
        
        for word in processed_text.split():
            word += "_"
            current = " ".join(list(word))
            
            while True:
                merged = False
                parts = current.split()
                
                for i in range(len(parts) - 1):
                    bigram = f"{parts[i]} {parts[i+1]}"
                    if bigram in merge_dict:
                        current = current.replace(bigram, merge_dict[bigram], 1)
                        merged = True
                        break
                        
                if not merged:
                    break
                    
            result.append(current.replace(" ", ""))
        
        return result

    def decode(self, encoded_text):
        decoded_text = "".join("".join(token) for token in encoded_text)
        return decoded_text.replace("_", " ")

    def calculate_metrics(self, reference_tokens, bpe_tokens):
        ref_vocab = set(reference_tokens)
        bpe_vocab = set(bpe_tokens)
        
        true_positives = len(ref_vocab.intersection(bpe_vocab))
        false_positives = len(bpe_vocab - ref_vocab) 
        false_negatives = len(ref_vocab - bpe_vocab)
        
        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        jaccard_similarity = len(ref_vocab.intersection(bpe_vocab)) / len(ref_vocab.union(bpe_vocab)) if ref_vocab or bpe_vocab else 0
        
        correct_tokens = sum(1 for token in bpe_vocab if token in ref_vocab)
        tokenization_accuracy = (correct_tokens / len(ref_vocab.union(bpe_vocab))) * 100 if len(ref_vocab.union(bpe_vocab)) > 0 else 0
        tokenization_coverage = (len(bpe_vocab) / len(ref_vocab)) * 100 if len(ref_vocab) > 0 else 0
        
        return {
            'correct_tokens': correct_tokens,
            'ref_vocab_size': len(ref_vocab),
            'bpe_vocab_size': len(bpe_vocab),
            'ref_avg_token_length': sum(len(t) for t in reference_tokens) / len(reference_tokens),
            'bpe_avg_token_length': sum(len(t) for t in bpe_tokens) / len(bpe_tokens),
            'total_ref_tokens': len(reference_tokens),
            'total_bpe_tokens': len(bpe_tokens),
            'tokenization_accuracy': tokenization_accuracy,
            'tokenization_coverage': tokenization_coverage,
            'precision': precision,
            'recall': recall,
            'f1_score': f1_score,
            'jaccard_similarity': jaccard_similarity,
            'true_positives': true_positives,
            'false_positives': false_positives,
            'false_negatives': false_negatives
        }

[nltk_data] Downloading package gutenberg to
[nltk_data]     /Users/dhruvgorasiya/nltk_data...
[nltk_data]   Package gutenberg is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/dhruvgorasiya/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [72]:
training_text = "Low lower lowest new newer"

bpe_tokenizer = BPETokenizer(training_text, 10)

bpe_tokenizer.BPE()

test_text = "Low asdf"

print(bpe_tokenizer.encode(test_text))
print(bpe_tokenizer.decode(bpe_tokenizer.encode(test_text)))

['Low_', 'asdf_']
Low asdf 


In [None]:


if __name__ == "__main__":
    books = {
    'training': [
        gutenberg.raw("austen-emma.txt"),
        gutenberg.raw("blake-poems.txt"),
        gutenberg.raw("shakespeare-hamlet.txt")
    ],
    'testing': {
        'shakespeare-caesar': gutenberg.raw("shakespeare-caesar.txt"),
        'carroll-alice': gutenberg.raw("carroll-alice.txt"),
        'chesterton-ball': gutenberg.raw("chesterton-ball.txt")
    }
    }
    
    training_text = " ".join(books['training'])
    bpe_tokenizer = BPETokenizer(training_text, num_merges=100000)
    corpus, vocab, merges = bpe_tokenizer.BPE()
    
    print("\nTokenization Comparison:")
    print("-" * 50)
    
    for book_name, test_book in books['testing'].items():
        ref_tokens = word_tokenize(test_book)
        bpe_tokens = bpe_tokenizer.encode(test_book)
        bpe_tokens = [token.replace('_', '') for token in bpe_tokens]
        
        metrics = bpe_tokenizer.calculate_metrics(ref_tokens, bpe_tokens)
        
        print(f"\nResults for {book_name}:")
        for metric, value in metrics.items():
            if isinstance(value, float):
                print(f"{metric.replace('_', ' ').title()}: {value:.2f}")
            else:
                print(f"{metric.replace('_', ' ').title()}: {value}")
