In [None]:
import torch
import jax.numpy as jnp
from jax import grad
from jax import jit
import trio
import tiktoken
import regex as re

# Byte Pair Encoding Tokenization

Based off of Andrej Kapathy YouTube Video "Let's Build GPT Tokenizer"

``` Place holder for BPE Description```

In [2]:
text = "This is a test string. It has no meaning and serves for testing purposes only. I'll also add the common sentence with every letter next. The quick brown fox jumps over the lazy dog."
encoded_input = text.encode("utf-8")
encoded_input = list(map(int, encoded_input))

In [4]:
def find_pairs(tokens):
    counts = {}

    for pair in zip(tokens, tokens[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

In [None]:
pairs = find_pairs(encoded_input)

most_freq_pair = max(pairs, key = pairs.get)

def merge_most_frequent(id_list, pair, new_token):
    tokens = []
    i = 0
    while i < len(id_list):
        if i < len(id_list) - 1 and id_list[i] == pair[0] and id_list[i+1] == pair[1]:
            tokens.append(new_token)
            i += 2
        else:
            tokens.append(id_list[i])
            i += 1
    return tokens

# test
print(merge_most_frequent([5, 6, 6, 7, 9, 1], (6, 7), 99))

In [None]:
vocab_size = 276
num_merges = vocab_size - 256
id_list = list(encoded_input) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = find_pairs(id_list)
  pair = max(stats, key=stats.get)
  new_token = 256 + i
  print(f"merging {pair} into a new token {new_token}")
  ids = merge_most_frequent(id_list, pair, new_token)
  merges[pair] = new_token

In [None]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
  # given ids (list of integers), return Python string
  tokens = b"".join(vocab[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

print(decode([128]))

In [None]:
def encode(text):
  # given a string, return list of integers (the tokens)
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = find_pairs(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break # nothing else can be merged
    idx = merges[pair]
    tokens = merge_most_frequent(tokens, pair, idx)
  return tokens

print(encode(""))

In [None]:
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

print(re.findall(gpt2pat, "Hello've world123 how's are you!!!?"))

In [None]:
# GPT-2 (does not merge spaces)
enc = tiktoken.get_encoding("gpt2")
print(enc.encode("    hello world!!!"))

# GPT-4 (merges spaces)
enc = tiktoken.get_encoding("cl100k_base")
print(enc.encode("    hello world!!!"))