## Byte-Pair Encoding (BPE)

In [1]:
import os
import json
import regex as re
import requests

import torch

In [2]:
# encoding is simply converting text into numbers so they can be fed into a neural network
# the first thing we need is a corpus of text to train our encoding

corpus = [
    "You will rejoice to hear that no disaster has accompanied the commencement of an enterprise which you have regarded with such evil forebodings.",
    "I arrived here yesterday, and my first task is to assure my dear sister of my welfare and increasing confidence in the success of my undertaking.",
    "I am already far north of London, and as I walk in the streets of Petersburgh, I feel a cold northern breeze play upon my cheeks, which braces my nerves and fills me with delight.",
    "Do you understand this feeling?",
    "This breeze, which has travelled from the regions towards which I am advancing, gives me a foretaste of those icy climes.",
    "Inspirited by this wind of promise, my daydreams become more fervent and vivid.",
    "I try in vain to be persuaded that the pole is the seat of frost and desolation; it ever presents itself to my imagination as the region of beauty and delight.",
    "There, Margaret, the sun is forever visible, its broad disk just skirting the horizon and diffusing a perpetual splendour.",
    "There—for with your leave, my sister, I will put some trust in preceding navigators—there snow and frost are banished; and, sailing over a calm sea, we may be wafted to a land surpassing in wonders and in beauty every region hitherto discovered on the habitable globe.",
    "Its productions and features may be without example, as the phenomena of the heavenly bodies undoubtedly are in those undiscovered solitudes."
]

In [3]:
# there's a pre-tokenization step where we compute the frequency of each word in the corpus and get the alphabet

from collections import defaultdict

word_freqs = defaultdict(int)

for text in corpus:
    for word in text.split():
        word_freqs[word] += 1

print(word_freqs)

defaultdict(<class 'int'>, {'You': 1, 'will': 2, 'rejoice': 1, 'to': 5, 'hear': 1, 'that': 2, 'no': 1, 'disaster': 1, 'has': 2, 'accompanied': 1, 'the': 12, 'commencement': 1, 'of': 10, 'an': 1, 'enterprise': 1, 'which': 4, 'you': 2, 'have': 1, 'regarded': 1, 'with': 3, 'such': 1, 'evil': 1, 'forebodings.': 1, 'I': 7, 'arrived': 1, 'here': 1, 'yesterday,': 1, 'and': 11, 'my': 9, 'first': 1, 'task': 1, 'is': 3, 'assure': 1, 'dear': 1, 'sister': 1, 'welfare': 1, 'increasing': 1, 'confidence': 1, 'in': 7, 'success': 1, 'undertaking.': 1, 'am': 2, 'already': 1, 'far': 1, 'north': 1, 'London,': 1, 'as': 3, 'walk': 1, 'streets': 1, 'Petersburgh,': 1, 'feel': 1, 'a': 5, 'cold': 1, 'northern': 1, 'breeze': 1, 'play': 1, 'upon': 1, 'cheeks,': 1, 'braces': 1, 'nerves': 1, 'fills': 1, 'me': 2, 'delight.': 2, 'Do': 1, 'understand': 1, 'this': 2, 'feeling?': 1, 'This': 1, 'breeze,': 1, 'travelled': 1, 'from': 1, 'regions': 1, 'towards': 1, 'advancing,': 1, 'gives': 1, 'foretaste': 1, 'those': 2, 'i

In [4]:
alphabet = sorted(list(set(letter for word in word_freqs for letter in word)))

In [5]:
print(*alphabet, sep="|")

,|.|;|?|D|I|L|M|P|T|Y|a|b|c|d|e|f|g|h|i|j|k|l|m|n|o|p|r|s|t|u|v|w|x|y|z|—


In [6]:
vocab = ["<|endoftext|>"] + alphabet.copy()

In [7]:
splits = {word: [c for c in word] for word in word_freqs.keys()}

In [8]:
def compute_pair_freqs(splits):
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs

In [14]:
pair_freqs = compute_pair_freqs(splits)

for i, key in enumerate(pair_freqs.keys()):
    if pair_freqs[key] > 10:
        print(f"{key}: {pair_freqs[key]}")

('r', 'e'): 29
('h', 'e'): 23
('e', 'a'): 12
('a', 'r'): 11
('t', 'h'): 26
('i', 's'): 16
('a', 's'): 11
('s', 't'): 12
('e', 'r'): 25
('a', 'n'): 18
('e', 'd'): 12
('v', 'e'): 14
('d', 'e'): 11
('i', 't'): 11
('i', 'n'): 21
('e', 's'): 11
('n', 'd'): 22
('o', 'n'): 13


In [41]:
best_pair = ""
max_freq = None

for pair, freq in pair_freqs.items():
    if max_freq is None or max_freq < freq:
        best_pair = pair
        max_freq = freq

print(best_pair, max_freq)

('r', 'e') 29


In [16]:
merges = defaultdict(int)

In [17]:
def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < (len(split) - 1):
            if split[i] == a and split[i + 1] == b:
                split = split[:i] + [a + b] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

In [18]:
vocab_size = 50

while len(vocab) < vocab_size:
    print('.', end="")
    pair_freqs = compute_pair_freqs(splits)
    best_pair = ""
    max_freq = None
    for pair, freq in pair_freqs.items():
        if max_freq is None or max_freq < freq:
            best_pair = pair
            max_freq = freq
    splits = merge_pair(*best_pair, splits)
    merges[best_pair] = best_pair[0] + best_pair[1]
    vocab.append(best_pair[0] + best_pair[1])

............

In [19]:
print(merges)

defaultdict(<class 'int'>, {('r', 'e'): 're', ('t', 'h'): 'th', ('n', 'd'): 'nd', ('i', 'n'): 'in', ('e', 'r'): 'er', ('i', 's'): 'is', ('a', 'nd'): 'and', ('th', 'e'): 'the', ('a', 's'): 'as', ('o', 'n'): 'on', ('e', 'd'): 'ed', ('o', 'f'): 'of'})


In [20]:
print(vocab)

['<|endoftext|>', ',', '.', ';', '?', 'D', 'I', 'L', 'M', 'P', 'T', 'Y', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '—', 're', 'th', 'nd', 'in', 'er', 'is', 'and', 'the', 'as', 'on', 'ed', 'of']


In [22]:
def tokenize(text):
    word_split = text.split(' ')
    splits = [[l for l in word] for word in word_split]
    for pair, merge in merges.items():
        for idx, split in enumerate(splits):
            i = 0
            while i < len(split) - 1:
                if split[i] == pair[0] and split[i + 1] == pair[1]:
                    split = split[:i] + [merge] + split[i + 2 :]
                else:
                    i += 1
            splits[idx] = split

    return sum(splits, [])

In [24]:
tokenize("This is not the token.")

['T', 'h', 'is', 'is', 'n', 'o', 't', 'the', 't', 'o', 'k', 'e', 'n', '.']