# Implementation of BPE algorithm

In [1]:
text = """
Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. 

Duis autem vel eum iriure dolor in hendrerit in vulputate velit esse molestie consequat, vel illum dolore eu
"""

In [2]:
words = text.strip().split(" ")

print(f"Vocabulary size: {len(words)}")

Vocabulary size: 100


In [None]:
#!wget http://www.gutenberg.org/cache/epub/16457/pg16457.txt


In [11]:
import re, collections

def get_vocab_text(text):
    vocab = collections.defaultdict(int)
    words = text.strip().split()
    for word in words:
        vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

def get_vocab(filename):
    vocab = collections.defaultdict(int)
    with open(filename, 'r', encoding='utf-8') as fhand:
        for line in fhand:
            words = line.strip().split()
            for word in words:
                vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

def get_pairs(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def get_tokens(vocab):
    tokens = collections.defaultdict(int)
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens[token] += freq
    return tokens

Tokens Before BPE
Tokens: defaultdict(<class 'int'>, {'\ufeff': 1, 'T': 1610, 'h': 26094, 'e': 59147, '</w>': 101830, 'P': 780, 'r': 29545, 'o': 34988, 'j': 857, 'c': 13891, 't': 44253, 'G': 300, 'u': 13731, 'n': 32494, 'b': 7428, 'g': 8749, 'E': 901, 'B': 1163, 'k': 2726, 'f': 10469, 'A': 1381, 'l': 20632, 'd': 17576, 'M': 1206, ',': 8068, 'y': 8812, 'J': 80, 's': 28330, 'V': 104, 'i': 31434, 'a': 36693, 'w': 8134, 'm': 9812, 'v': 4880, '.': 4055, 'Y': 250, 'p': 8040, '-': 1128, 'L': 429, ':': 209, 'R': 369, 'D': 327, '6': 77, '2': 158, '0': 401, '5': 131, '[': 32, '#': 1, '1': 295, '4': 104, '7': 65, ']': 32, '*': 44, 'S': 860, 'O': 510, 'F': 422, 'H': 689, 'I': 1432, 'C': 863, 'U': 170, 'N': 796, 'K': 42, '/': 52, '"': 4086, '!': 1214, 'W': 579, '3': 105, "'": 1243, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 75, '9': 38, '_': 1426, 'à': 3, 'x': 937, 'z': 365, '°': 41, 'q': 575, ';': 561, '(': 56, ')': 56, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê': 5, 'â': 1, 'ô':

In [12]:
vocab = get_vocab('pg16457.txt')

print('Tokens Before BPE')
tokens = get_tokens(vocab)
vocab = get_vocab('pg16457.txt')
print('==========')
print('Tokens Before BPE')
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')

num_merges = 100
for i in range(num_merges):
    pairs = get_pairs(vocab)

    if not pairs:
        break

    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    tokens = get_tokens(vocab)

    print('Iter: {}\n Best Pair {}\n Tokens: {}\n Number of tokens: {}\n========='.format(i, best, tokens, len(tokens)))
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')

num_merges = 100
for i in range(num_merges):
    pairs = get_pairs(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    print(best)
    vocab = merge_vocab(best, vocab)
    print('Iter: {}'.format(i))
    print('Best pair: {}'.format(best))
    tokens = get_tokens(vocab)
    print('Tokens: {}'.format(tokens))
    print('Number of tokens: {}'.format(len(tokens)))
    print('==========')

Tokens Before BPE
Tokens Before BPE
Tokens: defaultdict(<class 'int'>, {'\ufeff': 1, 'T': 1610, 'h': 26094, 'e': 59147, '</w>': 101830, 'P': 780, 'r': 29545, 'o': 34988, 'j': 857, 'c': 13891, 't': 44253, 'G': 300, 'u': 13731, 'n': 32494, 'b': 7428, 'g': 8749, 'E': 901, 'B': 1163, 'k': 2726, 'f': 10469, 'A': 1381, 'l': 20632, 'd': 17576, 'M': 1206, ',': 8068, 'y': 8812, 'J': 80, 's': 28330, 'V': 104, 'i': 31434, 'a': 36693, 'w': 8134, 'm': 9812, 'v': 4880, '.': 4055, 'Y': 250, 'p': 8040, '-': 1128, 'L': 429, ':': 209, 'R': 369, 'D': 327, '6': 77, '2': 158, '0': 401, '5': 131, '[': 32, '#': 1, '1': 295, '4': 104, '7': 65, ']': 32, '*': 44, 'S': 860, 'O': 510, 'F': 422, 'H': 689, 'I': 1432, 'C': 863, 'U': 170, 'N': 796, 'K': 42, '/': 52, '"': 4086, '!': 1214, 'W': 579, '3': 105, "'": 1243, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 75, '9': 38, '_': 1426, 'à': 3, 'x': 937, 'z': 365, '°': 41, 'q': 575, ';': 561, '(': 56, ')': 56, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê