### Custom regex for BPE tokenization

In [1]:
# wanna handle en, hi, kn languages.
tests = [
    "आज तो बहुत थक गया हूँ, ಸ್ವಲ್ಪ विश्रಾಂತಿ ಬೇಕು।",
    "मौसम कितना अच्छा है! ನೀವೂ ಹೊರಗೆ ಬನ್ನಿ, let's enjoy together.",
    "स्वल्पा adjust करो, बैंगलोर का ट्रैफिक ऐसा ही है।",
    "ನೀವು ಚಹಾ ಕುಡಿತೀರಾ? मुझे एक cup चाहिए।",
    "आज का काम पूरा करो, ನಾಳೆ ಎಲ್ಲಿಂದ ಆರಂಭಿಸೋದು ನೋಡಿ।",
    "ಪಾರ್ಟಿ ಹೇಗೆ ಇತ್ತು? मुझे तो बहुत मजा आया!",
    "ನಮ್ಮ ಚೂರು ಸಹನಶೀಲತೆಯನ್ನು ತೋರಿಸಿ, ये थोड़ी देर का मसला है।",
    "ಸಮಯ ನಿಲ್ಲುತ್ತಿಲ್ಲ, जिंदगी में स्वल्पा मज़ा भी जरूरी है।",
    "My name is Jeff Bezos, and I'm the owner of Amazon.<|endoftext|>"
]

In [2]:
import regex as re

def test_regex(regex, text):
    print(text)
    print(re.findall(regex, text))
    print()

# default gpt4 regex
pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
regex = re.compile(pattern)

for test in tests:
    test_regex(regex, test)

आज तो बहुत थक गया हूँ, ಸ್ವಲ್ಪ विश्रಾಂತಿ ಬೇಕು।
['आज', ' त', 'ो', ' बह', 'ुत', ' थक', ' गय', 'ा', ' ह', 'ूँ,', ' ಸ', '್ವಲ', '್ಪ', ' व', 'िश', '्र', 'ಾಂ', 'ತ', 'ಿ', ' ಬ', 'ೇಕ', 'ು।']

मौसम कितना अच्छा है! ನೀವೂ ಹೊರಗೆ ಬನ್ನಿ, let's enjoy together.
['म', 'ौसम', ' क', 'ितन', 'ा', ' अच', '्छ', 'ा', ' ह', 'ै!', ' ನ', 'ೀವ', 'ೂ', ' ಹ', 'ೊರಗ', 'ೆ', ' ಬನ', '್ನ', 'ಿ,', ' let', "'s", ' enjoy', ' together', '.']

स्वल्पा adjust करो, बैंगलोर का ट्रैफिक ऐसा ही है।
['स', '्वल', '्प', 'ा', ' adjust', ' कर', 'ो,', ' ब', 'ैं', 'गल', 'ोर', ' क', 'ा', ' ट', '्र', 'ैफ', 'िक', ' ऐस', 'ा', ' ह', 'ी', ' ह', 'ै।']

ನೀವು ಚಹಾ ಕುಡಿತೀರಾ? मुझे एक cup चाहिए।
['ನ', 'ೀವ', 'ು', ' ಚಹ', 'ಾ', ' ಕ', 'ುಡ', 'ಿತ', 'ೀರ', 'ಾ?', ' म', 'ुझ', 'े', ' एक', ' cup', ' च', 'ाह', 'िए', '।']

आज का काम पूरा करो, ನಾಳೆ ಎಲ್ಲಿಂದ ಆರಂಭಿಸೋದು ನೋಡಿ।
['आज', ' क', 'ा', ' क', 'ाम', ' प', 'ूर', 'ा', ' कर', 'ो,', ' ನ', 'ಾಳ', 'ೆ', ' ಎಲ', '್ಲ', 'ಿಂ', 'ದ', ' ಆರ', 'ಂಭ', 'ಿಸ', 'ೋದ', 'ು', ' ನ', 'ೋಡ', 'ಿ।']

ಪಾರ್ಟಿ ಹೇಗೆ ಇತ್ತು? मुझे तो बहुत मजा आया!
['ಪ', 'ಾರ', '್

In [3]:
# custom regex
pattern = r"""(?i) 's|'t|'re|'ve|'m|'ll|'d| ?\b[\p{L}\u0900-\u0963|\u0966-\u097F]+\b| ?\b[\p{L}\u0C80-\u0C9E|\u0CA0-\u0CFF]+\b| ?[\p{N}]+| ?[.,!?;:'\"-]| ?[\u0964-\u0965]| ?[\u0C9E-\u0C9F]| ?[^\s\p{L}\p{N}\u0900-\u097F\u0C80-\u0CFF]+| \s+(?!\S)| \s+"""
regex = re.compile(pattern)

for test in tests:
    test_regex(regex, test)

आज तो बहुत थक गया हूँ, ಸ್ವಲ್ಪ विश्रಾಂತಿ ಬೇಕು।
['आज', ' तो', ' बहुत', ' थक', ' गया', ' हूँ', ',', ' ಸ್ವಲ್ಪ', ' ಬೇಕು', '।']

मौसम कितना अच्छा है! ನೀವೂ ಹೊರಗೆ ಬನ್ನಿ, let's enjoy together.
['मौसम', ' कितना', ' अच्छा', ' है', '!', ' ನೀವೂ', ' ಹೊರಗೆ', ' ಬನ್ನಿ', ',', ' let', "'", 's', ' enjoy', ' together', '.']

स्वल्पा adjust करो, बैंगलोर का ट्रैफिक ऐसा ही है।
['स्वल्पा', ' adjust', ' करो', ',', ' बैंगलोर', ' का', ' ट्रैफिक', ' ऐसा', ' ही', ' है', '।']

ನೀವು ಚಹಾ ಕುಡಿತೀರಾ? मुझे एक cup चाहिए।
['ನೀವು', ' ಚಹಾ', ' ಕುಡಿತೀರಾ', '?', ' मुझे', ' एक', ' cup', ' चाहिए', '।']

आज का काम पूरा करो, ನಾಳೆ ಎಲ್ಲಿಂದ ಆರಂಭಿಸೋದು ನೋಡಿ।
['आज', ' का', ' काम', ' पूरा', ' करो', ',', ' ನಾಳೆ', ' ಎಲ್ಲಿಂದ', ' ಆರಂಭಿಸೋದು', ' ನೋಡಿ', '।']

ಪಾರ್ಟಿ ಹೇಗೆ ಇತ್ತು? मुझे तो बहुत मजा आया!
['ಪಾರ್ಟಿ', ' ಹೇಗೆ', ' ಇತ್ತು', '?', ' मुझे', ' तो', ' बहुत', ' मजा', ' आया', '!']

ನಮ್ಮ ಚೂರು ಸಹನಶೀಲತೆಯನ್ನು ತೋರಿಸಿ, ये थोड़ी देर का मसला है।
['ನಮ್ಮ', ' ಚೂರು', ' ಸಹನಶೀಲತೆಯನ್ನು', ' ತೋರಿಸಿ', ',', ' ये', ' थोड़ी', ' देर', ' का', ' मसला', ' है', '।']

ಸಮಯ ನಿಲ

### Dataset

In [None]:
from datasets import load_dataset, load_from_disk
import random
import os

# Load a subset of a dataset and cache it locally
def download_dataset_and_load_subset(dataset_name, name=None, data_dir=None, split="train", num_rows=1000, save_dir="/home/yaseen/hf_datasets"):
    # Create a directory path that includes the dataset name and split
    dataset_path = f"{dataset_name.replace('/', '_')}_{name.replace('/', '_') if name else ''}_{data_dir.replace('/', '_') if data_dir else ''}_{split}"
    save_path = os.path.join(save_dir, dataset_path)
    
    try:
        dataset = load_from_disk(save_path)
        print(f"Dataset {dataset_name} loaded from {save_path}..")
    
    except FileNotFoundError:
        # Load the dataset and cache it in the specified directory
        dataset = load_dataset(
            dataset_name, 
            name=name, 
            data_dir=data_dir, 
            split=split
        )
        
        # Ensure the base save directory exists
        os.makedirs(save_dir, exist_ok=True)
        
        # Save the dataset to disk
        dataset.save_to_disk(save_path)
        print(f"Dataset {dataset_name} saved to {save_path}..")
    
    # Select a random subset of the dataset
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    selected_indices = indices[:num_rows]
    return dataset.select(selected_indices)

# Load subsets of each dataset with caching
dataset_en = download_dataset_and_load_subset("HuggingFaceFW/fineweb-edu", name="sample-10BT", num_rows=600)
dataset_hin_deva = download_dataset_and_load_subset("ai4bharat/sangraha", data_dir="synthetic/hin_Deva", num_rows=200)
dataset_kan_knda = download_dataset_and_load_subset("ai4bharat/sangraha", data_dir="synthetic/kan_Knda", num_rows=200)

# Print the first row of each subset to verify
print(dataset_en[0])
print(dataset_hin_deva[0])
print(dataset_kan_knda[0])

In [None]:
# Concatenate all texts into a single list
all_texts = []

# Collect texts from each dataset
en_texts = [doc["text"].strip().replace("\n", " ") for doc in dataset_en]
hin_deva_texts = [doc["text"].strip().replace("\n", " ") for doc in dataset_hin_deva]
kan_knda_texts = [doc["text"].strip().replace("\n", " ") for doc in dataset_kan_knda]

# Add all texts to a single list
all_texts.extend(en_texts)
all_texts.extend(hin_deva_texts)
all_texts.extend(kan_knda_texts)

# Shuffle the combined texts
random.shuffle(all_texts)

print(f"Total number of texts: {len(all_texts)}")
print("\nFirst few texts after random shuffling:")
print(all_texts[:3])

corpus = "\n".join(all_texts)

# Save the combined texts to a file
with open("tok_corpus_small.txt", "w", encoding="utf-8") as file:
    file.write(corpus)

### Utility functions

In [4]:
import unicodedata

def replace_control_characters(s: str) -> str:
    # we don't want to print control characters
    # which distort the output (e.g. \n or much worse)
    # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
    # http://www.unicode.org/reports/tr44/#GC_Values_Table
    chars = []
    for ch in s:
        if unicodedata.category(ch)[0] != "C":
            chars.append(ch) # this character is ok
        else:
            chars.append(f"\\u{ord(ch):04x}") # escape
    return "".join(chars)

def render_token(t: bytes) -> str:
    # pretty print a token, escaping control characters
    s = t.decode('utf-8', errors='replace')
    s = replace_control_characters(s)
    return s

# utility function to visualise tokens
def visualise_tokens(token_values: list[bytes], end="\n") -> None:
    background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]]
    # If token boundaries do not occur at unicode character boundaries, it's unclear how best to
    # visualise the token. Here, we'll just use the unicode replacement character to represent some
    # fraction of a character.
    unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values]

    running_length = 0
    last_color = None
    for token in unicode_token_values:
        color = background[running_length % len(background)]
        if color == last_color:
            color = background[(running_length + 1) % len(background)]
            assert color != last_color
        last_color = color
        running_length += len(token)
        print(color + token, end="")
    print("\u001b[0m", end=end)

### Training BPE

In [5]:
# Read the corpus file
with open("tok_corpus_small.txt", "r", encoding="utf-8") as file:
    corpus = file.read()

# Get character and byte lengths
char_length = len(corpus)
byte_length = len(corpus.encode('utf-8'))

print(f"Corpus length in characters: {char_length:,}")
print(f"Corpus length in bytes: {byte_length:,}")


Corpus length in characters: 4,182,226
Corpus length in bytes: 6,164,138


In [6]:
from tqdm import tqdm

def get_stats(ids, freq):
    for pair in zip(ids[:-1], ids[1:]):
        freq[pair] = freq.get(pair, 0) + 1

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

# ---
# tokens = corpus.encode('utf-8')
vocab_size = 3256 # the desired final vocabulary size
num_merges = vocab_size - 256
# ids = list(tokens) # copy so we don't destroy the original list

text_chunks = re.findall(regex, corpus)
tokens = [list(ch.encode("utf-8")) for ch in text_chunks]
ids = [list(ch.encode("utf-8")) for ch in text_chunks]

merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)}

for i in tqdm(range(num_merges), desc="Training BPE", unit="merge"):
    stats = {}
    for chunk_ids in ids:
        get_stats(chunk_ids, stats)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
    merges[pair] = idx
    vocab[idx] = vocab[pair[0]] + vocab[pair[1]]

tok_len = 0
for chunk_toks in tokens:
    tok_len += len(chunk_toks)
print("tokens length:", tok_len)

ids_len = 0
for chunk_ids in ids:
    ids_len += len(chunk_ids)
print("ids length:", ids_len)

print(f"compression ratio: {tok_len / ids_len:.2f}X")

Training BPE: 100%|██████████| 3000/3000 [56:08<00:00,  1.12s/merge]  

tokens length: 6142083
ids length: 1507403
compression ratio: 4.07X





In [8]:
def _build_vocab():
    vocab = {idx: bytes([idx]) for idx in range(256)}
    for (p0, p1), idx in merges.items():
        vocab[idx] = vocab[p0] + vocab[p1]
    return vocab

special_tokens = {'<|endoftext|>': len(vocab)}

def save(file_prefix):
    """
    Saves two files: file_prefix.vocab and file_prefix.model
    This is inspired (but not equivalent to!) sentencepiece's model saving:
    - model file is the critical one, intended for load()
    - vocab file is just a pretty printed version for human inspection only
    """
    # write the model: to be used in load() later
    model_file = file_prefix + ".model"
    with open(model_file, 'w') as f:
        # write the version, pattern and merges, that's all that's needed
        f.write("minbpe v1\n")
        f.write(f"{pattern}\n")
        # write the special tokens, first the number of them, then each one
        f.write(f"{len(special_tokens)}\n")
        for special, idx in special_tokens.items():
            f.write(f"{special} {idx}\n")
        # the merges dict
        for idx1, idx2 in merges:
            f.write(f"{idx1} {idx2}\n")
    # write the vocab: for the human to look at
    vocab_file = file_prefix + ".vocab"
    inverted_merges = {idx: pair for pair, idx in merges.items()}
    with open(vocab_file, "w", encoding="utf-8") as f:
        for idx, token in vocab.items():
            # note: many tokens may be partial utf-8 sequences
            # and cannot be decoded into valid strings. Here we're using
            # errors='replace' to replace them with the replacement char.
            # this also means that we couldn't possibly use .vocab in load()
            # because decoding in this way is a lossy operation!
            s = render_token(token)
            # find the children of this token, if any
            if idx in inverted_merges:
                # if this token has children, render it nicely as a merge
                idx0, idx1 = inverted_merges[idx]
                s0 = render_token(vocab[idx0])
                s1 = render_token(vocab[idx1])
                f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
            else:
                # otherwise this is leaf token, just print it
                # (this should just be the first 256 tokens, the bytes)
                f.write(f"[{s}] {idx}\n")

save("bpe_tok")

def load(model_file):
    """Inverse of save() but only for the model file"""
    assert model_file.endswith(".model")
    # read the model file
    merges = {}
    special_tokens = {}
    idx = 256
    with open(model_file, 'r', encoding="utf-8") as f:
        # read the version
        version = f.readline().strip()
        assert version == "minbpe v1"
        # read the pattern
        pattern = f.readline().strip()
        # read the special tokens
        num_special = int(f.readline().strip())
        for _ in range(num_special):
            special, special_idx = f.readline().strip().split()
            special_tokens[special] = int(special_idx)
        # read the merges
        for line in f:
            idx1, idx2 = map(int, line.split())
            merges[(idx1, idx2)] = idx
            idx += 1
    vocab = _build_vocab()
    return merges, special_tokens, vocab

merges, special_tokens, vocab = load("bpe_tok.model")

In [9]:
def decode(ids) -> str:
  part_bytes = []
  for id in ids:
      if id in vocab:
          part_bytes.append(vocab[id]) # id can be > 256 after merging
      elif id in special_tokens:
          part_bytes.append(special_tokens[id])
      else:
          raise ValueError(f"id={id} not in vocab or special_tokens")
  text_bytes = b"".join(part_bytes)
  text = text_bytes.decode(encoding="utf-8", errors="replace")
  return text

print(decode([128]))

�


In [10]:
merges

{(224, 178): 256,
 (224, 164): 257,
 (224, 179): 258,
 (32, 257): 259,
 (224, 165): 260,
 (32, 256): 261,
 (32, 116): 262,
 (258, 141): 263,
 (263, 256): 264,
 (32, 97): 265,
 (104, 101): 266,
 (105, 110): 267,
 (256, 191): 268,
 (257, 190): 269,
 (114, 101): 270,
 (268, 256): 271,
 (262, 266): 272,
 (32, 111): 273,
 (258, 129): 274,
 (260, 135): 275,
 (101, 114): 276,
 (257, 176): 277,
 (32, 115): 278,
 (256, 190): 279,
 (260, 141): 280,
 (279, 256): 281,
 (97, 116): 282,
 (111, 110): 283,
 (32, 119): 284,
 (110, 100): 285,
 (32, 99): 286,
 (257, 191): 287,
 (101, 110): 288,
 (101, 115): 289,
 (105, 115): 290,
 (280, 257): 291,
 (287, 257): 292,
 (259, 149): 293,
 (258, 134): 294,
 (105, 116): 295,
 (256, 176): 296,
 (111, 114): 297,
 (269, 257): 298,
 (32, 98): 299,
 (32, 102): 300,
 (32, 112): 301,
 (97, 110): 302,
 (101, 100): 303,
 (273, 102): 304,
 (267, 103): 305,
 (97, 108): 306,
 (111, 117): 307,
 (97, 114): 308,
 (257, 130): 309,
 (32, 267): 310,
 (32, 109): 311,
 (265, 285):

In [14]:
def _encode_chunk(chunk_bytes: bytes, verbose=False) -> list[int]:
    tokens = list(chunk_bytes)
    while len(tokens) >= 2:
        if verbose:
            visualise_tokens([vocab[token] for token in tokens]) # token can be > 256 after merging
        stats = {}
        get_stats(tokens, stats)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        if not pair in merges:
            break
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens

def encode_ordinary(text, verbose=False) -> list[int]:
    chunk_texts = re.findall(regex, text)
    ids_list = []
    for i, text in enumerate(chunk_texts):
        if verbose:
            print()
            print(f"encoding chunk {i+1}/{len(chunk_texts)}: {text}")
        chunk_bytes = text.encode("utf-8") # raw bytes
        ids = _encode_chunk(chunk_bytes, verbose)
        ids_list.extend(ids)
    return ids_list

def encode(text, verbose=False, allowed_special="none") -> list[int]:
    special = {}
    if allowed_special == "all":
        special = special_tokens
    elif allowed_special == "none":
        special = {}
    elif allowed_special == "none_raise":
        special = {}
        assert all(token not in text for token in special_tokens), "Text contains special tokens that are not allowed"
    elif isinstance(allowed_special, set):
        special = {k: v for k, v in special_tokens.items() if k in allowed_special}
    else:
        raise ValueError(f"allowed_special={allowed_special} not understood.")
    if not special:
        return encode_ordinary(text, verbose)
    special_pattern = "(" + "|".join(re.escape(token) for token in special) + ")"
    parts = re.split(special_pattern, text)
    ids = []
    for part in parts:
        if part in special:
            ids.append(special[part])
        else:
            ids.extend(encode_ordinary(part, verbose))
    return ids

In [23]:
for test in tests:
    print(test)
    print(encode(test, allowed_special={"<|endoftext|>"}))
    print()

आज तो बहुत थक गया हूँ, ಸ್ವಲ್ಪ विश्रಾಂತಿ ಬೇಕು।
[2637, 665, 666, 320, 443, 441, 2459, 531, 332, 752, 374, 435, 257, 129, 44, 1919, 2649, 462, 1462, 274, 389]

मौसम कितना अच्छा है! ನೀವೂ ಹೊರಗೆ ಬನ್ನಿ, let's enjoy together.
[392, 872, 369, 392, 293, 577, 873, 458, 725, 291, 155, 269, 466, 33, 1268, 321, 486, 2662, 933, 462, 383, 268, 44, 2127, 39, 115, 678, 106, 1302, 2510, 46]

स्वल्पा adjust करो, बैंगलोर का ट्रैफिक ऐसा ही है।
[369, 692, 354, 1156, 269, 756, 106, 648, 631, 320, 44, 443, 379, 1538, 354, 1751, 620, 877, 2627, 840, 510, 259, 144, 369, 269, 2581, 466, 389]

ನೀವು ಚಹಾ ಕುಡಿತೀರಾ? मुझे एक cup चाहिए।
[317, 464, 750, 833, 596, 279, 2112, 381, 488, 464, 296, 279, 63, 345, 580, 157, 275, 583, 286, 1989, 776, 1378, 642, 389]

आज का काम पूरा करो, ನಾಳೆ ಎಲ್ಲಿಂದ ಆರಂಭಿಸೋದು ನೋಡಿ।
[2637, 665, 620, 3169, 3092, 277, 269, 631, 320, 44, 424, 2810, 294, 555, 418, 764, 1986, 1529, 406, 428, 3004, 424, 2826, 268, 389]

ಪಾರ್ಟಿ ಹೇಗೆ ಇತ್ತು? मुझे तो बहुत मजा आया!
[515, 417, 526, 268, 452, 384, 933, 603, 4

In [25]:
def test_encoding(text, verbose=False, allowed_special=None):
    print(f"Text: {text}")
    test_ids = encode(text, verbose=verbose, allowed_special=allowed_special)
    print(f"Tokens: {test_ids}")
    print("")
    print(f"Unmerged length: {len(text.encode('utf-8'))}")
    print(f"Merged length: {len(test_ids)}")
    print("-"*50)

for test in tests:
    test_encoding(test, allowed_special="all")

Text: आज तो बहुत थक गया हूँ, ಸ್ವಲ್ಪ विश्रಾಂತಿ ಬೇಕು।
Tokens: [2637, 665, 666, 320, 443, 441, 2459, 531, 332, 752, 374, 435, 257, 129, 44, 1919, 2649, 462, 1462, 274, 389]

Unmerged length: 117
Merged length: 21
--------------------------------------------------
Text: मौसम कितना अच्छा है! ನೀವೂ ಹೊರಗೆ ಬನ್ನಿ, let's enjoy together.
Tokens: [392, 872, 369, 392, 293, 577, 873, 458, 725, 291, 155, 269, 466, 33, 1268, 321, 486, 2662, 933, 462, 383, 268, 44, 2127, 39, 115, 678, 106, 1302, 2510, 46]

Unmerged length: 120
Merged length: 31
--------------------------------------------------
Text: स्वल्पा adjust करो, बैंगलोर का ट्रैफिक ऐसा ही है।
Tokens: [369, 692, 354, 1156, 269, 756, 106, 648, 631, 320, 44, 443, 379, 1538, 354, 1751, 620, 877, 2627, 840, 510, 259, 144, 369, 269, 2581, 466, 389]

Unmerged length: 117
Merged length: 28
--------------------------------------------------
Text: ನೀವು ಚಹಾ ಕುಡಿತೀರಾ? मुझे एक cup चाहिए।
Tokens: [317, 464, 750, 833, 596, 279, 2112, 381, 488, 464, 296, 279, 63