# Byte-Pair Embedding (WordPiece)

...

In [1]:
import collections

vocabs = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
          'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]']

...

In [2]:
original_words = {'low_' : 5, 'lower_' : 2, 'newest_' : 6, 'widest_' : 3}
words = {}
for word, freq in original_words.items():
    new_word = ' '.join(list(word))
    words[new_word] = original_words[word]

...

In [3]:
def get_max_freq_pair(words):
    pairs = collections.defaultdict(int)
    for word, freq in words.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            # Occurrence of each adjacent unit
            pairs[symbols[i], symbols[i + 1]] += freq
    max_freq_pair = max(pairs, key = pairs.get)
    return max_freq_pair

...

In [4]:
def merge_vocab(max_freq_pair, words, vocabs):
    bigram = ' '.join(max_freq_pair)
    vocabs.append(''.join(max_freq_pair))
    words_out = {}
    for word, freq in words.items():
        new_word = word.replace(bigram, ''.join(max_freq_pair))
        words_out[new_word] = words[word]
    return words_out

...

In [5]:
num_merges = 10
for i in range(num_merges):
    max_freq_pair = get_max_freq_pair(words)
    words = merge_vocab(max_freq_pair, words, vocabs)
    print("Merge #%d:" % (i + 1), max_freq_pair)

Merge #1: ('e', 's')
Merge #2: ('es', 't')
Merge #3: ('est', '_')
Merge #4: ('l', 'o')
Merge #5: ('lo', 'w')
Merge #6: ('n', 'e')
Merge #7: ('ne', 'w')
Merge #8: ('new', 'est_')
Merge #9: ('low', '_')
Merge #10: ('w', 'i')


...

In [6]:
print("Words:", list(original_words.keys()))
print("Wordpieces:", list(words.keys()))

Words: ['low_', 'lower_', 'newest_', 'widest_']
Wordpieces: ['low_', 'low e r _', 'newest_', 'wi d est_']


...

In [7]:
print("Vocabs:", vocabs)

Vocabs: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]', 'es', 'est', 'est_', 'lo', 'low', 'ne', 'new', 'newest_', 'low_', 'wi']


...

In [8]:
inputs = ['slow_', 'slowest_']
outputs = []
for word in inputs:
    start, end = 0, len(word)
    cur_output = []
    while start < len(word) and start < end:
        if word[start : end] in vocabs:
            cur_output.append(word[start : end])
            start = end
            end = len(word)
        else:
            end -= 1
    if start < len(word):
        cur_output.append('[UNK]')
    outputs.append(' '.join(cur_output))
print('Words:', inputs)
print('Wordpieces:', outputs)

Words: ['slow_', 'slowest_']
Wordpieces: ['s low_', 's low est_']


## Summary
...

## References

[1] Wu, Y., Schuster, M., Chen, Z., Le, Q. V., Norouzi, M., Macherey, W., ... & Klingner, J. (2016). Google's neural machine translation system: Bridging the gap between human and machine translation. arXiv preprint arXiv:1609.08144.

[2] Sennrich, R., Haddow, B., & Birch, A. (2015). Neural machine translation of rare words with subword units. arXiv preprint arXiv:1508.07909.