## Implement Byte Pair Encoding (BPE) Algorithm

In [4]:
import urllib.request

# Retrieve Text from url
url = "https://olki.loria.fr/cerisara/lexres/nh.txt"

response = urllib.request.urlopen(url)
data = response.read() 
text = data.decode('utf-8')

# Save text to file
with open('nh.txt', 'w') as f:
    f.write(text)

In [6]:
# Read text from file
with open('nh.txt', 'r') as f:
    text = f.read()

In [7]:

tokens = text.encode('utf-8')

# Convert text to list of integers
tokens = list(map(int, tokens))

# print length of tokens
print(len(tokens))

# print first 50 Unigram tokens
print(tokens[:50])


10246792
[61, 95, 61, 95, 32, 70, 105, 108, 101, 58, 87, 105, 107, 105, 46, 112, 110, 103, 10, 84, 104, 105, 115, 32, 102, 97, 118, 105, 99, 111, 110, 32, 115, 104, 111, 119, 115, 32, 97, 32, 32, 102, 114, 111, 109, 32, 78, 101, 116, 72]


In [8]:
def get_statistics(tokens):
    counts = {}
    for pair in zip(tokens, tokens[1:]):
        # pair is a tuple of two integers
        counts[pair] = counts.get(pair, 0) + 1
    return counts

In [9]:
def merge(tokens, pair, new_token):
    new_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == pair:
            new_tokens.append(new_token)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens

In [10]:
vocab_size = 300
num_merges = vocab_size - 256

tokens_clone = list(tokens)

merges = {} # map from pair to new token
print(f"num_merges = {num_merges}")
for i in range(num_merges):
    counts = get_statistics(tokens_clone)
    pair = max(counts, key=counts.get)
    new_token = max(tokens_clone) + 1
    print(f"merge {pair} -> {new_token}")
    tokens_clone = merge(tokens_clone, pair, new_token)
    merges[pair] = new_token

merge (101, 32) -> 241
merge (32, 116) -> 242
merge (32, 97) -> 243
merge (105, 110) -> 244
merge (101, 114) -> 245
merge (115, 32) -> 246
merge (242, 104) -> 247
merge (111, 110) -> 248
merge (116, 32) -> 249
merge (111, 117) -> 250
merge (100, 32) -> 251
merge (97, 110) -> 252
merge (111, 114) -> 253
merge (247, 241) -> 254
merge (61, 95) -> 255
merge (101, 110) -> 256
merge (115, 116) -> 257
merge (44, 32) -> 258
merge (108, 101) -> 259
merge (105, 116) -> 260
merge (114, 101) -> 261
merge (32, 32) -> 262
merge (244, 103) -> 263
merge (111, 32) -> 264
merge (121, 32) -> 265
merge (97, 116) -> 266
merge (97, 114) -> 267
merge (108, 108) -> 268
merge (111, 102) -> 269
merge (46, 32) -> 270
merge (97, 99) -> 271
merge (243, 110) -> 272
merge (243, 32) -> 273
merge (99, 101) -> 274
merge (101, 116) -> 275
merge (105, 99) -> 276
merge (114, 111) -> 277
merge (255, 255) -> 278
merge (278, 32) -> 279
merge (10, 279) -> 280
merge (250, 114) -> 281
merge (97, 108) -> 282
merge (105, 115) -> 

In [11]:
print(f"tokens length = {len(tokens)}")
print(f"tokens_clone length = {len(tokens_clone)}")
print(f"compression rate = {len(tokens) / len(tokens_clone):.2f}")

tokens length = 10246792
tokens_clone length = 7342917
compression rate = 1.40


In [12]:
print(f"vocab_size = {len(set(tokens_clone))}")
print(f"merges = {merges}")

vocab_size = 231
merges = {(101, 32): 241, (32, 116): 242, (32, 97): 243, (105, 110): 244, (101, 114): 245, (115, 32): 246, (242, 104): 247, (111, 110): 248, (116, 32): 249, (111, 117): 250, (100, 32): 251, (97, 110): 252, (111, 114): 253, (247, 241): 254, (61, 95): 255, (101, 110): 256, (115, 116): 257, (44, 32): 258, (108, 101): 259, (105, 116): 260, (114, 101): 261, (32, 32): 262, (244, 103): 263, (111, 32): 264, (121, 32): 265, (97, 116): 266, (97, 114): 267, (108, 108): 268, (111, 102): 269, (46, 32): 270, (97, 99): 271, (243, 110): 272, (243, 32): 273, (99, 101): 274, (101, 116): 275, (105, 99): 276, (114, 111): 277, (255, 255): 278, (278, 32): 279, (10, 279): 280, (250, 114): 281, (97, 108): 282, (105, 115): 283, (245, 32): 284}
