# Byte Pair Encoder

In [182]:
from collections import defaultdict
import re
from concurrent.futures import ThreadPoolExecutor

class BytePairEncoder:
    def __init__(self, n_merges):
        self.vocab = defaultdict(int)
        self.n_merges = n_merges
        self.str_to_int=defaultdict(int)

    def get_stats(self, corpus):
        pairs = defaultdict(int)
        for word, freq in corpus.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs

    def process_word(self, word, pair, pattern):
        new_word = pattern.sub(''.join(pair), word)
        return new_word

    def merge_pairs(self, pair, corpus):
        new_corpus = {}
        bigram = re.escape(" ".join(pair))
        pattern = re.compile(r"\b" + bigram + r"\b")

        with ThreadPoolExecutor() as executor:
            results = executor.map(lambda word: (self.process_word(word, pair, pattern), corpus[word]), corpus.keys())

        new_corpus.update(results)
        return new_corpus
    
    def generate_ids(self):
        tokens=sorted(self.vocab.keys())
        self.str_to_int={token: id for id, token in enumerate(tokens)}
        self.int_to_str={id: token for token, id in self.str_to_int.items()}

    def fit(self, corpus):
        corpus = {" ".join(word): freq for word, freq in corpus.items()}
        for _ in range(self.n_merges):
            pairs = self.get_stats(corpus)
            if not pairs:
                break
            best_pair = max(pairs, key=pairs.get)
            corpus = self.merge_pairs(best_pair, corpus)
            print(f"Merging {best_pair}")
        for word, freq in corpus.items():
            for token in word.split():
                self.vocab[token] += freq
        self.generate_ids()

    def encode(self, word):
        word = " ".join(word)
        for merged_token in sorted(self.vocab.keys(), key=len, reverse=True):
            word = word.replace(" ".join(merged_token), merged_token)
        return [self.str_to_int[w] for w in word.split()]
    def decode(self, word):
        return [self.int_to_str[w] for w in word]

In [183]:
with open('book.txt') as file:
    raw_text=file.read()

In [184]:
raw_text[:30]

'THE VERDICT\nJune 1908\n\nI had a'

In [185]:
preprocessed=re.split(r'([,.:;?_!"()\']|--|\s)', raw_text)
preprocessed = [r for r in preprocessed if r.strip()]

In [186]:
corpus=defaultdict(int)
for word in preprocessed:
    corpus[word.lower()]+=1

In [187]:
corpus['--']=0

In [188]:
bpe=BytePairEncoder(n_merges=25)
bpe.fit(corpus=corpus)

Merging ('h', 'e')
Merging ('i', 'n')
Merging ('t', 'he')
Merging ('h', 'a')
Merging ('o', 'u')
Merging ('r', 'e')
Merging ('a', 'n')
Merging ('i', 't')
Merging ('o', 'n')
Merging ('s', 't')
Merging ('e', 'd')
Merging ('i', 's')
Merging ('e', 'n')
Merging ('a', 's')
Merging ('in', 'g')
Merging ('t', 'o')
Merging ('o', 'f')
Merging ('e', 'r')
Merging ('an', 'd')
Merging ('a', 't')
Merging ('s', 'e')
Merging ('o', 'r')
Merging ('l', 'e')
Merging ('b', 'e')
Merging ('w', 'as')


In [189]:
bpe.vocab

defaultdict(int,
            {'the': 236,
             'v': 148,
             'er': 105,
             'd': 482,
             'i': 571,
             'c': 388,
             't': 709,
             'j': 43,
             'u': 335,
             'n': 252,
             'e': 665,
             '1': 1,
             '9': 1,
             '0': 1,
             '8': 1,
             'ha': 204,
             'a': 654,
             'l': 505,
             'w': 303,
             'y': 317,
             's': 523,
             'h': 394,
             'ou': 198,
             'g': 214,
             'k': 144,
             'is': 135,
             'b': 183,
             'r': 477,
             'he': 195,
             'p': 259,
             'en': 130,
             '-': 38,
             'o': 413,
             'f': 219,
             'it': 177,
             'was': 73,
             're': 183,
             'at': 91,
             'to': 126,
             'm': 397,
             ',': 229,
             'in': 162,
             '

In [190]:
encoded=bpe.encode('hello world. tallest')

In [191]:
encoded

[32, 40, 40, 44, 61, 47, 40, 23, 7, 56, 15, 40, 41, 55]

In [192]:
bpe.decode(encoded)

['he', 'l', 'l', 'o', 'w', 'or', 'l', 'd', '.', 't', 'a', 'l', 'le', 'st']