In [1]:
text = None
tokens = []

In [2]:
for lang in ["en", "hi", "ka"]:
    try:
        with open(f"tok_corpus_{lang}.txt", "r") as file:
            text = file.read().strip()
        assert text, "Tokenization corpus file is empty"
    except FileNotFoundError:
        print("Tokenization corpus file not found")
        exit()

    print(f"Lang: {lang}")

    text = text.encode("utf-8")
    print(f"len(text) = {len(text)}")
    
    print(f"len(tokens) = {len(list(map(int, text)))}")

    tokens += list(map(int, text))

    print("---------")

Lang: en
len(text) = 6423
len(tokens) = 6423
---------
Lang: hi
len(text) = 91330
len(tokens) = 91330
---------
Lang: ka
len(text) = 108961
len(tokens) = 108961
---------


In [3]:
len(tokens)

206714

In [4]:
def get_stats(ids):
    freq = {}
    for pair in zip(ids[:-1], ids[1:]):
        freq[pair] = freq.get(pair, 0) + 1
    return freq

def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

vocab_size = 1024*8
num_merges = vocab_size - 256
ids = list(tokens)
merges = {}
for i in range(num_merges):
    if i % 100 == 0:
        print(f"Merging {i} of {num_merges}")
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    # print(f"Merging {pair[0]} and {pair[1]} to {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

print("Merging complete..")

Merging 0 of 7936
Merging 100 of 7936
Merging 200 of 7936
Merging 300 of 7936
Merging 400 of 7936
Merging 500 of 7936
Merging 600 of 7936
Merging 700 of 7936
Merging 800 of 7936
Merging 900 of 7936
Merging 1000 of 7936
Merging 1100 of 7936
Merging 1200 of 7936
Merging 1300 of 7936
Merging 1400 of 7936
Merging 1500 of 7936
Merging 1600 of 7936
Merging 1700 of 7936
Merging 1800 of 7936
Merging 1900 of 7936
Merging 2000 of 7936
Merging 2100 of 7936
Merging 2200 of 7936
Merging 2300 of 7936
Merging 2400 of 7936
Merging 2500 of 7936
Merging 2600 of 7936
Merging 2700 of 7936
Merging 2800 of 7936
Merging 2900 of 7936
Merging 3000 of 7936
Merging 3100 of 7936
Merging 3200 of 7936
Merging 3300 of 7936
Merging 3400 of 7936
Merging 3500 of 7936
Merging 3600 of 7936
Merging 3700 of 7936
Merging 3800 of 7936
Merging 3900 of 7936
Merging 4000 of 7936
Merging 4100 of 7936
Merging 4200 of 7936
Merging 4300 of 7936
Merging 4400 of 7936
Merging 4500 of 7936
Merging 4600 of 7936
Merging 4700 of 7936
Merg

In [5]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
    text = b"".join([vocab[id] for id in ids])
    text = text.decode(encoding="utf-8", errors="replace")
    return text

def encode(text):
    tokens = text.encode("utf-8")
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        if not pair in merges:
            break
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens

In [6]:
text = "My name is Yaseen and I'm a Deep Learning Engineer"

len(text.encode("utf-8"))

50

In [7]:
len(encode(text))

27

In [8]:
text == decode(encode(text))

True

In [9]:
import json

# Convert tuple keys to strings
merges_serializable = {f"{k[0]},{k[1]}": v for k, v in merges.items()}

with open("merges.json", "w") as f:
    json.dump(merges_serializable, f)

In [10]:
import json

with open("merges.json", "r") as f:
    loaded_merges = json.load(f)

# Convert string keys back to tuples
merges_reloaded = {tuple(map(int, k.split(','))): v for k, v in loaded_merges.items()}

# Verify that the reloaded merges match the original
print(merges == merges_reloaded)

True
