# byte-pair encoding

In [1]:
import numpy as np

# initialize the vocabulary

In [2]:
# some text with lots of repititions
text = "like liker love lovely hug hugs hugging hearts"

chars = list(set(text))
chars.sort() # initial vocab is sorted

for l in chars:
    print(f"'{l}' appears {text.count(l)} times.")

' ' appears 7 times.
'a' appears 1 times.
'e' appears 5 times.
'g' appears 5 times.
'h' appears 4 times.
'i' appears 3 times.
'k' appears 2 times.
'l' appears 5 times.
'n' appears 1 times.
'o' appears 2 times.
'r' appears 2 times.
's' appears 2 times.
't' appears 1 times.
'u' appears 3 times.
'v' appears 2 times.
'y' appears 1 times.


In [3]:
# make a vocabulary
vocab = {word: i for i, word in enumerate(chars)}
vocab

{' ': 0,
 'a': 1,
 'e': 2,
 'g': 3,
 'h': 4,
 'i': 5,
 'k': 6,
 'l': 7,
 'n': 8,
 'o': 9,
 'r': 10,
 's': 11,
 't': 12,
 'u': 13,
 'v': 14,
 'y': 15}

In [4]:
# the text needs to be a list, not a string
# each element in the list is a token

original_text = list(text)
print(text)
print(original_text)

like liker love lovely hug hugs hugging hearts
['l', 'i', 'k', 'e', ' ', 'l', 'i', 'k', 'e', 'r', ' ', 'l', 'o', 'v', 'e', ' ', 'l', 'o', 'v', 'e', 'l', 'y', ' ', 'h', 'u', 'g', ' ', 'h', 'u', 'g', 's', ' ', 'h', 'u', 'g', 'g', 'i', 'n', 'g', ' ', 'h', 'e', 'a', 'r', 't', 's']


# find character pairs and merge the most frequent

In [5]:
token_pairs = dict()

# loop over tokens
for i in range(len(original_text)-1):
    # create a pair
    pair = original_text[i] + original_text[i+1]

    # increase pair frequencies
    if pair in token_pairs:
        token_pairs[pair] += 1
    else:
        token_pairs[pair] = 1


token_pairs

{'li': 2,
 'ik': 2,
 'ke': 2,
 'e ': 2,
 ' l': 3,
 'er': 1,
 'r ': 1,
 'lo': 2,
 'ov': 2,
 've': 2,
 'el': 1,
 'ly': 1,
 'y ': 1,
 ' h': 4,
 'hu': 3,
 'ug': 3,
 'g ': 2,
 'gs': 1,
 's ': 1,
 'gg': 1,
 'gi': 1,
 'in': 1,
 'ng': 1,
 'he': 1,
 'ea': 1,
 'ar': 1,
 'rt': 1,
 'ts': 1}

In [6]:
# find the most frequent pair
most_frequent_pair_idx = np.argmax(list(token_pairs.values()))
most_frequent_pair_char = list(token_pairs.keys())[most_frequent_pair_idx]
print(f'The most frequent character pair is "{most_frequent_pair_char}" with {list(token_pairs.values())[most_frequent_pair_idx]} appearances')

The most frequent character pair is " h" with 4 appearances


In [7]:
# update the vocab
vocab[most_frequent_pair_char] = max(vocab.values()) + 1
vocab

{' ': 0,
 'a': 1,
 'e': 2,
 'g': 3,
 'h': 4,
 'i': 5,
 'k': 6,
 'l': 7,
 'n': 8,
 'o': 9,
 'r': 10,
 's': 11,
 't': 12,
 'u': 13,
 'v': 14,
 'y': 15,
 ' h': 16}

# replace the token pair with one token

In [8]:
# initialize a new text list
new_text = []

# loop through the list
i = 0
while i < (len(original_text)-1):
    # test whether the pair of this and the following elements match the newly-created pair
    if (original_text[i] + original_text[i+1]) == most_frequent_pair_char:
        # append to the new version of the text
        new_text.append(most_frequent_pair_char)
        print(f"added '{most_frequent_pair_char}'")
        # skip the next character
        i += 2
    
    # this isn't a merged pair, so add this token to list
    else:
        new_text.append(original_text[i])
        # move to the next character
        i += 1

print("\n")
print(f"Original text: {original_text}")
print(f"Updated text: {new_text}")

print(f"\n\nOriginal text had {len(original_text)} tokens.")
print(f"\n\nNew text has {len(new_text)} tokens.")

added ' h'
added ' h'
added ' h'
added ' h'


Original text: ['l', 'i', 'k', 'e', ' ', 'l', 'i', 'k', 'e', 'r', ' ', 'l', 'o', 'v', 'e', ' ', 'l', 'o', 'v', 'e', 'l', 'y', ' ', 'h', 'u', 'g', ' ', 'h', 'u', 'g', 's', ' ', 'h', 'u', 'g', 'g', 'i', 'n', 'g', ' ', 'h', 'e', 'a', 'r', 't', 's']
Updated text: ['l', 'i', 'k', 'e', ' ', 'l', 'i', 'k', 'e', 'r', ' ', 'l', 'o', 'v', 'e', ' ', 'l', 'o', 'v', 'e', 'l', 'y', ' h', 'u', 'g', ' h', 'u', 'g', 's', ' h', 'u', 'g', 'g', 'i', 'n', 'g', ' h', 'e', 'a', 'r', 't']


Original text had 46 tokens.


New text has 41 tokens.


# fint the most common letter pairs (again!)

In [9]:
token_pairs = dict()

# loop over the new_text tokens (not the original)
for i in range(len(new_text)-1):
    # create a pair
    pair = new_text[i] + new_text[i+1]

    # increase pair frequencies
    if pair in token_pairs:
        token_pairs[pair] += 1
    else:
        token_pairs[pair] = 1

token_pairs

{'li': 2,
 'ik': 2,
 'ke': 2,
 'e ': 2,
 ' l': 3,
 'er': 1,
 'r ': 1,
 'lo': 2,
 'ov': 2,
 've': 2,
 'el': 1,
 'ly': 1,
 'y h': 1,
 ' hu': 3,
 'ug': 3,
 'g h': 2,
 'gs': 1,
 's h': 1,
 'gg': 1,
 'gi': 1,
 'in': 1,
 'ng': 1,
 ' he': 1,
 'ea': 1,
 'ar': 1,
 'rt': 1}

# now using functions

In [11]:
def get_pair_stats(text2pair):
    token_pairs = dict()

    # loop over tokens
    for i in range(len(text2pair)-1):
        # create a pair
        pair = text2pair[i] + text2pair[i+1]

        # increase pair frequencies
        if pair in token_pairs:
            token_pairs[pair] += 1
        else:
            token_pairs[pair] = 1
    
    return token_pairs


def update_vocab(token_pairs, vocab):
    # find the most frequent pair
    idx = np.argmax(list(token_pairs.values()))
    new_token = list(token_pairs.keys())[idx]

    # update the vocab
    vocab[new_token] = max(vocab.values()) + 1

    return vocab, new_token


def generate_new_token_sequence(prev_text, new_token):
    # initialize a new text list
    new_text = []

    # loop through the list
    i = 0
    while i < (len(prev_text)-1):
        # test whether the pair of this and the following element match the newly-created pair
        if (prev_text[i] + prev_text[i+1]) == new_token:
            new_text.append(new_token)
            i += 2 # skip the next character
        else:
            new_text.append(prev_text[i])
            i += 1
    
    if i < len(prev_text):
        new_text.append(prev_text[i])

    return new_text

In [12]:
# re-initialize the vocab
vocab = {word: i for i, word in enumerate(chars)}
print(f"Vocab has {len(vocab)} tokens.")

Vocab has 16 tokens.


In [13]:
# do one iteration

# find and count pairs
pairs = get_pair_stats(original_text)

# update the dictionary
vocab, new_token = update_vocab(pairs, vocab)

# get a new list of tokens
updated_text = generate_new_token_sequence(original_text, new_token)

print(f"Vocab has {len(vocab)} tokens.")

Vocab has 17 tokens.


In [14]:
updated_text

['l',
 'i',
 'k',
 'e',
 ' ',
 'l',
 'i',
 'k',
 'e',
 'r',
 ' ',
 'l',
 'o',
 'v',
 'e',
 ' ',
 'l',
 'o',
 'v',
 'e',
 'l',
 'y',
 ' h',
 'u',
 'g',
 ' h',
 'u',
 'g',
 's',
 ' h',
 'u',
 'g',
 'g',
 'i',
 'n',
 'g',
 ' h',
 'e',
 'a',
 'r',
 't',
 's']

In [15]:
## do a second iteration
pairs = get_pair_stats(updated_text)

# update the dictionary
vocab,newtoken = update_vocab(pairs,vocab)

# get a new list of tokens
updated_text = generate_new_token_sequence(updated_text,newtoken)
print(f'Vocab has {len(vocab)} tokens.')

Vocab has 18 tokens.


In [16]:
updated_text

['l',
 'i',
 'k',
 'e',
 ' l',
 'i',
 'k',
 'e',
 'r',
 ' l',
 'o',
 'v',
 'e',
 ' l',
 'o',
 'v',
 'e',
 'l',
 'y',
 ' h',
 'u',
 'g',
 ' h',
 'u',
 'g',
 's',
 ' h',
 'u',
 'g',
 'g',
 'i',
 'n',
 'g',
 ' h',
 'e',
 'a',
 'r',
 't',
 's']

In [17]:
vocab

{' ': 0,
 'a': 1,
 'e': 2,
 'g': 3,
 'h': 4,
 'i': 5,
 'k': 6,
 'l': 7,
 'n': 8,
 'o': 9,
 'r': 10,
 's': 11,
 't': 12,
 'u': 13,
 'v': 14,
 'y': 15,
 ' h': 16,
 ' l': 17}