In [43]:
text = '''Karpathy was born in Bratislava, Czechoslovakia (now Slovakia)[9][10][11][12] and moved with his family to Toronto when he was 15.[13] He completed his Computer Science and Physics bachelor's degrees at University of Toronto in 2009[14] and his master's degree at University of British Columbia in 2011,[14] where he worked on physically-simulated figures (for example, a simulated runner or a simulated person in a crowd).'''
tokens = text.encode("utf-8")
tokens = list(map(int, tokens))

In [44]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

In [45]:
stats = get_stats(tokens)
top_pair = max(stats, key=stats.get)
top_pair

(115, 32)

In [46]:
def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

In [47]:
print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))

[5, 6, 99, 9, 1]


In [48]:
vocab_size = 276
num_merges = vocab_size - 256
ids = list(tokens)

merges = {} # (int, int) -> int
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

merging (115, 32) into a new token 256
merging (100, 32) into a new token 257
merging (110, 32) into a new token 258
merging (97, 116) into a new token 259
merging (111, 114) into a new token 260
merging (101, 114) into a new token 261
merging (115, 105) into a new token 262
merging (91, 49) into a new token 263
merging (101, 257) into a new token 264
merging (101, 32) into a new token 265
merging (97, 32) into a new token 266
merging (121, 32) into a new token 267
merging (105, 258) into a new token 268
merging (93, 32) into a new token 269
merging (118, 97) into a new token 270
merging (93, 263) into a new token 271
merging (97, 110) into a new token 272
merging (272, 257) into a new token 273
merging (104, 105) into a new token 274
merging (274, 256) into a new token 275


In [49]:
print(f"compression ration: {len(tokens) / len(ids):.2f}X")

compression ration: 1.34X


In [60]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text


def encode(text):
    tokens = list(text.encode('utf-8'))
    while True:
        stats = get_stats(tokens)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens