Byte Pair Encoding (BPE) (Gage, 1994) is a simple data compression technique that iteratively replaces the most frequent pair of bytes in a sequence with a single, unused byte. We adapt this algorithm for word segmentation. Instead of merging frequent pairs of bytes, we merge characters or character sequences.

Algorithm: 

1. Initialize the symbol vocabulary with the character vocabulary, and represent each word as a sequence of characters, plus a special end-of-word symbol ‘<\w>’, which allows us to restore the original tokenization after translation.
2. Iteratively count all symbol pairs and replace each occurrence of the most frequent pair (‘A’, ‘B’) with a new symbol ‘AB’.
3. Each merge operation produces a new symbol which represents a character n-gram. Frequent character n-grams (or whole words) are eventually merged into a single symbol, thus BPE requires no shortlist.
4. The final symbol vocabulary size is equal to the size of the initial vocabulary, plus the number of merge operations – the latter is the only hyperparameter of the algorithm.

In practice, we increase efficiency by indexing all pairs, and updating data structures incrementally.

In [57]:
import re, collections

In [71]:
vocab = {
        "l o w </w>": 5,
        "l o w e r </w>": 2,
        "w i d e s t </w>": 3,
        "n e w e s t </w>": 6
        }

In [59]:
def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        #print("symbols:", symbols)
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
            #print(pairs)
        #print("\n")
    return pairs

In [60]:
get_stats(vocab)

defaultdict(int,
            {('d', 'e'): 3,
             ('e', 'r'): 2,
             ('e', 's'): 9,
             ('e', 'w'): 6,
             ('i', 'd'): 3,
             ('l', 'o'): 7,
             ('n', 'e'): 6,
             ('o', 'w'): 7,
             ('r', '</w>'): 2,
             ('s', 't'): 9,
             ('t', '</w>'): 9,
             ('w', '</w>'): 5,
             ('w', 'e'): 8,
             ('w', 'i'): 3})

In [61]:
# The maximum length of pair is 2 because we used .split() which splits words into characters.

In [62]:
def merge_vocab(pair, v_in):
    v_out = {}
    bigram_pattern = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram_pattern + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

In [63]:
pairs = get_stats(vocab)

In [64]:
pairs

defaultdict(int,
            {('d', 'e'): 3,
             ('e', 'r'): 2,
             ('e', 's'): 9,
             ('e', 'w'): 6,
             ('i', 'd'): 3,
             ('l', 'o'): 7,
             ('n', 'e'): 6,
             ('o', 'w'): 7,
             ('r', '</w>'): 2,
             ('s', 't'): 9,
             ('t', '</w>'): 9,
             ('w', '</w>'): 5,
             ('w', 'e'): 8,
             ('w', 'i'): 3})

In [65]:
best = max(pairs, key=pairs.get)

In [66]:
best

('e', 's')

In [67]:
vocab = merge_vocab(best, vocab)

In [68]:
print(vocab)

{'l o w </w>': 5, 'l o w e r </w>': 2, 'w i d es t </w>': 3, 'n e w es t </w>': 6}


In [72]:
num_merges = 20
for i in range(num_merges):
    pairs = get_stats(vocab)
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print(vocab)

{'l o w </w>': 5, 'l o w e r </w>': 2, 'w i d es t </w>': 3, 'n e w es t </w>': 6}
{'l o w </w>': 5, 'l o w e r </w>': 2, 'w i d est </w>': 3, 'n e w est </w>': 6}
{'l o w </w>': 5, 'l o w e r </w>': 2, 'w i d est</w>': 3, 'n e w est</w>': 6}
{'lo w </w>': 5, 'lo w e r </w>': 2, 'w i d est</w>': 3, 'n e w est</w>': 6}
{'low </w>': 5, 'low e r </w>': 2, 'w i d est</w>': 3, 'n e w est</w>': 6}
{'low </w>': 5, 'low e r </w>': 2, 'w i d est</w>': 3, 'ne w est</w>': 6}
{'low </w>': 5, 'low e r </w>': 2, 'w i d est</w>': 3, 'new est</w>': 6}
{'low </w>': 5, 'low e r </w>': 2, 'w i d est</w>': 3, 'newest</w>': 6}
{'low</w>': 5, 'low e r </w>': 2, 'w i d est</w>': 3, 'newest</w>': 6}
{'low</w>': 5, 'low e r </w>': 2, 'wi d est</w>': 3, 'newest</w>': 6}
{'low</w>': 5, 'low e r </w>': 2, 'wid est</w>': 3, 'newest</w>': 6}
{'low</w>': 5, 'low e r </w>': 2, 'widest</w>': 3, 'newest</w>': 6}
{'low</w>': 5, 'lowe r </w>': 2, 'widest</w>': 3, 'newest</w>': 6}
{'low</w>': 5, 'lower </w>': 2, 'widest</

ValueError: max() arg is an empty sequence

In [None]:
# The final symbol vocabulary size is equal to the size of the initial vocabulary