```py
character -> Unicode (utf-8) -> Byte
```

```py
aaabdaaabac
ZabdZabac    # Z=aa
ZYdZYac      # Y=ab
XdXac        # X=ZY
```

In [1]:
with open('../data/cs.txt', 'r') as f:
  text = f.read()

tokens = list(text.encode('utf-8'))

In [2]:
# returns byte pairs with their counts
def get_stats(tokens):
  toks = {}
  for pair in zip(tokens, tokens[1:]):
    toks[pair] = toks.get(pair, 0) + 1
  return toks

# merge byte pairs and replace them with new idx
def merge(ids, pair, idx):
  i = 0
  new = []
  while i < len(ids):
    if i < len(ids)-1 and (ids[i], ids[i+1]) == pair:
      new.append(idx)
      i += 2
    else:
      new.append(ids[i])
      i += 1
  return new

# bytes -> pair -> bytes (compressed)
def get_merges(tokens, num_merges):
  # tokens -> keep merging -> new_tokens
  merges = {}
  for i in range(num_merges):
    stats = get_stats(tokens)
    pair = max(stats, key=stats.get)
    idx = 256+i
    tokens = merge(tokens, pair, idx)
    merges[pair] = idx
  return tokens, merges

def create_vocab(merges):
  vocab = {i:bytes([i]) for i in range(256)}
  for (p0, p1), v in merges.items():
    vocab[v] = vocab[p0] + vocab[p1]
  return vocab

In [3]:
vocab_sz = 2500
num_merges = vocab_sz - 256

new_toks, merges = get_merges(tokens, num_merges)
vocab = create_vocab(merges)
f"compression ratio {(100 - ((len(new_toks)/len(tokens)) * 100)):.2f}%"

'compression ratio 77.80%'

In [4]:
vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [5]:
# output -> decode -> text
def decode(tokens):
  toks = b''.join(vocab[i] for i in tokens)
  return toks.decode('utf-8', errors='replace')

# text -> utf-8 -> encode -> tokens
def encode(text):
  tokens = list(text.encode('utf-8'))
  while len(tokens) >= 2: # fix for encode("a") single character encoding
    stats = get_stats(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break
    idx = merges[pair]
    tokens = merge(tokens, pair, idx)
  return tokens

In [6]:
# testing
text2 = decode(encode('b'))
print(text2 == 'b')

text2 = decode(encode('bob the builder'))
print(text2 == 'bob the builder')

valtext = "Anomalous tokens: a mysterious failure mode for GPT (which reliably insulted Matthew) We have found a set of anomalous tokens which result in a previously undocumented failure mode for GPT-2 and GPT-3 models. (The 'instruct' models “are particularly deranged” in this context, as janus has observed.) Many of these tokens reliably break determinism in the OpenAI GPT-3 playground at temperature 0 (which theoretically shouldn't happen)."
valtext2 = decode(encode(valtext))
print(valtext2 == valtext)

True
True
True


In [7]:
# nextup, improve this with regex pattern pre-training

In [26]:
import regex as re

LLAMA3_SPLIT_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""

compiled_pattern = re.compile(LLAMA3_SPLIT_PATTERN)
text_chunks = compiled_pattern.findall(text)
tokens = [list(i.encode('utf-8')) for i in text_chunks]

# get_merges(tokens, 50)

yes
