In [1]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [2]:
import regex as re
re.findall(PAT, "Hello, world!")

['Hello', ',', ' world', '!']

In [4]:
text = "Some text that I'll pre-tokenize"
re.findall(PAT, text)

['Some', ' text', ' that', ' I', "'ll", ' pre', '-', 'tokenize']

In [5]:
for match in re.finditer(PAT, text):
    print(f"Match: '{match.group()}' at positions {match.start()}–{match.end()}")

Match: 'Some' at positions 0–4
Match: ' text' at positions 4–9
Match: ' that' at positions 9–14
Match: ' I' at positions 14–16
Match: ''ll' at positions 16–19
Match: ' pre' at positions 19–23
Match: '-' at positions 23–24
Match: 'tokenize' at positions 24–32


In [10]:
bytes('9', 'utf-8')

b'9'

In [19]:
bytes([9, 10, 12])

b'\t\n\x0c'

In [11]:
bytes(9)

b'\x00\x00\x00\x00\x00\x00\x00\x00\x00'

In [6]:
# Initial vocabulary: 256 single-byte tokens
vocab = {bytes([i]): i for i in range(256)}
print(vocab[bytes([8])])

In [19]:
from collections import defaultdict

# Step 1: Define initial vocabulary (single-byte tokens)
vocab = {bytes([i]): i for i in range(256)}

special_tokens = ["<|endoftext|>"] # ["<|endoftext|>", "<|pad|>"]
for token in special_tokens:
    vocab[token] = len(vocab)

# Step 2: Pre-tokenize and convert to bytes
text = "low low low low low lower lower widest widest widest newest newest newest newest newest newest"

pre_tokens = text.strip().split()  # Simple pre-tokenization
pre_tokens_bytes = [token.encode('utf-8') for token in pre_tokens]

In [20]:
print(pre_tokens)

pre_token_counts = defaultdict(int)
for t in pre_tokens:
    pre_token_counts[t] +=1

print(pre_token_counts)

['low', 'low', 'low', 'low', 'low', 'lower', 'lower', 'widest', 'widest', 'widest', 'newest', 'newest', 'newest', 'newest', 'newest', 'newest']
defaultdict(<class 'int'>, {'low': 5, 'lower': 2, 'widest': 3, 'newest': 6})


In [21]:
def decode_token(token):
    if isinstance(token, int):
        # Single-byte integer token
        return bytes([token]).decode("utf-8", errors='replace')
    elif isinstance(token, tuple):
        # Tuple token, decode recursively
        return "".join(decode_token(st) for st in token)
    else:
        # Special tokens, represented as strings
        return str(token)

In [22]:
most_freq_pair = ((65, 66), 67)  # AB + C
print(decode_token(most_freq_pair))

ABC


In [23]:
def calc_pair_counts(tokens):
    pair_counts = defaultdict(int)
    for token in tokens:
        for i in range(len(token) - 1):
            pair = (token[i], token[i + 1])
            pair_counts[pair] += 1

    return pair_counts

pair_counts = calc_pair_counts(pre_tokens_bytes)
print(pair_counts)

defaultdict(<class 'int'>, {(108, 111): 7, (111, 119): 7, (119, 101): 8, (101, 114): 2, (119, 105): 3, (105, 100): 3, (100, 101): 3, (101, 115): 9, (115, 116): 9, (110, 101): 6, (101, 119): 6})


In [24]:
num_merges = 6

next_token_id = len(vocab)
print("Next tooken id = ", next_token_id)

tokens = pre_tokens_bytes
for merge_idx in range(num_merges):
    pair_counts = calc_pair_counts(tokens)
    if not pair_counts:
        break
    
    # Find most frequent pair (lexicographically greatest on ties)
    most_freq_pair = max(pair_counts, key=lambda p: (pair_counts[p], decode_token(p)))
    print(f"Merge #{merge_idx + 1}: {most_freq_pair}, Count: {pair_counts[most_freq_pair]}")
    t1 = decode_token(most_freq_pair[0])
    t2 = decode_token(most_freq_pair[1])
    print(f"{t1} {t2}, Count: {pair_counts[most_freq_pair]}")

    vocab[most_freq_pair] = next_token_id
    next_token_id += 1

    new_tokens = []
    for t in tokens:
        new_t = []
        i = 0
        while i < len(t):
            if i < (len(t) - 1) and (t[i], t[i + 1]) == most_freq_pair:
                new_t.append(most_freq_pair)
                i += 2
            else:
                new_t.append(t[i])
                i += 1

        new_tokens.append(new_t)
    tokens = new_tokens

Next tooken id =  257
Merge #1: (115, 116), Count: 9
s t, Count: 9
Merge #2: (101, (115, 116)), Count: 9
e st, Count: 9
Merge #3: (111, 119), Count: 7
o w, Count: 7
Merge #4: (108, (111, 119)), Count: 7
l ow, Count: 7
Merge #5: (119, (101, (115, 116))), Count: 6
w est, Count: 6
Merge #6: (110, 101), Count: 6
n e, Count: 6


In [25]:
print(len(pre_tokens_bytes))
print(len(tokens))
print(len(vocab))

16
16
263


In [2]:
from multiprocessing import Pool, cpu_count

num_processes = cpu_count()
print(num_processes)

32


In [2]:
import re

special_tokens = ["<|endoftext|>", "<|pad|>"]
split_pattern = "|".join(re.escape(token) for token in special_tokens)

print(split_pattern)

<\|endoftext\|>|<\|pad\|>


In [3]:
text = "Hello<|endoftext|>world<|pad|>example"
docs = re.split(split_pattern, text)

print(docs)

['Hello', 'world', 'example']
