In [1]:
import torch
import numpy as np
import pandas as pd 


In [2]:
import re, collections

def get_vocab(filename):
    vocab = collections.defaultdict(int)
    with open(filename, 'r', encoding='utf-8') as fhand:
        for line in fhand:
            words = line.strip().split()
            for word in words:
                vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def get_tokens(vocab):
    tokens = collections.defaultdict(int)
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens[token] += freq
    return tokens


In [4]:

# Get free book from Gutenberg
# !wget http://www.gutenberg.org/cache/epub/16457/pg16457.txt

In [3]:

# vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

vocab = get_vocab('pg16457.txt')

print('==========')
print('Tokens Before BPE')
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')




Tokens Before BPE
Tokens: defaultdict(<class 'int'>, {'T': 1610, 'h': 26094, 'e': 59147, '</w>': 101830, 'P': 780, 'r': 29545, 'o': 34988, 'j': 857, 'c': 13891, 't': 44253, 'G': 300, 'u': 13731, 'n': 32494, 'b': 7428, 'g': 8749, 'E': 901, 'B': 1163, 'k': 2726, 'f': 10469, 'A': 1381, 'l': 20632, 'd': 17576, 'M': 1206, ',': 8068, 'y': 8812, 'J': 80, 's': 28330, 'V': 104, 'i': 31434, 'a': 36693, 'w': 8134, 'm': 9812, 'v': 4880, '.': 4055, 'Y': 250, 'p': 8040, '-': 1128, 'L': 429, ':': 209, 'R': 369, 'D': 327, '6': 77, '2': 158, '0': 401, '5': 131, '[': 32, '#': 1, '1': 295, '4': 104, '7': 65, ']': 32, '*': 44, 'S': 860, 'O': 510, 'F': 422, 'H': 689, 'I': 1432, 'C': 863, 'U': 170, 'N': 796, 'K': 42, '/': 52, '"': 4086, '!': 1214, 'W': 579, '3': 105, "'": 1243, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 75, '9': 38, '_': 1426, 'à': 3, 'x': 937, 'z': 365, '°': 41, 'q': 575, ';': 561, '(': 56, ')': 56, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê': 5, 'â': 1, 'ô': 1, 'Æ': 3, '

In [4]:
vocab

defaultdict(int,
            {'T h e </w>': 475,
             'P r o j e c t </w>': 78,
             'G u t e n b e r g </w>': 21,
             'E B o o k </w>': 2,
             'o f </w>': 3553,
             'A l l </w>': 29,
             'A r o u n d </w>': 7,
             't h e </w>': 6871,
             'M o o n , </w>': 98,
             'b y </w>': 672,
             'J u l e s </w>': 3,
             'V e r n e </w>': 3,
             'T h i s </w>': 91,
             'e B o o k </w>': 5,
             'i s </w>': 670,
             'f o r </w>': 545,
             'u s e </w>': 24,
             'a n y o n e </w>': 4,
             'a n y w h e r e </w>': 4,
             'a t </w>': 693,
             'n o </w>': 288,
             'c o s t </w>': 3,
             'a n d </w>': 1826,
             'w i t h </w>': 668,
             'a l m o s t </w>': 54,
             'r e s t r i c t i o n s </w>': 2,
             'w h a t s o e v e r . </w>': 3,
             'Y o u </w>': 39,
             '

In [6]:
# check what pairs looks like
pairs = get_stats(vocab)
pairs

defaultdict(int,
            {('T', 'h'): 1075,
             ('h', 'e'): 12609,
             ('e', '</w>'): 17749,
             ('P', 'r'): 492,
             ('r', 'o'): 2467,
             ('o', 'j'): 444,
             ('j', 'e'): 583,
             ('e', 'c'): 1945,
             ('c', 't'): 1737,
             ('t', '</w>'): 9268,
             ('G', 'u'): 121,
             ('u', 't'): 1792,
             ('t', 'e'): 4141,
             ('e', 'n'): 5095,
             ('n', 'b'): 100,
             ('b', 'e'): 2161,
             ('e', 'r'): 7607,
             ('r', 'g'): 306,
             ('g', '</w>'): 3039,
             ('E', 'B'): 5,
             ('B', 'o'): 30,
             ('o', 'o'): 1308,
             ('o', 'k'): 253,
             ('k', '</w>'): 583,
             ('o', 'f'): 3870,
             ('f', '</w>'): 4187,
             ('A', 'l'): 64,
             ('l', 'l'): 2971,
             ('l', '</w>'): 2144,
             ('A', 'r'): 534,
             ('o', 'u'): 4789,
             ('u',

In [8]:
num_merges = 1000
for i in range(num_merges):
    pairs = get_stats(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print('Iter: {}'.format(i))
    print('Best pair: {}'.format(best))
    tokens = get_tokens(vocab)
    print('Tokens: {}'.format(tokens))
    print('Number of tokens: {}'.format(len(tokens)))
    print('==========')

Iter: 0
Best pair: ('e', '</w>')
Tokens: defaultdict(<class 'int'>, {'T': 1610, 'h': 26094, 'e</w>': 17749, 'P': 780, 'r': 29545, 'o': 34988, 'j': 857, 'e': 41398, 'c': 13891, 't': 44253, '</w>': 84081, 'G': 300, 'u': 13731, 'n': 32494, 'b': 7428, 'g': 8749, 'E': 901, 'B': 1163, 'k': 2726, 'f': 10469, 'A': 1381, 'l': 20632, 'd': 17576, 'M': 1206, ',': 8068, 'y': 8812, 'J': 80, 's': 28330, 'V': 104, 'i': 31434, 'a': 36693, 'w': 8134, 'm': 9812, 'v': 4880, '.': 4055, 'Y': 250, 'p': 8040, '-': 1128, 'L': 429, ':': 209, 'R': 369, 'D': 327, '6': 77, '2': 158, '0': 401, '5': 131, '[': 32, '#': 1, '1': 295, '4': 104, '7': 65, ']': 32, '*': 44, 'S': 860, 'O': 510, 'F': 422, 'H': 689, 'I': 1432, 'C': 863, 'U': 170, 'N': 796, 'K': 42, '/': 52, '"': 4086, '!': 1214, 'W': 579, '3': 105, "'": 1243, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 75, '9': 38, '_': 1426, 'à': 3, 'x': 937, 'z': 365, '°': 41, 'q': 575, ';': 561, '(': 56, ')': 56, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê'