# Demonstration of byte-pair encoding

Code adapted from https://leimao.github.io/blog/Byte-Pair-Encoding/


In [1]:
import collections
import re
from nltk.corpus import brown

In [2]:
def get_vocab(word_gen):
    """Initialize vocabulary from corpus
    """
    vocab = collections.defaultdict(int)
    for word in word_gen:
        vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

def get_stats(vocab):
    """Calculate co-occurrence statistics for token bigrams
    """
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    """Merge two tokens and update the resulting vocabulary
    """
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def get_tokens_from_vocab(vocab):
    """Recover tokens from tokenized vocabulary
    """
    tokens_frequencies = collections.defaultdict(int)
    vocab_tokenization = {}
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens_frequencies[token] += freq
        vocab_tokenization[''.join(word_tokens)] = word_tokens
    return tokens_frequencies, vocab_tokenization

def measure_token_length(token):
    if token[-4:] == '</w>':
        return len(token[:-4]) + 1
    else:
        return len(token)


In [3]:
vocab = get_vocab(map(lambda x: x.lower(),
                      filter(lambda x: x.isalpha(), 
                             brown.words())))

print('==========')
print('Tokens Before BPE')
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
print('All tokens: {}'.format(tokens_frequencies.keys()))
print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
print('==========')


Tokens Before BPE
All tokens: dict_keys(['t', 'h', 'e', '</w>', 'f', 'u', 'l', 'o', 'n', 'c', 'y', 'g', 'r', 'a', 'd', 'j', 's', 'i', 'v', 'p', 'm', 'k', 'x', 'w', 'b', 'z', 'q'])
Number of tokens: 27


In [4]:
def step_bpe(v):
    pairs = get_stats(v)
    if not pairs:
        return v
    best = max(pairs, key=pairs.get)
    v = merge_vocab(best, v)
    return v, best

In [5]:
num_merges = 10
for i in range(num_merges):
    print('Iter: {}'.format(i))
    vocab, best = step_bpe(vocab)
    tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
    print('Best pair: {}'.format(best))
    print('All tokens: {}'.format(tokens_frequencies.keys()))
    print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
    print('==========')

Iter: 0
Best pair: ('e', '</w>')
All tokens: dict_keys(['t', 'h', 'e</w>', 'f', 'u', 'l', 'o', 'n', '</w>', 'c', 'y', 'g', 'r', 'a', 'd', 'j', 's', 'i', 'v', 'e', 'p', 'm', 'k', 'x', 'w', 'b', 'z', 'q'])
Number of tokens: 28
Iter: 1
Best pair: ('t', 'h')
All tokens: dict_keys(['th', 'e</w>', 'f', 'u', 'l', 't', 'o', 'n', '</w>', 'c', 'y', 'g', 'r', 'a', 'd', 'j', 's', 'i', 'v', 'e', 'p', 'm', 'k', 'x', 'w', 'h', 'b', 'z', 'q'])
Number of tokens: 29
Iter: 2
Best pair: ('s', '</w>')
All tokens: dict_keys(['th', 'e</w>', 'f', 'u', 'l', 't', 'o', 'n', '</w>', 'c', 'y', 'g', 'r', 'a', 'd', 'j', 's', 'i', 'v', 'e', 'p', 'm', 's</w>', 'k', 'x', 'w', 'h', 'b', 'z', 'q'])
Number of tokens: 30
Iter: 3
Best pair: ('d', '</w>')
All tokens: dict_keys(['th', 'e</w>', 'f', 'u', 'l', 't', 'o', 'n', '</w>', 'c', 'y', 'g', 'r', 'a', 'd</w>', 'j', 's', 'i', 'd', 'v', 'e', 'p', 'm', 's</w>', 'k', 'x', 'w', 'h', 'b', 'z', 'q'])
Number of tokens: 31
Iter: 4
Best pair: ('t', '</w>')
All tokens: dict_keys(['t

In [6]:
#list(tokens_frequencies.keys())
for tok, freq in tokens_frequencies.items():
    print(tok,freq)

the</w> 70025
f 108098
u 124718
l 188286
t 202416
on 59265
</w> 463087
c 142440
o 290817
n 107396
y 78777
g 89353
r 206423
an 71278
d</w> 103902
j 7215
s 174596
a 301633
i 248484
d 77033
in 87565
v 45674
e 302602
t</w> 91543
p 92998
m 115244
e</w> 130997
th 63074
s</w> 122162
k 29857
er 73869
x 9052
w 86528
h 119094
b 70812
z 4316
q 4977


In [7]:
num_merges = 500
for i in range(num_merges):
    vocab, best = step_bpe(vocab)
    if i % 100 == 0:
        print('Iter: {}'.format(i))
        print('Best pair: {}'.format(best))
        print('==========')

Iter: 0
Best pair: ('y', '</w>')
Iter: 100
Best pair: ('f', 'or')
Iter: 200
Best pair: ('m', '</w>')
Iter: 300
Best pair: ('th', 'ing</w>')
Iter: 400
Best pair: ('o', 'f')


In [8]:
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
sorted_tokens_tuple = sorted(tokens_frequencies.items(), key=lambda item: (measure_token_length(item[0]), item[1]), reverse=True)
sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]

print(sorted_tokens[:10])


['through</w>', 'ations</w>', 'tional</w>', 'before</w>', 'ation</w>', 'which</w>', 'there</w>', 'would</w>', 'their</w>', 'other</w>']


In [9]:
len(tokens_frequencies)

537

In [10]:
vocab_tokenization.get("mountains</w>")

['m', 'oun', 'ta', 'in', 's</w>']

In [11]:
vocab_tokenization.get("fictional</w>")

['f', 'ic', 'tional</w>']

In [12]:
vocab_tokenization.get("liveliness</w>")

['li', 'v', 'el', 'in', 'ess</w>']