# Byte-Pair Encoding tokenization

Based on the Huggingface tutorial https://huggingface.co/learn/nlp-course/chapter6/5

In [95]:
corpus = [
    "NLP course",
    "NLP: Natural Language Processing",
    "Attention is all you need.",
    "Transformer: A Novel Neural Network Architecture for Language Understanding.",
    "The Annotated Transformer.",
    "The Illustrated Transformer.",
    "Layer Normalization",
    "Dropout: A Simple Way to Prevent Neural Networks from Overfitting.",
    "Improving Transformer Optimization Through Better Initialization",
]

## Pre-tokenization

See `bytes_to_unicode()` from: https://github.com/karpathy/minGPT/blob/master/mingpt/bpe.py

In [96]:
def bytes_to_unicode():
    """
    Every possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode
    character that represents it visually. Some bytes have their appearance preserved
    because they don't cause any trouble. These are defined in list bs. For example:
    chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!".
    However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps these
    bytes, into new characters in a range where chr() returns a single nice character.
    So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8).
    In particular, the space character is 32, which we can see by ord(' '). Instead,
    this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'.
    So this is just a simple one-to-one mapping of bytes 0..255 into unicode characters
    that "look nice", either in their original form, or a funny shifted character
    like 'Ā', or 'Ġ', etc.
    """
    # the 188 integers that render fine in their original form and need no shifting
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict
    # now get the representations of the other 68 integers that do need shifting
    # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop
    n = 0
    for b in range(2**8):
        if b not in bs:
            # if this byte is "ugly" then map it to the next available "nice" character
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    d = dict(zip(bs, cs))
    return d

We will cheat a little bit and use the `transformers` package to do the pre-tokenization phase.

In [97]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [100]:
from collections import defaultdict

word_freqs = defaultdict(int)

for text in corpus:
    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [word for word, offset in words_with_offsets]
    for word in new_words:
        word_freqs[word] += 1

print(word_freqs)

defaultdict(<class 'int'>, {'NLP': 2, 'Ġcourse': 1, ':': 3, 'ĠNatural': 1, 'ĠLanguage': 2, 'ĠProcessing': 1, 'Attention': 1, 'Ġis': 1, 'Ġall': 1, 'Ġyou': 1, 'Ġneed': 1, '.': 5, 'Transformer': 1, 'ĠA': 2, 'ĠNovel': 1, 'ĠNeural': 2, 'ĠNetwork': 1, 'ĠArchitecture': 1, 'Ġfor': 1, 'ĠUnderstanding': 1, 'The': 2, 'ĠAnnotated': 1, 'ĠTransformer': 3, 'ĠIllustrated': 1, 'Layer': 1, 'ĠNormalization': 1, 'Dropout': 1, 'ĠSimple': 1, 'ĠWay': 1, 'Ġto': 1, 'ĠPrevent': 1, 'ĠNetworks': 1, 'Ġfrom': 1, 'ĠOverfitting': 1, 'Improving': 1, 'ĠOptimization': 1, 'ĠThrough': 1, 'ĠBetter': 1, 'ĠInitialization': 1})


In [105]:
alphabet = list(
    set([letter 
            for word in word_freqs.keys() 
                 for letter in word
        ])
    )

alphabet.sort()
print(alphabet)

['.', ':', 'A', 'B', 'D', 'I', 'L', 'N', 'O', 'P', 'S', 'T', 'U', 'W', 'a', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'Ġ']


In [106]:
vocab = ["<|endoftext|>"] + alphabet.copy()
print(vocab)

['<|endoftext|>', '.', ':', 'A', 'B', 'D', 'I', 'L', 'N', 'O', 'P', 'S', 'T', 'U', 'W', 'a', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'Ġ']


## BPE "training" step

In [107]:
splits = {word: [c for c in word] for word in word_freqs.keys()}
print(splits)

{'NLP': ['N', 'L', 'P'], 'Ġcourse': ['Ġ', 'c', 'o', 'u', 'r', 's', 'e'], ':': [':'], 'ĠNatural': ['Ġ', 'N', 'a', 't', 'u', 'r', 'a', 'l'], 'ĠLanguage': ['Ġ', 'L', 'a', 'n', 'g', 'u', 'a', 'g', 'e'], 'ĠProcessing': ['Ġ', 'P', 'r', 'o', 'c', 'e', 's', 's', 'i', 'n', 'g'], 'Attention': ['A', 't', 't', 'e', 'n', 't', 'i', 'o', 'n'], 'Ġis': ['Ġ', 'i', 's'], 'Ġall': ['Ġ', 'a', 'l', 'l'], 'Ġyou': ['Ġ', 'y', 'o', 'u'], 'Ġneed': ['Ġ', 'n', 'e', 'e', 'd'], '.': ['.'], 'Transformer': ['T', 'r', 'a', 'n', 's', 'f', 'o', 'r', 'm', 'e', 'r'], 'ĠA': ['Ġ', 'A'], 'ĠNovel': ['Ġ', 'N', 'o', 'v', 'e', 'l'], 'ĠNeural': ['Ġ', 'N', 'e', 'u', 'r', 'a', 'l'], 'ĠNetwork': ['Ġ', 'N', 'e', 't', 'w', 'o', 'r', 'k'], 'ĠArchitecture': ['Ġ', 'A', 'r', 'c', 'h', 'i', 't', 'e', 'c', 't', 'u', 'r', 'e'], 'Ġfor': ['Ġ', 'f', 'o', 'r'], 'ĠUnderstanding': ['Ġ', 'U', 'n', 'd', 'e', 'r', 's', 't', 'a', 'n', 'd', 'i', 'n', 'g'], 'The': ['T', 'h', 'e'], 'ĠAnnotated': ['Ġ', 'A', 'n', 'n', 'o', 't', 'a', 't', 'e', 'd'], 'ĠTransfo

In [108]:
import random

def sample_from_dict(d, sample=10):
    keys = random.sample(list(d), sample)
    values = [d[k] for k in keys]
    return dict(zip(keys, values))

In [123]:
def compute_pair_freqs(splits):
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs

pair_freqs = compute_pair_freqs(splits)

In [130]:
print(sample_from_dict(pair_freqs))

{('e', 'n'): 2, ('e', 'd'): 3, ('l', 'u'): 1, ('e', 'r'): 8, ('Ġ', 'c'): 1, ('i', 't'): 3, ('c', 'o'): 1, ('f', 'i'): 1, ('u', 'a'): 2, ('I', 'l'): 1}


In [131]:
best_pair = max(pair_freqs, key=pair_freqs.get)
print(best_pair, pair_freqs[best_pair])

('r', 'a') 8


In [84]:
merges = { best_pair: best_pair[0] + best_pair[1]}
vocab.append(best_pair[0] + best_pair[1])
print(merges)
print(vocab)

{('r', 'a'): 'ra'}
['<|endoftext|>', '.', ':', 'A', 'B', 'D', 'I', 'L', 'N', 'O', 'P', 'S', 'T', 'U', 'W', 'a', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'Ġ', 'ra']


In [85]:
def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue

        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                split = split[:i] + [a + b] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

In [86]:
splits = merge_pair('o','r',splits)
print(splits['ĠTransformer'])

['Ġ', 'T', 'r', 'a', 'n', 's', 'f', 'or', 'm', 'e', 'r']


In [87]:
vocab_size = 50

while len(vocab) < vocab_size:
    pair_freqs = compute_pair_freqs(splits)
    for pair, freq in pair_freqs.items():
        best_pair = max(pair_freqs, key=pair_freqs.get)
        splits = merge_pair(*best_pair, splits)
    merges[best_pair] = best_pair[0] + best_pair[1]
    vocab.append(best_pair[0] + best_pair[1])

In [88]:
print(merges)

{('r', 'a'): 'ra', ('e', 'r'): 'er', ('Ġ', 'N'): 'ĠN', ('t', 'i'): 'ti', ('n', 'g'): 'ng', ('r', 'o'): 'ro', ('f', 'or'): 'for', ('t', 'e'): 'te', ('ti', 'o'): 'tio', ('tio', 'n'): 'tion', ('T', 'ra'): 'Tra'}


In [89]:
print(vocab)

['<|endoftext|>', '.', ':', 'A', 'B', 'D', 'I', 'L', 'N', 'O', 'P', 'S', 'T', 'U', 'W', 'a', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'Ġ', 'ra', 'ra', 'er', 'ĠN', 'ti', 'ng', 'ro', 'for', 'te', 'tio', 'tion', 'Tra']


In [90]:
def tokenize(text):
    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
    pre_tokenized_text = [word for word, offset in pre_tokenize_result]
    splits = [[l for l in word] for word in pre_tokenized_text]
    for pair, merge in merges.items():
        for idx, split in enumerate(splits):
            i = 0
            while i < len(split) - 1:
                if split[i] == pair[0] and split[i + 1] == pair[1]:
                    split = split[:i] + [merge] + split[i + 2 :]
                else:
                    i += 1
            splits[idx] = split

    return sum(splits, [])

In [133]:
print(tokenize("attentionisallyouneed"))
print(tokenize("attentionmasters"))
print(tokenize("TransformersChatbotsInDisguise"))

['a', 't', 'te', 'n', 'tion', 'i', 's', 'a', 'l', 'l', 'y', 'o', 'u', 'n', 'e', 'e', 'd']
['a', 't', 'te', 'n', 'tion', 'm', 'a', 's', 't', 'er', 's']
['Tra', 'n', 's', 'f', 'o', 'r', 'm', 'er', 's', 'C', 'h', 'a', 't', 'b', 'o', 't', 's', 'I', 'n', 'D', 'i', 's', 'g', 'u', 'i', 's', 'e']


In [134]:
def tokens_to_index(tokens):
    idx_list = []
    for tok in tokens:
        idx_list.append(tokens.index(tok))
    return idx_list

print(tokenize("attentionisallyouneed"))
print(tokens_to_index(tokenize("attentionisallyouneed")))

['a', 't', 'te', 'n', 'tion', 'i', 's', 'a', 'l', 'l', 'y', 'o', 'u', 'n', 'e', 'e', 'd']
[0, 1, 2, 3, 4, 5, 6, 0, 8, 8, 10, 11, 12, 3, 14, 14, 16]


In [92]:
from IPython.core.display import HTML

def css_styling():
    styles = open("../css/notebook.css", "r").read()
    return HTML(styles)
css_styling()