In [1]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [2]:
from collections import defaultdict

corpus = ["The quick brown fox jumps over the lazy dog.", 
          "Pack my box with five dozen liquor jugs.",
          "The five boxing wizards jump quickly."]
word_freqs = defaultdict(int)

for text in corpus:
    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [word for word, offset in words_with_offsets]

    for word in new_words:
        word_freqs[word] += 1
print(word_freqs)

defaultdict(<class 'int'>, {'The': 2, 'Ġquick': 1, 'Ġbrown': 1, 'Ġfox': 1, 'Ġjumps': 1, 'Ġover': 1, 'Ġthe': 1, 'Ġlazy': 1, 'Ġdog': 1, '.': 3, 'Pack': 1, 'Ġmy': 1, 'Ġbox': 1, 'Ġwith': 1, 'Ġfive': 2, 'Ġdozen': 1, 'Ġliquor': 1, 'Ġjugs': 1, 'Ġboxing': 1, 'Ġwizards': 1, 'Ġjump': 1, 'Ġquickly': 1})


In [3]:
alphabet = []

for word in word_freqs.keys():
    for letter in word:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()
print(alphabet)

['.', 'P', 'T', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'Ġ']


In [4]:
vocab = ["<|endoftext|>"] + alphabet.copy()
splits = {word: [c for c in word] for word in word_freqs.keys()}

def compute_pair_freqs(splits):
    pair_freqs = defaultdict(int)

    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue

        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs

pair_freqs = compute_pair_freqs(splits)

for i, key in enumerate(pair_freqs.keys()):
    print(f"{key}: {pair_freqs[key]}")
    if i >= 5:
        break

('T', 'h'): 2
('h', 'e'): 3
('Ġ', 'q'): 2
('q', 'u'): 3
('u', 'i'): 2
('i', 'c'): 2


In [5]:
best_pair = ""
max_freq = None

for pair, freq in pair_freqs.items():
    if max_freq is None or max_freq < freq:
        best_pair = pair
        max_freq = freq
print(best_pair, max_freq)

('h', 'e') 3


In [6]:
merges = {("h", "e"): "he"}
vocab.append("he")

In [7]:
def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0

        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                split = split[:i] + [a + b] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits
 
splits = merge_pair("h", "e", splits)
print(splits["The"])

['T', 'he']


In [8]:
vocab_size = 50

while len(vocab) < vocab_size:
    pair_freqs = compute_pair_freqs(splits)
    best_pair = ""
    max_freq = None

    for pair, freq in pair_freqs.items():
        if max_freq is None or max_freq < freq:
            best_pair = pair
            max_freq = freq
    splits = merge_pair(*best_pair, splits)
    merges[best_pair] = best_pair[0] + best_pair[1]
    vocab.append(best_pair[0] + best_pair[1])
print(merges)

{('h', 'e'): 'he', ('q', 'u'): 'qu', ('c', 'k'): 'ck', ('Ġ', 'b'): 'Ġb', ('Ġ', 'f'): 'Ġf', ('o', 'x'): 'ox', ('Ġ', 'j'): 'Ġj', ('Ġj', 'u'): 'Ġju', ('v', 'e'): 've', ('T', 'he'): 'The', ('Ġ', 'qu'): 'Ġqu', ('Ġqu', 'i'): 'Ġqui', ('Ġqui', 'ck'): 'Ġquick', ('Ġju', 'm'): 'Ġjum', ('Ġjum', 'p'): 'Ġjump', ('Ġ', 'l'): 'Ġl', ('Ġ', 'd'): 'Ġd', ('Ġd', 'o'): 'Ġdo', ('Ġb', 'ox'): 'Ġbox'}


In [9]:
print(vocab)

['<|endoftext|>', '.', 'P', 'T', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'Ġ', 'he', 'qu', 'ck', 'Ġb', 'Ġf', 'ox', 'Ġj', 'Ġju', 've', 'The', 'Ġqu', 'Ġqui', 'Ġquick', 'Ġjum', 'Ġjump', 'Ġl', 'Ġd', 'Ġdo', 'Ġbox']


In [10]:
def tokenize(text):
    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
    pre_tokenized_text = [word for word, offset in pre_tokenize_result]
    splits = [[l for l in word] for word in pre_tokenized_text]

    for pair, merge in merges.items():
        for idx, split in enumerate(splits):
            i = 0

            while i < len(split) - 1:
                if split[i] == pair[0] and split[i + 1] == pair[1]:
                    split = split[:i] + [merge] + split[i + 2 :]
                else:
                    i += 1
            splits[idx] = split
    return sum(splits, [])
print(tokenize("Sphinx of black quartz, judge my vow."))

['S', 'p', 'h', 'i', 'n', 'x', 'Ġ', 'o', 'f', 'Ġb', 'l', 'a', 'ck', 'Ġqu', 'a', 'r', 't', 'z', ',', 'Ġju', 'd', 'g', 'e', 'Ġ', 'm', 'y', 'Ġ', 'v', 'o', 'w', '.']
