## 3.1 Byte Pair Encoding

Implement Byte Pair Encoding (BPE) from scratch using basic Python functionalities. <br>
Test the correctness of your implementation by reproducing the example in chapter 2.5.2 of "Speech and Language Processing"

from the book: Jurafsky, Daniel, and James H. Martin. *Speech and Language Processing: An Introduction to Natural Language Processing, Computational Linguistics, and Speech Recognition with Language Models*, 3rd edition. Online manuscript released January 12, 2025. [Available here](https://web.stanford.edu/~jurafsky/slp3).

a pseudocode representing the BPE algorithm is available:

## insert image here

In [130]:
from collections import Counter

def tokenizeCorpus(corpus: str) -> list[str]:
    """
    Function that tokenizes a corpus into individual characters, 
    then white-space-separates them to give a set of strings
    and adds a special end-of-word symbol.

    @param str corpus: Input corpus that should be split
    @return: List of tokenized words
    """
    # split corpus into words
    words = corpus.split(' ')

    # Add special end-of-word character '_'
    tokenizedCorpus = []
    for word in words:
        tokenizedCorpus.append(' '.join(list(word)) + ' _')

    return tokenizedCorpus

def MostFreqPair(C: list[str]) -> tuple[str]:
    """
    Returns the most frequent pair of adjacent tokens in C

    @param list[str] C: A list of tokenized words
    @return: The most frequent adjacent token pair (as a tuple)
    """
    pairs = {}
    # Count pairs of adjacent tokens
    for word in C:
        tokens = word.split()  # Split the word into tokens
        for i in range(len(tokens)-1):
            pair = (tokens[i], tokens[i+1])  # adjacent pair
            if pair in pairs:
                pairs[pair] += 1  # Increment frequency count
            else:
                pairs[pair] = 1  # Add new pair with count 1

    # Return the most frequent pair
    most_freq_pair = max(pairs, key=pairs.get)  # Get pair with max frequency
    return most_freq_pair

def bytePairEncoding(corpus:str, k: int):
    """
    Function that returns a list of k merged pairs of characters from the input string.

    @param corpus: The corpus
    @param integer k: The number of merges

    @return list: A list of k merged tokens (subwords)
    """

    C = tokenizeCorpus(corpus)
    # Initialize vocabulary with individual characters
    vocab = set()
    for word in C:
        vocab.update(word.split())  # Add individual tokens (characters + '_') to vocab

    # Merge tokens k times
    for i in range(k):
        # Step 1: Find the most frequent pair of adjacent tokens
        tl, tr = MostFreqPair(C)

        # Step 2: Create a new token by merging the pair
        tn = ''.join(tl) + ''.join(tr)

        # Step 3: Add the new token to the vocabulary
        vocab.add(tn)

        # Step 4: Update the corpus with the new token
        updated_corpus = []
        for word in C:
            # Replace the pair with the new merged token
            updated_word = word.replace(' '.join([tl, tr]), tn)
            updated_corpus.append(updated_word)
        
        # Update the corpus for the next iteration
        C = updated_corpus

        # Print the result of the current iteration
        prettyPrintMerge(C, i, list(vocab), [tl, tr])

    # Return the final vocabulary after k merges
    #return vocab

def prettyPrintMerge(C: list[str], i: int, vocab: list[str], current:tuple[str], detailed:bool=False) -> None:
    """
    A function to pretty-print each merge of the BPE algorithm.

    @param list[str] C: Corpus
    @param int k: The iteration number
    @param list[str] vocab: The vocabulary at the current iteration
    """
    if detailed:
        print(f'After {i+1} merge(s):')

    print(f'merge: \t\t{(current[0]), current[1]}')

    if detailed:
        print('Corpus:')
        wordFrequency = Counter(C)
        for word, freq in wordFrequency.items():
            print(f'{freq} \t {word}')
        print('Vocabulary \t'+', '.join(sorted(vocab)))

### Reproducing example in chapter 2.5.2

In [112]:
corpus = 'low low low low low lowest lowest newer newer newer newer newer newer wider wider wider new new'
k = 8
vocab = bytePairEncoding(corpus, k)

merge: 		('e', 'r')
merge: 		('er', '_')
merge: 		('n', 'e')
merge: 		('ne', 'w')
merge: 		('l', 'o')
merge: 		('lo', 'w')
merge: 		('new', 'er_')
merge: 		('low', '_')


## WordPiece

Implement *WordPiece* from scratch using basic python functionalities.

Reference: https://huggingface.co/learn/nlp-course/en/chapter6/6

In [134]:
def tokenizeCorpusWP(corpus:str):
    words = corpus.split()

    tokenizedCorpus= []
    for word in words:
        tokenizedWord = [word[0]]
        for letter in word[1:]:
            tokenizedWord.append(f'##{letter}')
        tokenizedCorpus.append(tokenizedWord)

    return tokenizedCorpus



def createVocab(tokenizedCorpus:str)-> set:
    """
    Function that returns a dictionary with token frequenciesfor WP 
    @param corpus: Corpus in string format
    @return: Dictionary with word frequencies
    """ 

    #tokenizedCorpus = tokenizeCorpusWP(corpus)

    tokenFreq = {}

    for tokens in tokenizedCorpus:
        for token in tokens:
            if token in tokenFreq:
                tokenFreq[token] += 1
            else:
                tokenFreq[token] = 1

    return tokenFreq


def getPairs(tokenizedCorpus:str) -> dict:
    """
    A simple function that splits the corpus into pairs
    """
    #tokenizedCorpus = tokenizeCorpusWP(corpus)


    pairs = {}
    for tokens in tokenizedCorpus:
        for i in range(len(tokens)-1):
            pair = (tokens[i], tokens[i+1])

            if pair in pairs:
                pairs[pair] += 1
            else: pairs[pair] = 1

    return pairs
    


def computePairScore(pair, pairCount, tokenFreq)-> float:
    t1, t2 = pair

    # get the frequency of each token
    t1freq = tokenFreq.get(t1, 0)
    t2freq = tokenFreq.get(t2, 0)

    if t1freq == 0 or t2freq == 0:
        return 0
        
    
    return pairCount / (t1freq * t2freq)


def computePairsScores(corpus:str) -> dict:

    pairCounts = getPairs(corpus)
    tokenFreq = createVocab(corpus)

    pairScore = {}

    for pair, pairCount in pairCounts.items():
        score = computePairScore(pair, pairCount, tokenFreq)
        pairScore[pair] = score

    return pairScore


def mergePair(pair, tokenizedCorpus, pairScores, tokenFreq):
    newToken = pair[0] + pair[1][2:]

    updatedCorpus = []

    for word in tokenizedCorpus:
        i = 0
        while i < len(word)-1:
            if (word[i], word[i+1]) == pair:
                word[i] = newToken
                del word[i+1]
            else:
                i += 1
        updatedCorpus.append(word)
    
    # Recompute token frequencies and pair scores for the updated corpus
    tokenFreq = createVocab(updatedCorpus)  # Recompute token frequencies
    pairScores = computePairsScores(updatedCorpus)  # Recompute pair scores
    
    return updatedCorpus, tokenFreq, pairScores




def wordPiece(corpus:str, numMerges:int=10):

    tokenizedCorpus = tokenizeCorpusWP(corpus)
    tokenFreq = createVocab(tokenizedCorpus)
    pairScores = computePairsScores(tokenizedCorpus)
    
    for i in range(numMerges):
        if not pairScores:
            print('no more pairs to merge')
            break

        highestScore = max(pairScores, key=pairScores.get)
        print(f'merging pair: {highestScore} with score {pairScores[highestScore]}')

        corpus, tokenFreq, pairScores = mergePair(highestScore, tokenizedCorpus, pairScores, tokenFreq)

In [132]:
wordPiece(corpus, 5)

merging pair: ('##s', '##t') with score 0.5
merging pair: ('w', '##i') with score 0.3333333333333333
merging pair: ('wi', '##d') with score 0.3333333333333333
merging pair: ('l', '##o') with score 0.14285714285714285
merging pair: ('lo', '##w') with score 0.06666666666666667
{'low': 7, '##e': 19, '##st': 2, 'n': 8, '##w': 8, '##r': 9, 'wid': 3}


### Comparing BPE with WordPiece

In [133]:
bytePairEncoding(corpus, 5)

merge: 		('e', 'r')
merge: 		('er', '_')
merge: 		('n', 'e')
merge: 		('ne', 'w')
merge: 		('l', 'o')


In [135]:
wordPiece(corpus, 5)

merging pair: ('##s', '##t') with score 0.5
merging pair: ('w', '##i') with score 0.3333333333333333
merging pair: ('wi', '##d') with score 0.3333333333333333
merging pair: ('l', '##o') with score 0.14285714285714285
merging pair: ('lo', '##w') with score 0.06666666666666667


### Try a different corpus to compare the two

In [136]:
corpus = 'lorem ipsum dolor si amet'