## Wordy Motivation

A Byte Pair Encoding (BPE) tokenizer splits text into subword units based on the most frequent character pairs in a corpus,\
allowing it to balance vocabulary size and represent rare words efficiently. It starts with individual bytes or characters \
and repeatedly merges the most common adjacent pairs into new tokens until a fixed vocabulary size is reached.

Modern BPE tokenizers used for training and inference in large language models typically apply regex-based pretokenization. \
This step splits text into linguistically or visually meaningful chunks (like words or punctuation groups), preventing merges \
across token boundaries that could produce spurious or misleading tokens (e.g., treating "dog" and "dog!" as entirely different tokens).

Regex pretokenization also enables more efficient frequency counting: if a word like “text” appears 10 times, we can increment pair counts (like 't','e') by 10 directly.\
When a merge occurs (e.g., 't','e','x','t' → 'te','x','t'), only the keys in the frequency dictionary need to be updated — the total count remains the same, \
which simplifies and speeds up the BPE merge step.

In [2]:
import regex as re

In [None]:
# from here: https://github.com/openai/tiktoken/pull/234/files
# Using this pattern re.finditer will produce one pretoken per group
GPT2_SPLIT_PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
pattern = re.compile(GPT2_SPLIT_PAT)

## Pieces of Training

In [18]:
# mix of languages + emoji
text = "Привет, world! 😄 Let's go, 今日は."

In [29]:
# a bit larger piece of text
text = """Alice looked at the glowing sign: “Добро пожаловать!” — it blinked beneath a line of Chinese characters: 欢迎光临.

She typed quickly: `hello_世界123! :)` — mixing English, symbols, digits, and emojis into her message.  
The response came instantly: "Принято. ✅"  
She smiled, whispered «行吧», and pressed Send.
"""

In [42]:
# and another one from cs336 assignment
text = """ low low low low low lower lower widest widest widest newest newest newest newest newest newest"""

In [None]:
pattern.findall(text)

[' low',
 ' low',
 ' low',
 ' low',
 ' low',
 ' lower',
 ' lower',
 ' widest',
 ' widest',
 ' widest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest']

In [187]:
# as described above this dictionary will keep pretoken counts
# tuple(bytes('Привет', encoding='utf-8')) -> integer represented byte sequence for convenience
pretokens = dict()
for mt in pattern.finditer(text):
  pt = mt.group() # -> str; match will have one pretoken per group
  pt = tuple(pt.encode('utf-8'))
  pretokens[pt] = pretokens.get(pt, 0) + 1

next_ix = 256

In [188]:
merges = [int.to_bytes(i) for i in range(256)]

In [252]:
for pt, cnt in pretokens.items():
  print([merges[i] for i in pt], ':', cnt)

[b' ', b'w', b'i', b'd', b'est'] : 3
[b' low'] : 5
[b' low', b'e', b'r'] : 2
[b' ', b'ne', b'west'] : 6


In [253]:
pretokens

{(32, 119, 105, 100, 257): 3, (260,): 5, (260, 101, 114): 2, (32, 262, 261): 6}

In [254]:
# we need to iterate through pretokens to find pair frequencies
pair_counts = dict()
for pt, cnt in pretokens.items():
  for p in zip(pt, pt[1:]):
    pair_counts[p] = pair_counts.get(p, 0) + cnt

In [255]:
for p, cnt in pair_counts.items():
  print(merges[p[0]], '+' ,merges[p[1]], ":", cnt)

b' ' + b'w' : 3
b'w' + b'i' : 3
b'i' + b'd' : 3
b'd' + b'est' : 3
b' low' + b'e' : 2
b'e' + b'r' : 2
b' ' + b'ne' : 6
b'ne' + b'west' : 6


In [256]:
# find most frequent pair, ties resolved in lexicographical order
top_pair, top_cnt = max(pair_counts.items(), key=lambda it: [it[1], it[0]])

In [257]:
top_pair

(262, 261)

In [258]:
print(merges[top_pair[0]], ',', merges[top_pair[1]], '->', top_cnt)

b'ne' , b'west' -> 6


In [259]:
# merge `pair` to become `new_ix` if it's in the `seq`
def merge(seq, pair, new_ix):
  new_seq = []
  i = 0
  while i < len(seq):
    # check in range and if match
    if i+1 < len(seq) and (seq[i], seq[i+1]) == pair:
      new_seq.append(new_ix)
      i += 2 # correct step
    else:
      new_seq.append(seq[i]) # only current position
      i += 1
  return tuple(new_seq)

In [260]:
# Each merge introduces a new token (pair → new token) that wasn’t in the vocabulary before
# Pretoken keys are sequences of current tokens.
# Until you merge ('t', 'e') into 'te', there's no way 'te' appears as a unit inside any key
# Only keys that contain the exact pair ('t', 'e') in adjacent positions will be modified.
# The output of merge() depends deterministically on the input key.
# Therefore, at most one original key can produce any given new_pt in the merge step.
for pt in list(pretokens): # static copy of keys (prevents RuntimeError if we iterate original dict)
  new_pt = merge(pt, top_pair, next_ix)
  if new_pt != pt: # update only if we merged new index
    # even though we proved it can't happen (see above), we want this assertions and perhaps test against it
    # so we are sure not to mess up with implementation
    assert new_pt not in pretokens, f"Collision: {new_pt} already in pretokens"
    pretokens[new_pt] = pretokens.pop(pt) #  safe from key collisions under the BPE merge assumptions (see above)
  
# update merges
merges.append(merges[top_pair[0]] + merges[top_pair[1]])
next_ix += 1

In [262]:
# we can take a look into newly formed tokens
for i, bp in enumerate(merges[256:], 256):
  print(i, '->', bp)

256 -> b'st'
257 -> b'est'
258 -> b'ow'
259 -> b'low'
260 -> b' low'
261 -> b'west'
262 -> b'ne'
263 -> b'newest'


## Let's Put It All Together

In [None]:
# merge `pair` to become `new_ix` if it's in the `seq`
def merge(seq, pair, new_ix):
  new_seq = []
  i = 0
  while i < len(seq):
    # check in range and if match
    if i+1 < len(seq) and (seq[i], seq[i+1]) == pair:
      new_seq.append(new_ix)
      i += 2 # correct step
    else:
      new_seq.append(seq[i]) # only current position
      i += 1
  return tuple(new_seq)

In [263]:
merges = [int.to_bytes(i) for i in range(256)]

# as described above this dictionary will keep pretoken counts
# tuple(bytes('Привет', encoding='utf-8')) -> integer represented byte sequence for convenience
pretokens = dict()
for mt in pattern.finditer(text):
  pt = mt.group() # -> str; match will have one pretoken per group
  pt = tuple(pt.encode('utf-8'))
  pretokens[pt] = pretokens.get(pt, 0) + 1

next_ix = 256
num_merges = 10

In [264]:
sep = "==================================="
for _ in range(num_merges):
  print(sep)
  # show pretokens
  for pt, cnt in pretokens.items():
    print([merges[i] for i in pt], ':', cnt)
  
  # we need to iterate through pretokens to find pair frequencies
  pair_counts = dict()
  for pt, cnt in pretokens.items():
    for p in zip(pt, pt[1:]):
      pair_counts[p] = pair_counts.get(p, 0) + cnt
  # find most frequent pair, ties resolved in lexicographical order
  top_pair, top_cnt = max(pair_counts.items(), key=lambda it: [it[1], it[0]])
  print("top pair", merges[top_pair[0]], ',', merges[top_pair[1]], '->', top_cnt)
  
  # merge
  for pt in list(pretokens): # static copy of keys (prevents RuntimeError if we iterate original dict)
    new_pt = merge(pt, top_pair, next_ix)
    if new_pt != pt: # update only if we merged new index
      # even though we proved it can't happen (see above), we want this assertions and perhaps test against it
      # so we are sure not to mess up with implementation
      assert new_pt not in pretokens, f"Collision: {new_pt} already in pretokens"
      pretokens[new_pt] = pretokens.pop(pt) #  safe from key collisions under the BPE merge assumptions (see above)
  
  # update merges
  merges.append(merges[top_pair[0]] + merges[top_pair[1]])
  next_ix += 1

print(sep)
# we can take a look into newly formed tokens
for i, bp in enumerate(merges[256:], 256):
  print(i, '->', bp)

[b' ', b'l', b'o', b'w'] : 5
[b' ', b'l', b'o', b'w', b'e', b'r'] : 2
[b' ', b'w', b'i', b'd', b'e', b's', b't'] : 3
[b' ', b'n', b'e', b'w', b'e', b's', b't'] : 6
top pair b's' , b't' -> 9
[b' ', b'l', b'o', b'w'] : 5
[b' ', b'l', b'o', b'w', b'e', b'r'] : 2
[b' ', b'w', b'i', b'd', b'e', b'st'] : 3
[b' ', b'n', b'e', b'w', b'e', b'st'] : 6
top pair b'e' , b'st' -> 9
[b' ', b'l', b'o', b'w'] : 5
[b' ', b'l', b'o', b'w', b'e', b'r'] : 2
[b' ', b'w', b'i', b'd', b'est'] : 3
[b' ', b'n', b'e', b'w', b'est'] : 6
top pair b'o' , b'w' -> 7
[b' ', b'w', b'i', b'd', b'est'] : 3
[b' ', b'n', b'e', b'w', b'est'] : 6
[b' ', b'l', b'ow'] : 5
[b' ', b'l', b'ow', b'e', b'r'] : 2
top pair b'l' , b'ow' -> 7
[b' ', b'w', b'i', b'd', b'est'] : 3
[b' ', b'n', b'e', b'w', b'est'] : 6
[b' ', b'low'] : 5
[b' ', b'low', b'e', b'r'] : 2
top pair b' ' , b'low' -> 7
[b' ', b'w', b'i', b'd', b'est'] : 3
[b' ', b'n', b'e', b'w', b'est'] : 6
[b' low'] : 5
[b' low', b'e', b'r'] : 2
top pair b'w' , b'est' -> 6
[b' 