We want to realize an tokenizer from scratch

In [1]:
print(ord("😯"))
print(ord("你"))

128559
20320


In [2]:
text = "你好"
print(list(text.encode("utf-8")))

[228, 189, 160, 229, 165, 189]


In [3]:
# get the word pairs from the raw utf-8 bytes

def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(list(text.encode('utf-8')))
print(stats)

{(228, 189): 1, (189, 160): 1, (160, 229): 1, (229, 165): 1, (165, 189): 1}


In [4]:
# to get the top pairs, we can use python's method

top_token = max(stats, key=stats.get)
print(top_token)

(228, 189)


In [15]:
# replace the old index with new paris

def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids


In [6]:
text = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception."
tokens = list(map(int, text.encode('utf-8')))

vocab_size = 276
merge_number = vocab_size - 256
merges = {}
ids = list(tokens)  # copy it

for i in range(merge_number):
    stats = get_stats(ids)
    top_token = max(stats, key=stats.get)
    idx = 256 + i
    ids = merge(ids, top_token, idx)
    merges[top_token] = idx


merging (101, 32) into 256
merging (240, 159) into 257
merging (226, 128) into 258
merging (105, 110) into 259
merging (115, 32) into 260
merging (97, 110) into 261
merging (116, 104) into 262
merging (257, 133) into 263
merging (257, 135) into 264
merging (97, 114) into 265
merging (239, 189) into 266
merging (258, 140) into 267
merging (267, 264) into 268
merging (101, 114) into 269
merging (111, 114) into 270
merging (116, 32) into 271
merging (259, 103) into 272
merging (115, 116) into 273
merging (261, 100) into 274
merging (32, 262) into 275


In [7]:
# now we get the decoder function, we need a vocabulary.

vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

print(vocab)

def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode('utf-8', errors='replace')
    return text

print(decode(ids))

{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[',

In [8]:
print(decode([128]))

�


In [13]:
# now getting the encoder

def encode(text):
    tokens = list(map(int, text.encode("utf-8")))
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break  # nothing to be merged
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens

print(encode("hello world!"))
print(encode("h"))

merging (111, 114) into 270
[104, 101, 108, 108, 111, 32, 119, 270, 108, 100, 33]
[104]


In [16]:
print(decode(encode("hello world!")))

hello world!
