The core of WordPiece is that we will consider the following ratio to be the guide for the next tokens to merge : 

$$R = \frac{f(AB)}{f(A)f(B)}$$

where $f$  is the frequency

In [121]:
from collections import defaultdict
from typing_extensions import Tuple

In [122]:
training_corpus = "lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua ut enim ad minim veniam quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur excepteur sint occaecat cupidatat non proident sunt in culpa qui officia deserunt mollit anim id est laborum" 

In [123]:
words = training_corpus.split()

In [124]:
explode_word = lambda word: [word[0]] + ["##" + word[i] for i in range(1, len(word))]

In [125]:
words[0]

'lorem'

In [126]:
print("this is how we build the original vocaubulary")
explode_word(words[0])

this is how we build the original vocaubulary


['l', '##o', '##r', '##e', '##m']

In [127]:
def freq_map(encoded_training_corpus: list) -> dict:
    freq_map = {}
    for word in encoded_training_corpus:
        for token in word:
            freq_map[token] = freq_map.get(token, 0) + 1
    return freq_map

In [128]:
def encode(training_corpus: str) -> dict:
    words = training_corpus.split()
    return [explode_word(word) for word in words]

In [129]:
E = encode(training_corpus)
E[:2]

[['l', '##o', '##r', '##e', '##m'], ['i', '##p', '##s', '##u', '##m']]

In [130]:
F = freq_map(E)
list(F.items())[:5]

[('l', 4), ('##o', 27), ('##r', 21), ('##e', 27), ('##m', 14)]

In [131]:
vocab = set().union(*[set(word) for word in E])
list(vocab)[:5]

['t', 'v', '##a', '##n', 's']

We need $f$ frequency for each individual token, so that we can compute the denominator 

In [132]:
def make_bigrams(L:list) -> list:
    return [list(zip(sub, sub[1:])) for sub in L]
tuples = make_bigrams(E)
tuples[:2]

[[('l', '##o'), ('##o', '##r'), ('##r', '##e'), ('##e', '##m')],
 [('i', '##p'), ('##p', '##s'), ('##s', '##u'), ('##u', '##m')]]

In [133]:
def generate_merge_token(p,s):
    if p.startswith("##"):
        return p + "+" + s + " = ##"+ (p + s).replace("#","")
    return p + "+" + s + " = " + p + s.replace("#","")

print(generate_merge_token("##a","b"))

##a+b = ##ab


In [134]:
# second order frequency map
# ie frequency of bigrams of tokens in the corpus
SF = defaultdict(int)
for W in tuples : # word 
    for B in W : # bigrams in word
        SF[generate_merge_token(*B)] += 1
print("We are looking for which contiguous tokens we can merge")
print("frequency of merges : ", dict(SF))

We are looking for which contiguous tokens we can merge
frequency of merges :  {'l+##o = lo': 1, '##o+##r = ##or': 9, '##r+##e = ##re': 6, '##e+##m = ##em': 2, 'i+##p = ip': 1, '##p+##s = ##ps': 1, '##s+##u = ##su': 1, '##u+##m = ##um': 3, 'd+##o = do': 5, '##o+##l = ##ol': 6, '##l+##o = ##lo': 4, 's+##i = si': 2, '##i+##t = ##it': 6, 'a+##m = am': 1, '##m+##e = ##me': 1, '##e+##t = ##et': 2, 'c+##o = co': 3, '##o+##n = ##on': 4, '##n+##s = ##ns': 2, '##s+##e = ##se': 4, '##e+##c = ##ec': 2, '##c+##t = ##ct': 1, '##t+##e = ##te': 4, '##t+##u = ##tu': 2, '##u+##r = ##ur': 4, 'a+##d = ad': 2, '##d+##i = ##di': 2, '##i+##p = ##ip': 2, '##p+##i = ##pi': 2, '##i+##s = ##is': 5, '##s+##c = ##sc': 1, '##c+##i = ##ci': 4, '##i+##n = ##in': 3, '##n+##g = ##ng': 1, 'e+##l = el': 1, '##l+##i = ##li': 5, 's+##e = se': 1, '##e+##d = ##ed': 1, 'e+##i = ei': 1, '##i+##u = ##iu': 1, '##u+##s = ##us': 1, '##s+##m = ##sm': 1, '##m+##o = ##mo': 2, '##o+##d = ##od': 2, 't+##e = te': 1, '##m+##p = ##mp': 1

In [135]:
bi = [(merg,freq) for merg,freq in SF.items()]
bi = sorted(bi , key=lambda x:-x[1])
print("bigrams sorted by descending frequency")
print(bi)

bigrams sorted by descending frequency
[('##o+##r = ##or', 9), ('##a+##t = ##at', 8), ('##r+##e = ##re', 6), ('##o+##l = ##ol', 6), ('##i+##t = ##it', 6), ('d+##o = do', 5), ('##i+##s = ##is', 5), ('##l+##i = ##li', 5), ('##n+##t = ##nt', 5), ('##l+##o = ##lo', 4), ('##o+##n = ##on', 4), ('##s+##e = ##se', 4), ('##t+##e = ##te', 4), ('##u+##r = ##ur', 4), ('##c+##i = ##ci', 4), ('i+##n = in', 4), ('##i+##d = ##id', 4), ('##n+##i = ##ni', 4), ('##i+##a = ##ia', 4), ('##u+##i = ##ui', 4), ('##r+##u = ##ru', 4), ('##l+##l = ##ll', 4), ('##u+##m = ##um', 3), ('c+##o = co', 3), ('##i+##n = ##in', 3), ('##u+##n = ##un', 3), ('u+##t = ut', 3), ('l+##a = la', 3), ('##a+##b = ##ab', 3), ('##b+##o = ##bo', 3), ('##q+##u = ##qu', 3), ('##i+##m = ##im', 3), ('##e+##n = ##en', 3), ('e+##x = ex', 3), ('##e+##r = ##er', 3), ('##t+##a = ##ta', 3), ('##r+##i = ##ri', 3), ('##e+##m = ##em', 2), ('s+##i = si', 2), ('##e+##t = ##et', 2), ('##n+##s = ##ns', 2), ('##e+##c = ##ec', 2), ('##t+##u = ##tu', 2),

In [136]:
top_merge_byte_pair_encoding = bi[0]
print("This is the token that BPE would choose next" )
top_merge_byte_pair_encoding

This is the token that BPE would choose next


('##o+##r = ##or', 9)

Now for wordpiece we have to score each AB with the underlying A and B

In [137]:
total_counts = sum(F.values())
print("total number of tokens in the corpus : ", total_counts)

total number of tokens in the corpus :  369


In [138]:
def get_a_b_from_ab(code: str) -> Tuple[str, str]:
    return code.split("+")[0], code.split("+")[1].split(" = ")[0]

In [141]:
def compute_wp_score(merge_token:str, merge_count:int, F:dict) -> float:
    total_counts = sum(F.values())
    a,b = get_a_b_from_ab(merge_token)
    fa, fb = F.get(a)/total_counts, F.get(b)/total_counts
    fab = merge_count/total_counts
    return fab / (fa * fb)

In [142]:
merge_token, merge_count = top_merge_byte_pair_encoding
compute_wp_score(merge_token, merge_count, F)

5.857142857142858

In [146]:
wordpiece_scores = [compute_wp_score(x[0],x[1],F) for x in bi]

In [150]:
# token merging value score
bpe_versus_wordpiece_scores = list(zip(bi,wordpiece_scores))
bpe_versus_wordpiece_scores[:5]

[(('##o+##r = ##or', 9), 5.857142857142858),
 (('##a+##t = ##at', 8), 4.328445747800586),
 (('##r+##e = ##re', 6), 3.9047619047619055),
 (('##o+##l = ##ol', 6), 4.555555555555556),
 (('##i+##t = ##it', 6), 2.0405529953917054)]

In [158]:
wordpiece_preference_list = sorted(bpe_versus_wordpiece_scores, key=lambda x:-x[-1])
wordpiece_preference_list[:4]

[(('o+##f = of', 1), 92.25),
 (('##f+##f = ##ff', 1), 92.25),
 (('e+##x = ex', 3), 33.54545454545455),
 (('o+##c = oc', 1), 18.45)]

In [159]:
print(f"BPE selected the next token : {top_merge_byte_pair_encoding}")
print(f"wordpiece selected the next token : {wordpiece_preference_list[0]}")

BPE selected the next token : ('##o+##r = ##or', 9)
wordpiece selected the next token : (('o+##f = of', 1), 92.25)


In [None]:
Next we can use a trie