# Byte Pair Encoding (BPE) from Scratch

Step 1: Prepare Training Data

In [1]:
# Import the corpus data from the text file 

# corpus = []
# with open('D:\\dev\\Llama_from_scrtach\\Reviews.txt', 'r', encoding='utf-8') as file:
#     for line in file:
#         corpus.append(line.strip())


corpus = [
    "This is the first document.",
    "This document is the second document.",
    "And this is the third one.",
    "Is this the first document?",
]

print("Training data[1]: ",corpus[1])  

Training data[1]:  This document is the second document.


Step 2: Initialize and Pre-tokenize

In [2]:
# add the chars to a set to get unique chars
unique_chars = set()
for line in corpus:
    for char in line:
        unique_chars.add(char)
        
# sort        
vocab = list(unique_chars)
vocab.sort()
vocab.append('</w>')  # end of word token
print("Vocabulary: ",vocab)
print("Vocabulary size: ",len(vocab))


            

Vocabulary:  [' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', '</w>']
Vocabulary size:  20


In [3]:
word_splits = {}
for doc in corpus:
    words = doc.split(' ')
    for word in words:
        if word:
            chars_list = list(word) + ['</w>']
            # Tuple for immutability
            word_tuple = tuple(chars_list)
            if word_tuple not in word_splits:
                word_splits[word_tuple] = 0
            word_splits[word_tuple] += 1
print("Pre-tokenized Word Frequencies: ",word_splits)
print("Pre-tokenized Word Frequencies size: ",len(word_splits))

Pre-tokenized Word Frequencies:  {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w>'): 1, ('A', 'n', 'd', '</w>'): 1, ('t', 'h', 'i', 's', '</w>'): 2, ('t', 'h', 'i', 'r', 'd', '</w>'): 1, ('o', 'n', 'e', '.', '</w>'): 1, ('I', 's', '</w>'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w>'): 1}
Pre-tokenized Word Frequencies size:  13


Helper Function ```get_pair_stats```
find the most frequent pair of symbols in the vocabulary.

In [4]:
import collections 

def get_pair_stats(splits):
    """Get the frequency of all pairs of adjacent symbols in the splits."""
    pairs_counts = collections.defaultdict(int)
    for word_tuple, freq in splits.items():
        symbols = list(word_tuple)
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i + 1])
            pairs_counts[pair] += freq
    return pairs_counts
print("Pairs Counts: ",get_pair_stats(word_splits))

Pairs Counts:  defaultdict(<class 'int'>, {('T', 'h'): 2, ('h', 'i'): 5, ('i', 's'): 7, ('s', '</w>'): 8, ('t', 'h'): 7, ('h', 'e'): 4, ('e', '</w>'): 4, ('f', 'i'): 2, ('i', 'r'): 3, ('r', 's'): 2, ('s', 't'): 2, ('t', '</w>'): 3, ('d', 'o'): 4, ('o', 'c'): 4, ('c', 'u'): 4, ('u', 'm'): 4, ('m', 'e'): 4, ('e', 'n'): 4, ('n', 't'): 4, ('t', '.'): 2, ('.', '</w>'): 3, ('s', 'e'): 1, ('e', 'c'): 1, ('c', 'o'): 1, ('o', 'n'): 2, ('n', 'd'): 2, ('d', '</w>'): 3, ('A', 'n'): 1, ('r', 'd'): 1, ('n', 'e'): 1, ('e', '.'): 1, ('I', 's'): 1, ('t', '?'): 1, ('?', '</w>'): 1})


Helper Function ```merge_pair```

In [5]:
def merge_pair(pair_to_merge, splits):
    """Merge the specified pair in the word splits."""
    new_splits = {}
    (first, second) = pair_to_merge
    merged_token = first + second
    for word_tuple, freq in splits.items():
        symbols = list(word_tuple)
        new_symbols = []
        i = 0
        while i < len(symbols):
            if i < len(symbols) - 1 and symbols[i] == first and symbols[i + 1] == second:
                new_symbols.append(merged_token)
                i += 2
            else:
                new_symbols.append(symbols[i])
                i += 1
        new_splits[tuple(new_symbols)] = freq
    return new_splits

def get_most_frequent_pair(pairs_counts):
    """
        Get the most frequent pair from the pairs counts.
    """
    return max(pairs_counts, key=pairs_counts.get)

print("Most Frequent Pair: ",get_most_frequent_pair(get_pair_stats(word_splits)))


Most Frequent Pair:  ('s', '</w>')


Step 3: Iterative BPE Merging Loop

In [8]:
num_merges = 15

merges = {}

current_splits = word_splits.copy()

print(f"Initial Word Splits: {current_splits}")

for i in range(num_merges):
    print(f"\n Merge Iteration {i + 1}/ {num_merges}")
    
    
    # Calculate pair frequency 
    pair_stats = get_pair_stats(current_splits)
    if not pair_stats:
        print("No more pairs to merge.")
        break
    
    # Print top 5 pairs for inspection
    sorted_pairs = sorted(pair_stats.items(), key=lambda item: item[1], reverse=True)
    print("Top 5 pairs: ", sorted_pairs[:5])
    
    # Find the best pair 
    best_pair = max(pair_stats, key=pair_stats.get)
    best_freq = pair_stats[best_pair]
    print(f"Best pair: {best_pair} with frequency: {best_freq}")
    
    # Merge the best pair
    current_splits = merge_pair(best_pair, current_splits)
    
    

Initial Word Splits: {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w>'): 1, ('A', 'n', 'd', '</w>'): 1, ('t', 'h', 'i', 's', '</w>'): 2, ('t', 'h', 'i', 'r', 'd', '</w>'): 1, ('o', 'n', 'e', '.', '</w>'): 1, ('I', 's', '</w>'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w>'): 1}

 Merge Iteration 1/ 15
Top 5 pairs:  [(('s', '</w>'), 8), (('i', 's'), 7), (('t', 'h'), 7), (('h', 'i'), 5), (('h', 'e'), 4)]
Best pair: ('s', '</w>') with frequency: 8

 Merge Iteration 2/ 15
Top 5 pairs:  [(('i', 's</w>'), 7), (('t', 'h'), 7), (('h', 'i'), 5), (('h', 'e'), 4), (('e', '</w>'), 4)]
Best pair: ('i', 's</w>') with frequency: 7

 Merge Iteration 3/ 15
Top 5 pairs:  [(('t', 'h'), 7), (('h', 'is</w>'), 4), (('h', 'e'), 4), (('e', '</w>'), 4), (('d', 'o'), 4)]
Best pair: (