The core of WordPiece is that we will consider the following ratio to be the guide for the next tokens to merge : 

$$R = \frac{f(AB)}{f(A)f(B)}$$

where $f$  is the frequency

In [63]:
from collections import defaultdict

In [64]:
training_corpus = "lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua ut enim ad minim veniam quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur excepteur sint occaecat cupidatat non proident sunt in culpa qui officia deserunt mollit anim id est laborum" 

In [65]:
words = training_corpus.split()

In [66]:
explode_word = lambda word: [word[0]] + ["##" + word[i] for i in range(1, len(word))]

In [67]:
words[0]

'lorem'

In [68]:
print("this is how we build the original vocaubulary")
explode_word(words[0])

this is how we build the original vocaubulary


['l', '##o', '##r', '##e', '##m']

In [69]:
def encode(training_corpus: str) -> dict:
    words = training_corpus.split()
    return [explode_word(word) for word in words]

In [70]:
E = encode(training_corpus)
E[:2]

[['l', '##o', '##r', '##e', '##m'], ['i', '##p', '##s', '##u', '##m']]

In [71]:
vocab = set().union(*[set(word) for word in E])
vocab

{'##a',
 '##b',
 '##c',
 '##d',
 '##e',
 '##f',
 '##g',
 '##h',
 '##i',
 '##l',
 '##m',
 '##n',
 '##o',
 '##p',
 '##q',
 '##r',
 '##s',
 '##t',
 '##u',
 '##x',
 'a',
 'c',
 'd',
 'e',
 'f',
 'i',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v'}

We need $f$ frequency for each individual token, so that we can compute the denominator 

In [72]:
def make_bigrams(L:list) -> list:
    return [list(zip(sub, sub[1:])) for sub in L]
tuples = make_bigrams(E)

In [73]:
tuples = make_bigrams(E)
tuples[:2]

[[('l', '##o'), ('##o', '##r'), ('##r', '##e'), ('##e', '##m')],
 [('i', '##p'), ('##p', '##s'), ('##s', '##u'), ('##u', '##m')]]

In [74]:
def generate_merge_token(p,s):
    if p.startswith("##"):
        return p + "+" + s + " = ##"+ (p + s).replace("#","")
    return p + "+" + s + " = " + p + s.replace("#","")

print(generate_merge_token("##a","b"))

##a+b = ##ab


In [75]:
# second order frequency map
# ie frequency of bigrams of tokens in the corpus
SF = defaultdict(int)
for W in tuples : # word 
    for B in W : # bigrams in word
        SF[generate_merge_token(*B)] += 1
print("We are looking for which contiguous tokens we can merge")
print("frequency of merges : ", dict(SF))

We are looking for which contiguous tokens we can merge
frequency of merges :  {'l+##o = lo': 1, '##o+##r = ##or': 9, '##r+##e = ##re': 6, '##e+##m = ##em': 2, 'i+##p = ip': 1, '##p+##s = ##ps': 1, '##s+##u = ##su': 1, '##u+##m = ##um': 3, 'd+##o = do': 5, '##o+##l = ##ol': 6, '##l+##o = ##lo': 4, 's+##i = si': 2, '##i+##t = ##it': 6, 'a+##m = am': 1, '##m+##e = ##me': 1, '##e+##t = ##et': 2, 'c+##o = co': 3, '##o+##n = ##on': 4, '##n+##s = ##ns': 2, '##s+##e = ##se': 4, '##e+##c = ##ec': 2, '##c+##t = ##ct': 1, '##t+##e = ##te': 4, '##t+##u = ##tu': 2, '##u+##r = ##ur': 4, 'a+##d = ad': 2, '##d+##i = ##di': 2, '##i+##p = ##ip': 2, '##p+##i = ##pi': 2, '##i+##s = ##is': 5, '##s+##c = ##sc': 1, '##c+##i = ##ci': 4, '##i+##n = ##in': 3, '##n+##g = ##ng': 1, 'e+##l = el': 1, '##l+##i = ##li': 5, 's+##e = se': 1, '##e+##d = ##ed': 1, 'e+##i = ei': 1, '##i+##u = ##iu': 1, '##u+##s = ##us': 1, '##s+##m = ##sm': 1, '##m+##o = ##mo': 2, '##o+##d = ##od': 2, 't+##e = te': 1, '##m+##p = ##mp': 1

In [76]:
bi = [(merg,freq) for merg,freq in SF.items()]
bi = sorted(bi , key=lambda x:-x[1])
print("bigrams sorted by descending frequency")
print(bi)

bigrams sorted by descending frequency
[('##o+##r = ##or', 9), ('##a+##t = ##at', 8), ('##r+##e = ##re', 6), ('##o+##l = ##ol', 6), ('##i+##t = ##it', 6), ('d+##o = do', 5), ('##i+##s = ##is', 5), ('##l+##i = ##li', 5), ('##n+##t = ##nt', 5), ('##l+##o = ##lo', 4), ('##o+##n = ##on', 4), ('##s+##e = ##se', 4), ('##t+##e = ##te', 4), ('##u+##r = ##ur', 4), ('##c+##i = ##ci', 4), ('i+##n = in', 4), ('##i+##d = ##id', 4), ('##n+##i = ##ni', 4), ('##i+##a = ##ia', 4), ('##u+##i = ##ui', 4), ('##r+##u = ##ru', 4), ('##l+##l = ##ll', 4), ('##u+##m = ##um', 3), ('c+##o = co', 3), ('##i+##n = ##in', 3), ('##u+##n = ##un', 3), ('u+##t = ut', 3), ('l+##a = la', 3), ('##a+##b = ##ab', 3), ('##b+##o = ##bo', 3), ('##q+##u = ##qu', 3), ('##i+##m = ##im', 3), ('##e+##n = ##en', 3), ('e+##x = ex', 3), ('##e+##r = ##er', 3), ('##t+##a = ##ta', 3), ('##r+##i = ##ri', 3), ('##e+##m = ##em', 2), ('s+##i = si', 2), ('##e+##t = ##et', 2), ('##n+##s = ##ns', 2), ('##e+##c = ##ec', 2), ('##t+##u = ##tu', 2),

In [77]:
top_merge_byte_pair_encoding = bi[0]
print("This is the token that BPE would choose next" )
top_merge_byte_pair_encoding

This is the token that BPE would choose next


('##o+##r = ##or', 9)

Now for wordpiece we have to score each AB with the underlying A and B

In [82]:
def get_a_b_from_ab(code:str) -> tuple:
    return code.split("+")[0] , code.split("+")[1].split(" = ")[0]

get_a_b_from_ab(top_merge_byte_pair_encoding[0])

('##o', '##r')

In [83]:
F

NameError: name 'F' is not defined

In [80]:
get_a_b_from_ab_operands = get_a_b_from_ab.__code__.co_varnames

NameError: name 'get_a_b_from_ab' is not defined

In [78]:
new_representation

<function __main__.new_representation(pre, suff)>

In [79]:
apply_merge(encoded_corpus, top_merge):
    return [new_representation(*T) for T in encoded_corpus]


SyntaxError: invalid syntax (1255235206.py, line 1)

In [None]:
E

[['l', '##o', '##r', '##e', '##m'],
 ['i', '##p', '##s', '##u', '##m'],
 ['d', '##o', '##l', '##o', '##r'],
 ['s', '##i', '##t'],
 ['a', '##m', '##e', '##t'],
 ['c', '##o', '##n', '##s', '##e', '##c', '##t', '##e', '##t', '##u', '##r'],
 ['a', '##d', '##i', '##p', '##i', '##s', '##c', '##i', '##n', '##g'],
 ['e', '##l', '##i', '##t'],
 ['s', '##e', '##d'],
 ['d', '##o'],
 ['e', '##i', '##u', '##s', '##m', '##o', '##d'],
 ['t', '##e', '##m', '##p', '##o', '##r'],
 ['i', '##n', '##c', '##i', '##d', '##i', '##d', '##u', '##n', '##t'],
 ['u', '##t'],
 ['l', '##a', '##b', '##o', '##r', '##e'],
 ['e', '##t'],
 ['d', '##o', '##l', '##o', '##r', '##e'],
 ['m', '##a', '##g', '##n', '##a'],
 ['a', '##l', '##i', '##q', '##u', '##a'],
 ['u', '##t'],
 ['e', '##n', '##i', '##m'],
 ['a', '##d'],
 ['m', '##i', '##n', '##i', '##m'],
 ['v', '##e', '##n', '##i', '##a', '##m'],
 ['q', '##u', '##i', '##s'],
 ['n', '##o', '##s', '##t', '##r', '##u', '##d'],
 ['e',
  '##x',
  '##e',
  '##r',
  '##c',
  '##i'

If we want to be faster we have to use a trie

In [None]:
potential_merges = [ new_representation(p,s) for (p,s) in L for L in make_tuples(E)]


NameError: name 'L' is not defined

In [None]:
make_tuples(E)

[<zip at 0x1204ee140>,
 <zip at 0x1204ee300>,
 <zip at 0x1204ee880>,
 <zip at 0x1204ec500>,
 <zip at 0x1204ec180>,
 <zip at 0x1204eeb80>,
 <zip at 0x1204ef740>,
 <zip at 0x117c8ad00>,
 <zip at 0x117c88d80>,
 <zip at 0x1204ec880>,
 <zip at 0x117c8a640>,
 <zip at 0x117c8ac40>,
 <zip at 0x1117a6780>,
 <zip at 0x1117a6900>,
 <zip at 0x1117a6480>,
 <zip at 0x1117a7c40>,
 <zip at 0x1117a7400>,
 <zip at 0x1117a7600>,
 <zip at 0x1117a6540>,
 <zip at 0x1117a6bc0>,
 <zip at 0x12054bb80>,
 <zip at 0x12054a380>,
 <zip at 0x1204f2d40>,
 <zip at 0x1204f0e80>,
 <zip at 0x1204f1240>,
 <zip at 0x1204f3980>,
 <zip at 0x1204f1d40>,
 <zip at 0x1204f0f40>,
 <zip at 0x1204f3900>,
 <zip at 0x1204f0200>,
 <zip at 0x1204f33c0>,
 <zip at 0x1204f0900>,
 <zip at 0x1204f2d80>,
 <zip at 0x1204f1540>,
 <zip at 0x1204f1480>,
 <zip at 0x1204f1880>,
 <zip at 0x1204f3e40>,
 <zip at 0x1204f3c80>,
 <zip at 0x1204f1100>,
 <zip at 0x1204f06c0>,
 <zip at 0x1204f2140>,
 <zip at 0x1204f3a40>,
 <zip at 0x1204f1dc0>,
 <zip at 0x

In [None]:
def freq_map(encoded_training_corpus: list) -> dict:
    freq_map = {}
    for word in encoded_training_corpus:
        for token in word:
            if token in freq_map:
                freq_map[token] += 1
            else:
                freq_map[token] = 1
    return freq_map

NUM_STEPS = 1

E = encode(training_corpus)

for step in range(NUM_STEPS):
    # first order frequency
    F = freq_map(E)
    # second order frequency
#     F2 = freq_map(encode(" ".join(encode(training_corpus))))
#     for word in training_corpus.split():
#         for token in explode_word(word):
#             if F[token] < THRESHOLD:
#                 training_corpus = training_corpus.replace(word, token)
#                 break
# F = freq_map(encode(training_corpus))
F

{'l': 4,
 '##o': 27,
 '##r': 21,
 '##e': 27,
 '##m': 14,
 'i': 7,
 '##p': 9,
 '##s': 14,
 '##u': 25,
 'd': 7,
 '##l': 18,
 's': 4,
 '##i': 35,
 '##t': 31,
 'a': 7,
 'c': 6,
 '##n': 20,
 '##c': 10,
 '##d': 12,
 '##g': 3,
 'e': 11,
 't': 1,
 'u': 4,
 '##a': 22,
 '##b': 3,
 'm': 3,
 '##q': 3,
 'v': 3,
 'q': 2,
 'n': 4,
 '##x': 3,
 'r': 1,
 '##h': 1,
 'f': 1,
 'p': 2,
 'o': 2,
 '##f': 2}

In [None]:
def tokenize(input_seq, vocabulary):
    """ we will break playing into p, ##l

