# Byte Pair Encoding 
## Implementing the algorithm to better understand the moving parts

BPE does the follow steps recursively until the required vocab size is reached
- pre-tokenize to get individual words
- pad both sides of the word with `\w` 
- compute frequencies of pairs, the pair with the highest frequency gets merged into a single token and gets added to the dictionary 

### Data structures
- vocab: the set of unique tokens, needs to be updated after every new merge rule
- corpus: the set of tokenized words with their frequency 

In [None]:
from collections import Counter, defaultdict
from pprint import pprint as pp
import random

In [20]:
lang = """
I am the a hero. 
I am looking for a villain because I am bored. 
What is the point of being a hero if there is no villain. 
Then all I am is a person, and that is not fun enough. 
So, villain, where are you?
"""

In [None]:
lang = """
Entropy is a scientific concept, as well as a measurable physical property, that is most commonly associated with a state of disorder, randomness, or uncertainty. The term and the concept are used in diverse fields, from classical thermodynamics, where it was first recognized, to the microscopic description of nature in statistical physics, and to the principles of information theory. It has found far-ranging applications in chemistry and physics, in biological systems and their relation to life, in cosmology, economics, sociology, weather science, climate change, and information systems including the transmission of information in telecommunication.

The thermodynamic concept was referred to by Scottish scientist and engineer William Rankine in 1850 with the names thermodynamic function and heat-potential. In 1865, German physicist Rudolf Clausius, one of the leading founders of the field of thermodynamics, defined it as the quotient of an infinitesimal amount of heat to the instantaneous temperature. He initially described it as transformation-content, in German Verwandlungsinhalt, and later coined the term entropy from a Greek word for transformation. Referring to microscopic constitution and structure, in 1862, Clausius interpreted the concept as meaning disgregation.

A consequence of entropy is that certain processes are irreversible or impossible, aside from the requirement of not violating the conservation of energy, the latter being expressed in the first law of thermodynamics. Entropy is central to the second law of thermodynamics, which states that the entropy of isolated systems left to spontaneous evolution cannot decrease with time, as they always arrive at a state of thermodynamic equilibrium, where the entropy is highest.

Austrian physicist Ludwig Boltzmann explained entropy as the measure of the number of possible microscopic arrangements or states of individual atoms and molecules of a system that comply with the macroscopic condition of the system. He thereby introduced the concept of statistical disorder and probability distributions into a new field of thermodynamics, called statistical mechanics, and found the link between the microscopic interactions, which fluctuate about an average configuration, to the macroscopically observable behavior, in form of a simple logarithmic law, with a proportionality constant, the Boltzmann constant, that has become one of the defining universal constants for the modern International System of Units (SI).

In 1948, Bell Labs scientist Claude Shannon developed similar statistical concepts of measuring microscopic uncertainty and multiplicity to the problem of random losses of information in telecommunication signals. Upon John von Neumann's suggestion, Shannon named this entity of missing information in analogous manner to its use in statistical mechanics as entropy, and gave birth to the field of information theory. This description has been identified as a universal definition of the concept of entropy. 

"""

In [None]:
def pad_word_with_start_token(word:list):
    if word is None:
        return None
    else:
        return word + '#'
    
# removed punctuations
def remove_punctuation(word:list, punct:list = punct):
    if len(word)==0:
        return None
    
    if word[0] in punct:
        word = word.replace(word[0],'')
    if word[-1] in punct:
        word = word.replace(word[-1],'')
        
    word = pad_word_with_start_token(word)
    return word

In [145]:
l = [1,2,3]
l.append("a"+"b")
l

[1, 2, 3, 'ab']

In [151]:
punct = set(".?,! \n")

class BytePair:
    def __init__(
        self, 
        text: list, 
        punct:list, 
        vocab_size:int = 60,
        verbose:bool = False
        ):
        self.verbose = verbose
        self.vocab_size = vocab_size
        self.text = text
        self.merges = {}
        self.build_byte_pairs()
        
    def pp_values(self):
        rep = 10
        print("\n" ,"-"*rep, " word_freq ", "-"*rep, "\n")
        pp(self.word_freq)
        print("\n" ,"-"*rep, " base-tokens ", "-"*rep, "\n")
        pp(self.vocab)
        print("\n" ,"-"*rep, " char-splits ", "-"*rep, "\n")
        pp(self.char_splits)
        print("\n" ,"-"*rep, " bytepair-freq-table ", "-"*rep, "\n")
        pp(self.bytepair_table)
        print("\n" ,"-"*rep, " bytepair-freq-sorted-table ", "-"*rep, "\n")
        pp(self.find_most_freq_pair())
    
    def build_byte_pairs(self,):
        self.word_freq = self.build_corpus()
        self.vocab = self.extract_vocab()
        self.char_splits = self.extract_character_splits()
        
        while len(self.vocab)<self.vocab_size:
            if self.verbose:
                print(f"---- vocab size: {len(self.vocab)}")
            # compute pair frequencies to find most common pair to merge
            new_merge_pair = self.find_most_freq_pair(
               pair_freq = self.compute_pair_freq()
            )
            self.integrate_merge_pair(new_merge_pair)
    
    def integrate_merge_pair(self, merge_pair:tuple):
        joined_pair = "".join(merge_pair)
        self.merges[merge_pair] = joined_pair
        self.vocab.add(joined_pair)
        self.merge_pair_into_splits(merge_pair[0], merge_pair[1])
    
    def merge_pair_into_splits(self, c1:str, c2:str,):
        if self.verbose:
            print("----", f"new merge candidate: {(c1,c2)}", "----")
        for word in self.word_freq.keys():
            split = self.char_splits[word]
            new_split = []                      # the split is reconstructed and replaced. This variable is used to store, cleared at the start of new word
            merge_candidate = []                # intermediate list, to keep track of potential merge candidates while looping through the word
            merge_performed = False             # set to true if merge was performed on the word, helps with verbose output
            if len(split) == 1:
                continue
            if split is None:
                continue
                
            for c in split:
                if c == c1:
                    merge_candidate.append(c1)
                elif c == c2:
                    if len(merge_candidate) == 1:
                        new_split.append(c1+c2)
                        merge_performed = True
                        merge_candidate = []
                    else:
                        new_split.append(c)
                        continue

                else:
                    if len(merge_candidate) == 1:
                        new_split.extend(merge_candidate)
                    merge_candidate = []
                    new_split.append(c)
                
                    
            if self.verbose and merge_performed:
                print("new split: --> ", word, new_split)
            self.char_splits[word] = new_split
            
    def old_merge_pair_into_splits(self, c1:str, c2:str,):
        """
        Originally wrote this but this method failed when the same merge pair occured twice in the same word. 
        Listindex out of bounds error was thrown. 
        To solve this, rewrote as the above function. 
        """
        if self.verbose:
            print("----", f"new merge candidate: {(c1,c2)}", "----")
        for word in self.word_freq.keys():
            split = self.char_splits[word]
            if len(split) == 1:
                continue
            if split is None:
                continue
            for i in range(len(split)-1):
                try:
                    if split[i]==c1 and split[i+1]==c2:
                        split = split[:i] + [c1 + c2] + split[i+2:]
                except Exception as e:
                    print(e)
                    print("*"*10, "failed at", word, split, f"merge-pair: {(c1,c2)}", f"index: {i}", "*"*10)
                    
                    if self.verbose:
                        print("new split: --> ", word, split)
            self.char_splits[word] = split
        
        
    def build_corpus(self,) -> dict:
        tokens = [remove_punctuation(token) for token in self.text.lower().split()]
        return Counter([t for t in tokens])

    def extract_vocab(self,) -> set:
        return set(self.text.lower()) - punct

    def extract_character_splits(self) -> dict:
        return {word:[c for c in word] for word in self.word_freq.keys()}
    
    def compute_pair_freq(self) -> dict:
        pair_freqs = defaultdict(int)
        for word,freq in self.word_freq.items(): 
            split = self.char_splits[word]
            
            if len(split) == 1:
                continue
            
            for pair in zip(split,split[1:]):
                pair_freqs[pair] += freq
                
        return pair_freqs
    
    @staticmethod
    def find_most_freq_pair(pair_freq: dict[tuple,int]) -> tuple:
        
        return sorted(
            pair_freq.keys(), 
            key= lambda x: pair_freq[x], 
            reverse=True
        )[0]
            
    
    
    @staticmethod
    def build_bytepair_pair_freq(tokens:list, counter: Counter = None) -> dict:
        if counter is None:
            counter = Counter()
        for t in tokens:
            for char1,char2 in zip(t,t[1:]):
                counter.update(["".join(
                    (char1,char2)
                )])
        return counter
    

b = BytePair(lang, punct, verbose=True)
#b.pp_values()


---- vocab size: 38
---- new merge candidate: ('e', '#') ----
new split: -->  measurable# ['m', 'e', 'a', 's', 'u', 'r', 'a', 'b', 'l', 'e#']
new split: -->  state# ['s', 't', 'a', 't', 'e#']
new split: -->  the# ['t', 'h', 'e#']
new split: -->  are# ['a', 'r', 'e#']
new split: -->  diverse# ['d', 'i', 'v', 'e', 'r', 's', 'e#']
new split: -->  where# ['w', 'h', 'e', 'r', 'e#']
new split: -->  nature# ['n', 'a', 't', 'u', 'r', 'e#']
new split: -->  life# ['l', 'i', 'f', 'e#']
new split: -->  science# ['s', 'c', 'i', 'e', 'n', 'c', 'e#']
new split: -->  climate# ['c', 'l', 'i', 'm', 'a', 't', 'e#']
new split: -->  change# ['c', 'h', 'a', 'n', 'g', 'e#']
new split: -->  rankine# ['r', 'a', 'n', 'k', 'i', 'n', 'e#']
new split: -->  one# ['o', 'n', 'e#']
new split: -->  temperature# ['t', 'e', 'm', 'p', 'e', 'r', 'a', 't', 'u', 'r', 'e#']
new split: -->  he# ['h', 'e#']
new split: -->  structure# ['s', 't', 'r', 'u', 'c', 't', 'u', 'r', 'e#']
new split: -->  consequence# ['c', 'o', 'n', 's'

In [133]:
print(len(b.vocab))

60


In [53]:
punct = list(".?,!")
tokens = [remove_punctuation(w) for w in lang.lower().split()]
corpus = Counter(tokens)
tokenized_corpus = Counter()

    
for word in corpus.keys():
    tokenized_corpus.update(word)

corpus, tokenized_corpus

iter_frequencies = Counter()

for w in corpus:
    for merged_token in zip(w,w[1:]):
        merged_token = "".join(merged_token)
        iter_frequencies.update([merged_token])
        
for byte_pair in iter_frequencies.most_common(1):
    # add the pair to the tokenized_corpus 
    # remove the counts from the framgments of the byte pair, remove entry if count goes to 0
    ...
    

In [None]:
sorted_freq = {k,v for k,v in sorted(iter_frequencies.keys())}

(sentence: "anudeep went to the park")
corpus: {word: frequency}
- base dictionary, no computations done on this 
- {"anudeep":1, "went":1, "to":1, "the":1, "park":1}

tokenized_corpus: {word: byte_pair_tokenized_word}
- starts off as simple tokenization, over time the byte pair tokenization is updated here

base_tokens: set{characters_that_make_up_tokens}
- the set of base characters that make up all the tokens
- no computations done on this 
byte_pair_freq_table: {byte_pairs: freq}
- byte pairs and their frequency
- this table will be used to calculate the next merge candidates 

```
# Algorithm:

1) sort the byte pairs by frequency
2) use the last character, of the byte pair to 

for byte_pair in sorted(byte_pair_freq_table):
    if byte_pair[-1] == '#':
        continue
    else:
        for bp in byte_pair_freq_table:
            if bp[0] == byte_pair[-1]:
                

```


iter_freq: {merged_token: freq}

Algorithm:
- 