The core of WordPiece is that we will consider the following ratio to be the guide for the next tokens to merge : 

$$R = \frac{f(AB)}{f(A)f(B)}$$

where $f$  is the frequency

In [31]:
from collections import defaultdict

In [19]:
training_corpus = "lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua ut enim ad minim veniam quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur excepteur sint occaecat cupidatat non proident sunt in culpa qui officia deserunt mollit anim id est laborum" 

In [20]:
words = training_corpus.split()

In [21]:
explode_word = lambda word: [word[0]] + ["##" + word[i] for i in range(1, len(word))]

In [22]:
words[0]

'lorem'

In [23]:
print("this is how we build the original vocaubulary")
explode_word(words[0])

this is how we build the original vocaubulary


['l', '##o', '##r', '##e', '##m']

In [24]:
def encode(training_corpus: str) -> dict:
    words = training_corpus.split()
    return [explode_word(word) for word in words]

In [25]:
E = encode(training_corpus)
E [:2]

[['l', '##o', '##r', '##e', '##m'], ['i', '##p', '##s', '##u', '##m']]

We need $f$ frequency for each individual token, so that we can compute the denominator 

In [26]:
def make_tuples(L:list) -> list:
    return [list(zip(sub, sub[1:])) for sub in L]
tuples = make_tuples(E)


In [27]:
def new_representation(pre,suff):
    if pre.startswith("##"):
        return "##"+ (pre + suff).replace("#","")
    return pre + suff.replace("#","")

In [30]:
tuples = make_tuples(E)
tuples[:2]

[[('l', '##o'), ('##o', '##r'), ('##r', '##e'), ('##e', '##m')],
 [('i', '##p'), ('##p', '##s'), ('##s', '##u'), ('##u', '##m')]]

In [32]:
# second order frequency map
# ie frequency of bigrams of tokens in the corpus
SF = defaultdict(int)
for W in tuples : # word 
    for T in W : # tokens in word
        SF[new_representation(*T)] += 1
SF

defaultdict(int,
            {'lo': 1,
             '##or': 9,
             '##re': 6,
             '##em': 2,
             'ip': 1,
             '##ps': 1,
             '##su': 1,
             '##um': 3,
             'do': 5,
             '##ol': 6,
             '##lo': 4,
             'si': 2,
             '##it': 6,
             'am': 1,
             '##me': 1,
             '##et': 2,
             'co': 3,
             '##on': 4,
             '##ns': 2,
             '##se': 4,
             '##ec': 2,
             '##ct': 1,
             '##te': 4,
             '##tu': 2,
             '##ur': 4,
             'ad': 2,
             '##di': 2,
             '##ip': 2,
             '##pi': 2,
             '##is': 5,
             '##sc': 1,
             '##ci': 4,
             '##in': 3,
             '##ng': 1,
             'el': 1,
             '##li': 5,
             'se': 1,
             '##ed': 1,
             'ei': 1,
             '##iu': 1,
             '##us': 1,
             '##sm'

If we want to be faster we have to use a trie

In [None]:
potential_merges = [ new_representation(p,s) for (p,s) in L for L in make_tuples(E)]


NameError: name 'L' is not defined

In [None]:
make_tuples(E)

[<zip at 0x1204ee140>,
 <zip at 0x1204ee300>,
 <zip at 0x1204ee880>,
 <zip at 0x1204ec500>,
 <zip at 0x1204ec180>,
 <zip at 0x1204eeb80>,
 <zip at 0x1204ef740>,
 <zip at 0x117c8ad00>,
 <zip at 0x117c88d80>,
 <zip at 0x1204ec880>,
 <zip at 0x117c8a640>,
 <zip at 0x117c8ac40>,
 <zip at 0x1117a6780>,
 <zip at 0x1117a6900>,
 <zip at 0x1117a6480>,
 <zip at 0x1117a7c40>,
 <zip at 0x1117a7400>,
 <zip at 0x1117a7600>,
 <zip at 0x1117a6540>,
 <zip at 0x1117a6bc0>,
 <zip at 0x12054bb80>,
 <zip at 0x12054a380>,
 <zip at 0x1204f2d40>,
 <zip at 0x1204f0e80>,
 <zip at 0x1204f1240>,
 <zip at 0x1204f3980>,
 <zip at 0x1204f1d40>,
 <zip at 0x1204f0f40>,
 <zip at 0x1204f3900>,
 <zip at 0x1204f0200>,
 <zip at 0x1204f33c0>,
 <zip at 0x1204f0900>,
 <zip at 0x1204f2d80>,
 <zip at 0x1204f1540>,
 <zip at 0x1204f1480>,
 <zip at 0x1204f1880>,
 <zip at 0x1204f3e40>,
 <zip at 0x1204f3c80>,
 <zip at 0x1204f1100>,
 <zip at 0x1204f06c0>,
 <zip at 0x1204f2140>,
 <zip at 0x1204f3a40>,
 <zip at 0x1204f1dc0>,
 <zip at 0x

In [None]:
def freq_map(encoded_training_corpus: list) -> dict:
    freq_map = {}
    for word in encoded_training_corpus:
        for token in word:
            if token in freq_map:
                freq_map[token] += 1
            else:
                freq_map[token] = 1
    return freq_map

NUM_STEPS = 1

E = encode(training_corpus)

for step in range(NUM_STEPS):
    # first order frequency
    F = freq_map(E)
    # second order frequency
#     F2 = freq_map(encode(" ".join(encode(training_corpus))))
#     for word in training_corpus.split():
#         for token in explode_word(word):
#             if F[token] < THRESHOLD:
#                 training_corpus = training_corpus.replace(word, token)
#                 break
# F = freq_map(encode(training_corpus))
F

{'l': 4,
 '##o': 27,
 '##r': 21,
 '##e': 27,
 '##m': 14,
 'i': 7,
 '##p': 9,
 '##s': 14,
 '##u': 25,
 'd': 7,
 '##l': 18,
 's': 4,
 '##i': 35,
 '##t': 31,
 'a': 7,
 'c': 6,
 '##n': 20,
 '##c': 10,
 '##d': 12,
 '##g': 3,
 'e': 11,
 't': 1,
 'u': 4,
 '##a': 22,
 '##b': 3,
 'm': 3,
 '##q': 3,
 'v': 3,
 'q': 2,
 'n': 4,
 '##x': 3,
 'r': 1,
 '##h': 1,
 'f': 1,
 'p': 2,
 'o': 2,
 '##f': 2}

In [None]:
def tokenize(input_seq, vocabulary):
    """ we will break playing into p, ##l

