### WordPiece guideline

* Init vocab :
    * Characters at the beginning of each word
    * Characters in each word preceded by the prefix '##'
    * Special tokens
    * Compute pairs in each word with their respective frequency

* Objects : 
    * Vocab : list containing the vocab
    * Splits : dictionary containing pairs (updated) with their frequency in each word

* Iteration : 
    * Merge pairs with the highest score (arbitrary choice equality)
    * Until size_vocab reached

In [14]:
corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

### Code Review
**Defaultdict** : Python dictionnary which values can be initialized even if key not already present <br>
**set.update** : Insert multiple items at once in a set

In [15]:
from typing import List, Tuple
from collections import defaultdict as dd

def init_vocab(text: List[str]) -> Tuple[list, list, list]:
    special_tokens = ['[CLS]','[SEP]','[PAD]','[UNK]','[MASK]']
    vocab = set()
    word_freq = dd(int)
    splits = dd(list)

    for sentence in text:
        words = sentence.split(" ")
        for word in words:
            word_freq[word] += 1
            if word not in splits:
                splits[word] = [word[0]] + [f"##{c}" for c in word[1:]]
                vocab.update(splits[word])

    return special_tokens + list(vocab), word_freq, splits

# time complexity : O(nb of letters in the text)

vocab, word_freq, splits = init_vocab(corpus)

### Code Review

* sorted() : **key** param enables to filter sort
* list of (key,value) using **list(dict.items())**

In [16]:
from collections import defaultdict # for typing purposes

def compute_init_scores(splits: defaultdict, word_freq: defaultdict) -> Tuple[list,defaultdict]:
    scores = dd(int)
    pair_freq = dd(int)
    indiv_freq = dd(int)
    for word, token_list in splits.items():
        for i in range(1,len(token_list)):
            pair = (token_list[i-1] , token_list[i])
            indiv_freq[token_list[i-1]] += word_freq[word]
            pair_freq[pair] += word_freq[word]
        indiv_freq[token_list[-1]] += word_freq[word]
    
    for pair,freq in pair_freq.items():
        scores[pair] = freq / (indiv_freq[pair[0]] * indiv_freq[pair[1]])
    

    return sorted(list(scores.items()),key=lambda item: item[1]), indiv_freq

def compute_scores(splits: defaultdict, word_freq: defaultdict, indiv_freq: defaultdict) -> list:
    scores = dd(int)
    pair_freq = dd(int)
    for word, token_list in splits.items():
        for i in range(1,len(token_list)):
            pair = (token_list[i-1] , token_list[i])
            pair_freq[pair] += word_freq[word]
    
    for pair,freq in pair_freq.items():
        scores[pair] = freq / (indiv_freq[pair[0]] * indiv_freq[pair[1]])
    

    return sorted(list(scores.items()),key=lambda item: item[1])


# Time complexity : O(nb of letters in the text + n_pair*log(n_pair))
            
scores, indiv_freq = compute_init_scores(splits, word_freq)

In [None]:
def iterate(vocab: list, vocab_size: int, indiv_freq: defaultdict, word_freq: defaultdict, scores: list, splits: defaultdict) -> Tuple[list, defaultdict]:
    while len(vocab) < vocab_size:
        new_pair = scores[-1][0]
        token_a = new_pair[0]
        token_b = new_pair[1]
        new_token = token_a + token_b[2:]
        vocab.append(new_token) # remove '##' between tokens

        # merge and update indiv_freq
        for word,list_token in splits.items():
            for i in range(len(list_token)-1):
                if list_token[i] == token_a and list_token[i+1] == token_b:
                    # pair found
                    indiv_freq[token_a] -= word_freq[word]
                    indiv_freq[token_b] -= word_freq[word]
                    indiv_freq[new_token] += word_freq[word]
                    split_left = []
                    split_right = []
                    if i > 0:
                        split_left = splits[word][:i]
                    if i < i-2:
                        split_right = splits[word][i+2:]

                    splits[word] = split_left + [new_token] + split_right
        scores = compute_scores(splits, word_freq, indiv_freq)
    return scores, splits


scores, splits = iterate(vocab, 70, indiv_freq, word_freq, scores, splits)