# Byte Pair Encoder

In [151]:
from collections import defaultdict
import re

class BytePairEncoder:
    def __init__(self, n_merges):
        self.vocab = {}
        self.n_merges = n_merges
        self.str_to_int=defaultdict(int)

    def get_stats(self, corpus):
        pairs = defaultdict(int)
        for word, freq in corpus.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs

    def merge_pairs(self, pair, corpus):
        new_corpus = {}
        bigram = re.escape(" ".join(pair))
        pattern = re.compile(r"\b" + bigram + r"\b")

        for word, freq in corpus.items():
            new_word = pattern.sub("".join(pair), word)
            new_corpus[new_word] = freq
        return new_corpus

    def generate_ids(self):
        tokens=sorted(self.vocab.keys())
        self.str_to_int={token: id for id, token in enumerate(tokens)}
        self.int_to_str={id: token for token, id in self.str_to_int.items()}

    def fit(self, corpus):
        corpus = {" ".join(word): freq for word, freq in corpus.items()}
        for _ in range(self.n_merges):
            pairs = self.get_stats(corpus)
            if not pairs:
                break
            best_pair = max(pairs, key=pairs.get)
            corpus = self.merge_pairs(best_pair, corpus)
            print(f"Merging {best_pair}")
        for word, freq in corpus.items():
            for token in word.split():
                self.vocab[token] = freq
        self.generate_ids()

    def encode(self, word):
        word = " ".join(word)
        for merged_token in sorted(self.vocab.keys(), key=len, reverse=True):
            word = word.replace(" ".join(merged_token), merged_token)
        return [self.str_to_int[w] for w in word.split()]
    def decode(self, word):
        return [self.int_to_str[w] for w in word]

In [152]:
with open('book.txt') as file:
    raw_text=file.read()

In [153]:
raw_text[:30]

'THE VERDICT\nJune 1908\n\nI had a'

In [154]:
preprocessed=re.split(r'([,.:;?_!"()\']|--|\s)', raw_text)
preprocessed = [r for r in preprocessed if r.strip()]

In [155]:
corpus=defaultdict(int)
for word in preprocessed:
    corpus[word.lower()]+=1

In [156]:
corpus['--']=0

In [157]:
bpe=BytePairEncoder(n_merges=25)
bpe.fit(corpus=corpus)

Merging ('h', 'e')
Merging ('i', 'n')
Merging ('t', 'he')
Merging ('h', 'a')
Merging ('o', 'u')
Merging ('r', 'e')
Merging ('a', 'n')
Merging ('i', 't')
Merging ('o', 'n')
Merging ('s', 't')
Merging ('e', 'd')
Merging ('i', 's')
Merging ('e', 'n')
Merging ('a', 's')
Merging ('in', 'g')
Merging ('t', 'o')
Merging ('o', 'f')
Merging ('e', 'r')
Merging ('an', 'd')
Merging ('a', 't')
Merging ('s', 'e')
Merging ('o', 'r')
Merging ('l', 'e')
Merging ('b', 'e')
Merging ('w', 'as')


In [158]:
bpe.vocab

{'the': 1,
 'v': 1,
 'er': 1,
 'd': 1,
 'i': 1,
 'c': 1,
 't': 1,
 'j': 2,
 'u': 1,
 'n': 1,
 'e': 1,
 '1': 1,
 '9': 1,
 '0': 1,
 '8': 1,
 'ha': 1,
 'a': 1,
 'l': 1,
 'w': 1,
 'y': 1,
 's': 1,
 'h': 1,
 'ou': 1,
 'g': 1,
 'k': 1,
 'is': 1,
 'b': 1,
 'r': 1,
 'he': 1,
 'p': 1,
 'en': 1,
 '-': 1,
 'o': 1,
 'f': 1,
 'it': 1,
 'was': 1,
 're': 1,
 'at': 1,
 'to': 1,
 'm': 1,
 ',': 229,
 'in': 1,
 'of': 4,
 'or': 1,
 'ed': 1,
 'ing': 1,
 'and': 1,
 'st': 1,
 'se': 1,
 'on': 3,
 '.': 212,
 '(': 3,
 'be': 1,
 ')': 3,
 '"': 137,
 'an': 1,
 'le': 1,
 "'": 94,
 ';': 22,
 'as': 1,
 'x': 1,
 'q': 3,
 ':': 21,
 '?': 24,
 '!': 24,
 'z': 1}

In [159]:
encoded=bpe.encode('hello world. tallest')

In [160]:
encoded

[32, 40, 40, 44, 61, 47, 40, 23, 7, 56, 15, 40, 41, 55]

In [161]:
bpe.decode(encoded)

['he', 'l', 'l', 'o', 'w', 'or', 'l', 'd', '.', 't', 'a', 'l', 'le', 'st']