# A simple walkthrough for training your very own BPE model

## Load your training corpus

We'll be using a dummy training corpus with 4 lines of text in `ex_corpus.txt`. Here's how it looks:

In [1]:
with open("ex_corpus.txt", "r") as f:
    print(f.read())

Trying to learn about BPE
I'm learning about byte-pair encoding
My friend learnt that digram coding and byte pair encoding mean the same
I love Jacques Cousteau


## Extract a list of words from the training corpus.

In [2]:
from collections import defaultdict
EOW_TOKEN = '</w>'
def get_initial_words(filename):
    segmented_word_to_freq = defaultdict(int) # {"w o r d </w>": 1, ...}
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f:
            words = line.strip().split() # words are split on whitespace; this is our pre-tokenization step
            for word in words:
                # Separate each character with spaces, to show tokens for clarity. "word" -> "w o r d </w>"
                segmented_word = ' '.join(list(word)) + " " +  EOW_TOKEN
                segmented_word_to_freq[segmented_word] += 1
    return segmented_word_to_freq

In [3]:
word_to_freq = get_initial_words('ex_corpus.txt')
print(word_to_freq)

defaultdict(<class 'int'>, {'T r y i n g </w>': 1, 't o </w>': 1, 'l e a r n </w>': 1, 'a b o u t </w>': 2, 'B P E </w>': 1, "I ' m </w>": 1, 'l e a r n i n g </w>': 1, 'b y t e - p a i r </w>': 1, 'e n c o d i n g </w>': 2, 'M y </w>': 1, 'f r i e n d </w>': 1, 'l e a r n t </w>': 1, 't h a t </w>': 1, 'd i g r a m </w>': 1, 'c o d i n g </w>': 1, 'a n d </w>': 1, 'b y t e </w>': 1, 'p a i r </w>': 1, 'm e a n </w>': 1, 't h e </w>': 1, 's a m e </w>': 1, 'I </w>': 1, 'l o v e </w>': 1, 'J a c q u e s </w>': 1, 'C o u s t e a u </w>': 1})


All the words are actually represented as a space-separated string of symbols/tokens (i.e the dictionary has _segmented_ words)- in this case, we start off with simple character-level tokenization, and add an end of word token as well to match the original BPE algorithm.

## Get initial vocabulary

In [4]:
def get_tokens(word_to_freq):
    tokens = defaultdict(int)
    for word in word_to_freq.keys():
        for token in word.split():
            if token not in tokens:
                tokens[token] = len(tokens)
    return tokens

vocab = get_tokens(word_to_freq)
print(vocab)

defaultdict(<class 'int'>, {'T': 0, 'r': 1, 'y': 2, 'i': 3, 'n': 4, 'g': 5, '</w>': 6, 't': 7, 'o': 8, 'l': 9, 'e': 10, 'a': 11, 'b': 12, 'u': 13, 'B': 14, 'P': 15, 'E': 16, 'I': 17, "'": 18, 'm': 19, '-': 20, 'p': 21, 'c': 22, 'd': 23, 'M': 24, 'f': 25, 'h': 26, 's': 27, 'v': 28, 'J': 29, 'q': 30, 'C': 31})


Our vocabulary is a mapping from token to a unique token ID

## Revisit the training algorithm

1. Extract a list of words from the training corpus. (**Done**)
2. Make a word counter dictionary with keys being words and values being frequencies in the training corpus. (**Done**)
3. Keep a vocabulary of symbols (variable length strings), initialized with unique characters present in the training corpus. (**Done**)

We now proceed to the core part of the algorithm:

4. Initialize a list of _merge rules_ to be an empty. Each merge rule is a tuple of symbol/token to merge at test time.
4. Iteratively do the following:
    - Get the most frequent pair of symbols in the current vocabulary, by going over the word counter. (Ex: `("l", "e")`)
    - Merge the two symbols into a new symbol, and _add_ this to the vocabulary. This is like a new byte in the original BPE except we have variable length strings (_character n-gram_). 
    - Add our tuple of symbols to the list of merge rules.
    - Replace all occurences of the pair of symbols with the new symbol. For example, a word, segmented as `("l", "e", "a", "r", "n")` becomes `("le", "a", "r", "n")`. 
    - Repeat until you reach the target vocabulary size. (a hyperparameter)

The training algorithm should look something like this:

In [5]:
# merges = []
# num_merges = N
# for i in range(num_merges):
#     pairs = <get statistics for frequencies of symbol pairs>
#     best_pair = max(pairs, key=pairs.get) -----> Do argmax for pairs, get the most frequent pair
#     word_to_freq = merge_word_splits(best_pair, word_to_freq) -----> Merge the most frequent pair and get the new words
#     new_token = ''.join(best_pair) -----> Get the new token
#     vocab[new_token] = len(vocab) -----> Add the new token to the vocab

The full algorithm is below:

In [6]:
import re
def get_stats(word_to_freq: dict):
    pairs = defaultdict(int)
    for word, freq in word_to_freq.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq 
    return pairs

def merge_word_splits(pair, v_in: dict):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

merges = []
print("Initial vocab: ", vocab)
print("##################")
num_merges = 15
print_until = 5
for i in range(num_merges):
    pairs = get_stats(word_to_freq)
    best_pair = max(pairs, key=pairs.get)
    word_to_freq = merge_word_splits(best_pair, word_to_freq)
    new_token = ''.join(best_pair)
    vocab[new_token] = len(vocab)
    merges.append(best_pair)
    if i < print_until:
        print(f"Iteration {i+1}")
        print("Best pair: ", best_pair)
        print("New token: ", new_token)
        print("All words: ", list(word_to_freq.keys()))
        print("##################")
    else:
        print(f"Iteration {i+1} done")
print("Final vocab: ", vocab)
print("All merges: ", merges)

Initial vocab:  defaultdict(<class 'int'>, {'T': 0, 'r': 1, 'y': 2, 'i': 3, 'n': 4, 'g': 5, '</w>': 6, 't': 7, 'o': 8, 'l': 9, 'e': 10, 'a': 11, 'b': 12, 'u': 13, 'B': 14, 'P': 15, 'E': 16, 'I': 17, "'": 18, 'm': 19, '-': 20, 'p': 21, 'c': 22, 'd': 23, 'M': 24, 'f': 25, 'h': 26, 's': 27, 'v': 28, 'J': 29, 'q': 30, 'C': 31})
##################
Iteration 1
Best pair:  ('i', 'n')
New token:  in
All words:  ['T r y in g </w>', 't o </w>', 'l e a r n </w>', 'a b o u t </w>', 'B P E </w>', "I ' m </w>", 'l e a r n in g </w>', 'b y t e - p a i r </w>', 'e n c o d in g </w>', 'M y </w>', 'f r i e n d </w>', 'l e a r n t </w>', 't h a t </w>', 'd i g r a m </w>', 'c o d in g </w>', 'a n d </w>', 'b y t e </w>', 'p a i r </w>', 'm e a n </w>', 't h e </w>', 's a m e </w>', 'I </w>', 'l o v e </w>', 'J a c q u e s </w>', 'C o u s t e a u </w>']
##################
Iteration 2
Best pair:  ('in', 'g')
New token:  ing
All words:  ['T r y ing </w>', 't o </w>', 'l e a r n </w>', 'a b o u t </w>', 'B P

Note that I've only printed out the full training state until 5 iterations. If you want to see all of it, change `print_until` to be `num_merges`. Before we analyse what happened, let's see how our initial set of words get tokenized with the final dictionary:

In [8]:
print(list(word_to_freq.keys()))

['T r y ing</w>', 't o </w>', 'learn </w>', 'ab ou t</w>', 'B P E </w>', "I ' m </w>", 'learn ing</w>', 'b y t e - p a i r </w>', 'en coding</w>', 'M y </w>', 'f r i en d </w>', 'learn t</w>', 't h a t</w>', 'd i g r a m </w>', 'coding</w>', 'a n d </w>', 'b y t e</w>', 'p a i r </w>', 'm ea n </w>', 't h e</w>', 's a m e</w>', 'I </w>', 'l o v e</w>', 'J a c q u e s </w>', 'C ou s t ea u </w>']


Observe what happens to the last line in the text: This is an out-of-place sentence compared to the above three. You're getting almost character-level tokenization for all the words here("love" gets 4 tokens), while words that are repeated ("learn", "encoding") get represented with 1/2 tokens. This is the essence of BPE: _outliers_ do _not_ get _compressed_ as much. When you have a large training corpus, it's very rare to get "outlier" English sentences, but you can imagine other data domains (code for example) that wasn't represented well in the training corpus, to be segmentd into a lot of tokens.