# Tokenisation 
## Byte Pair Encoding
Ref:[Nural Machine Translation of Rare Words with Subword Units](https://aclanthology.org/P16-1162.pdf)


Byte pair encoding ([Gage 1994](http://www.pennelynn.com/Documents/CUJ/HTML/94HTML/19940045.HTM)) was initially developed as an algorithm to encode/compress a text based on the **most frequently occuring bytes**(a byte or 8 bits refers to a single character token for practical usage). 

The algorithm merges most frequently occuring consecutive bytes and replaces them with a new representative token that is not part of the existing lexicon.

In [16]:
import re, collections

def get_stats(vocab):
    # Initialize a defaultdict to count pairs of consecutive symbols
    pairs = collections.defaultdict(int)
    
    # Iterate over each word and its frequency in the vocabulary
    for word, freq in vocab.items():
        # Split the word into symbols (assuming space-separated symbols)
        symbols = word.split()
        # Iterate over the symbols to count pairs of consecutive symbols
        for i in range(len(symbols)-1):
            # Increment the count of the pair by the frequency of the word
            pairs[symbols[i],symbols[i+1]] += freq
            
    # Return the dictionary of pairs and their frequencies
    return pairs

def merge_vocab(pair, v_in):
    # Initilize an empty dictionary for the new vocabulary
    v_out = {}
    # Create a bigram string from the pair, esacping any special characters
    bigram = re.escape(' '.join(pair))
    # Compile a regular expression pattern to match the bigram as a whole word
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    
    # Iterate over each word in the input vocabulary
    for word in v_in:
        # Replace occurances of the bigram with the merged pair
        w_out = p.sub(''.join(pair), word)
        # Update the new vocabulary with the modified word and its frequency
        v_out[w_out] = v_in[word]
        
    # Return the new vocabulary
    return v_out

In [10]:
def token_learner_bpe(vocab, num_merges):
    merge_rules = []
    for _ in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)
        merge_rules.append(best)
    return merge_rules

In [11]:
def token_segmenter_bpe(test_word, merge_rules):
    # Split into characters and add end-of-word symbol
    symbols = list(test_word) + ['</w>']
    while True:
        pairs = [(symbols[i], symbols[i+1]) for i in range(len(symbols) - 1)]
        candidates = [pair for pair in pairs if pair in merge_rules]
        if not candidates:
            break
        best_pair = min(candidates, key=merge_rules.index)
        i = pairs.index(best_pair)
        symbols = symbols[:i] + [''.join(best_pair)] + symbols[i + 2:]
    return symbols

In [15]:
# Training Vocabulary
train_vocab = {'l o w </w>': 5, 'l o w e r </w>': 2,
               'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
# Learn merge rules
num_merges = 15
merge_rules = token_learner_bpe(train_vocab, num_merges)
print("Learned Merge Rules:")
print(merge_rules)

# Test the segmenter on new data
test_data = ['lowest', 'newest', 'widest']
print("\nSegmented Test Data:")
for word in test_data:
    segmented_word = token_segmenter_bpe(word, merge_rules)
    print(f"{word} -> {' '.join(segmented_word)}")

Learned Merge Rules:
[('e', 's'), ('es', 't'), ('est', '</w>'), ('l', 'o'), ('lo', 'w'), ('n', 'e'), ('ne', 'w'), ('new', 'est</w>'), ('low', '</w>'), ('w', 'i'), ('wi', 'd'), ('wid', 'est</w>'), ('low', 'e'), ('lowe', 'r'), ('lower', '</w>')]

Segmented Test Data:
lowest -> low est</w>
newest -> newest</w>
widest -> widest</w>
