<a href="https://colab.research.google.com/github/ajayrfhp/LearningDeepLearning/blob/main/bytepairencoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bytepair encoding
- Goal is to reimplement bytepair encoding from scratch and have output match with the tiktoken library

## GPT2 tokenizer

In [6]:
import tiktoken

input_str = "Hello, world!"
gpt2_tokenizer = tiktoken.encoding_for_model("gpt-2")
gpt2_tokenized = gpt2_tokenizer.encode(input_str)
gpt2_tokenized_intermediate = [gpt2_tokenizer.decode([token]) for token in gpt2_tokenized]
gpt2_tokenized_decoded = gpt2_tokenizer.decode(gpt2_tokenized)

print("Input string:", input_str)
print("GPT-2 tokenized:", gpt2_tokenized)
print("GPT-2 tokenized intermediate:", gpt2_tokenized_intermediate)
print("GPT-2 tokenized decoded:", gpt2_tokenized_decoded)





Input string: Hello, world!
GPT-2 tokenized: [15496, 11, 995, 0]
GPT-2 tokenized intermediate: ['Hello', ',', ' world', '!']
GPT-2 tokenized decoded: Hello, world!


- GPT 2 tokenizer has 50,000 merges and 256 unicode characters

In [7]:
gpt2_tokenizer.n_vocab

50257

## Implementation from scratch

Prepare vocab

In [8]:
from collections import defaultdict

max_vocab_size = 50257


def init_vocab():
  vocab = {}
  vocab_size = 0

  for chr_code in range(256):
    vocab[chr_code] = chr(chr_code)
    vocab_size += 1
  return vocab, vocab_size

vocab, vocab_size = init_vocab()
len(vocab)

256

Get pair count

In [21]:
def get_pair_count(text, counts=defaultdict(int)):
  most_frequent_pair = None
  most_frequent_pair_count = 0
  for i in range(len(text)-1):
    pair = tuple(text[i:i+2])
    counts[pair] += 1
    if counts[pair] > most_frequent_pair_count:
      most_frequent_pair = pair
      most_frequent_pair_count = counts[pair]
  return counts, most_frequent_pair, most_frequent_pair_count

text = "Hello, world!"
text_encoded = text.encode('utf-8')
pair_count, most_frequent_pair, most_frequent_pair_count = get_pair_count(text_encoded)
most_frequent_pair

(72, 101)

- Merge common tokens

In [42]:
# Add symbol for most frequent pair in vocab and run encoding again to replace most frequent pair with new symbol.

def merge(text_encoded, pair, idx):
  i = 0
  text_encoded_merged = []
  while i < len(text_encoded) - 1:
    if not text_encoded[i]:
      continue
    elif text_encoded[i] == pair[0] and text_encoded[i+1] == pair[1]:
      text_encoded_merged.append(idx)
      i += 2
    else:
      text_encoded_merged.append(text_encoded[i])
      i += 1
  return text_encoded_merged

text_encoded=[1, 2, 3, 4, 5, 5, 1, 2, 9,  9, 1, 2]
print(len(text_encoded))
text_encoded = merge(
    text_encoded,
    pair=[1, 2],
    idx=10
)
print(len(text_encoded))


12
9


In [32]:
from collections import Counter
def get_pair_countv2(text):
    pair_iterator = zip(text, text[1:])
    counts = Counter(pair_iterator)
    most_frequent_pair = counts.most_common(1)
    return counts, most_frequent_pair[0][0], most_frequent_pair[0][1]

text = "Hello, world!"
text_encoded = text.encode('utf-8')
pair_count, most_frequent_pair, most_frequent_pair_count = get_pair_countv2(text_encoded)
most_frequent_pair

(72, 101)

In [33]:
def mergev2(text_encoded, pair, idx):
    i = 0
    new_text_encoded = []
    while i < len(text_encoded):
        if i+1 < len(text_encoded) and text_encoded[i] == pair[0] and text_encoded[i+1] == pair[1]:
            new_text_encoded.append(idx)
            i += 2
        else:
            new_text_encoded.append(text_encoded[i])
            i += 1
    del text_encoded
    return new_text_encoded

mergev2(
    text_encoded=[1, 2, 3, 4, 5, 5, 1, 2, 1, 2],
    pair=[1, 2],
    idx=10
)

[10, 3, 4, 5, 5, 10, 10]

- Had copilot rewrite the above functions in cython

In [15]:
%load_ext Cython

In [34]:
%%cython

from collections import defaultdict
import cython # Allows using @cython decorators if needed, and type hints

# cpdef makes the function available to Python and optimized for C calls
# We type inputs/outputs. Assuming 'data' is bytes as in BPE context.
# Returns a Python tuple: (dict, tuple | None, int)
cpdef tuple get_paircountv3(list data):

    # --- C Type Declarations ---
    cdef Py_ssize_t i, n
    cdef int p0, p1
    cdef int count, max_count = 0
    cdef tuple pair_key
    cdef tuple max_pair = None
    counts = defaultdict(int)
    # ---------------------------

    n = len(data)

    if n < 2:
        return counts, None, 0

    # --- Counting Loop ---
    for i in range(n - 1):
        p0 = data[i]
        p1 = data[i+1]
        pair_key = (p0, p1)
        counts[pair_key] += 1
    # ---------------------

    # --- Find Maximum After Loop ---
    # Iterate using the standard .items() method
    for pair_key, count in counts.items(): # CORRECTED LINE
        if count > max_count:
            max_count = count
            max_pair = pair_key
    # ---------------------------

    return counts, max_pair, max_count

In [35]:
%%cython

# Import necessary types if needed (often optional for basic types)
# cimport cython # Uncomment if using @cython decorators later

# Use 'cpdef' for a function callable from both Python and C (Cython) code
# Add type declarations for variables using 'cdef' and in the signature
# We assume inputs are Python lists/int, output is Python list
# Typing loop variables (i, n) and known types (idx) gives most benefit here
cpdef list mergev3(list original_text_encoded, tuple pair_to_replace, int replacement_idx):
    # --- C variable declarations ---
    cdef Py_ssize_t i = 0  # Py_ssize_t is preferred for indexing
    cdef Py_ssize_t n = len(original_text_encoded)
    cdef list new_list = [] # Output remains a standard Python list
    # Assume pair elements are integers for comparison
    # Type checking happens when accessing pair_to_replace[0]
    cdef int p0 = pair_to_replace[0]
    cdef int p1 = pair_to_replace[1]
    # ------------------------------

    while i < n:
        # Accessing list elements (original_text_encoded[i]) still involves
        # Python object overhead as it's a Python list.
        # For max speed, inputs would ideally be memoryviews or arrays.
        if i + 1 < n and original_text_encoded[i] == p0 and original_text_encoded[i+1] == p1:
            new_list.append(replacement_idx) # Append the typed int
            i += 2
        else:
            new_list.append(original_text_encoded[i]) # Append existing Python object
            i += 1

    # No need for 'del text_encoded' as the original list wasn't modified
    return new_list

Grab big text

In [55]:
import requests
import regex as re
big_text_url = "https://raw.githubusercontent.com/dscape/spell/refs/heads/master/test/resources/big.txt"

big_text = requests.get(big_text_url).text
big_text = big_text[:1000000]
gpt2_pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
compiled_pattern = re.compile(gpt2_pattern)
big_text = re.findall(compiled_pattern, big_text)
big_text_encoded = [ list(chunk.encode("utf-8")) for chunk in big_text]

len(big_text_encoded)

220288

In [56]:
big_text_encoded[0]

[84, 104, 101]

- Let's do 10 merges and profile

In [57]:
%load_ext line_profiler


def profile_merge(vocab, big_text_encoded_local, vocab_size):
  num_merges = vocab_size - 256
  for i in range(num_merges):
    most_frequent_pair = None 
    most_frequent_pair_count = 0
    counts = defaultdict(int)
    for chunk in big_text_encoded_local:
      counts, chunk_most_frequent_pair, chunk_most_frequent_pair_count = get_pair_count(chunk, counts)
      if chunk_most_frequent_pair_count > most_frequent_pair_count:
        most_frequent_pair = chunk_most_frequent_pair
        most_frequent_pair_count = chunk_most_frequent_pair_count
    if most_frequent_pair:
      big_text_encoded_local= [merge(chunk, most_frequent_pair, vocab_size) for chunk in big_text_encoded_local]
      vocab_size += 1
      vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
      print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")

  return vocab


vocab, vocab_size = init_vocab()

%lprun -f profile_merge profile_merge(vocab, big_text_encoded.copy(), vocab_size=276)


The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler
merge 0 (32, 116) to  t
merge 1 (276, 104) to Ĕh
merge 2 (32, 97) to  a
merge 3 (32, 115) to  s
merge 4 (32, 99) to  c
merge 5 (32, 32) to   
merge 6 (32, 112) to  p
merge 7 (32, 105) to  i
merge 8 (281, 281) to ęę
merge 9 (284, 284) to ĜĜ
merge 10 (280, 111) to Ęo
merge 11 (32, 67) to  C
merge 12 (32, 77) to  M
merge 13 (32, 114) to  r
merge 14 (32, 82) to  R
merge 15 (32, 117) to  u
merge 16 (10, 10) to 


merge 17 (32, 100) to  d


Timer unit: 1e-09 s

Total time: 11.2314 s
File: /tmp/ipykernel_2290/1040820679.py
Function: profile_merge at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
     4                                           def profile_merge(vocab, big_text_encoded_local, vocab_size):
     5         1       1300.0   1300.0      0.0    num_merges = vocab_size - 256
     6        21      15700.0    747.6      0.0    for i in range(num_merges):
     7        20       5600.0    280.0      0.0      most_frequent_pair = None 
     8        20       3300.0    165.0      0.0      most_frequent_pair_count = 0
     9        20     266100.0  13305.0      0.0      counts = defaultdict(int)
    10   4405780  654318800.0    148.5      5.8      for chunk in big_text_encoded_local:
    11   4405760 4655131100.0   1056.6     41.4        counts, chunk_most_frequent_pair, chunk_most_frequent_pair_count = get_pair_count(chunk, counts)
    12   4405760  659914100.0    149.8      5.9        if chunk_m

In [None]:
len(big_text_encoded)

In [45]:
%load_ext line_profiler


def profile_mergev2(vocab, big_text_encoded_local, vocab_size):
  num_merges = vocab_size - 256
  for i in range(num_merges):
    _, most_frequent_pair, _ = get_pair_countv2(big_text_encoded_local)
    big_text_encoded_local = mergev2(big_text_encoded_local, most_frequent_pair, vocab_size)
    vocab_size += 1
    vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
    print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")

vocab, vocab_size = init_vocab()
%lprun -f profile_mergev2 profile_mergev2(vocab, big_text_encoded.copy(), vocab_size=276)


The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


TypeError: 'bytes' object cannot be interpreted as an integer

In [None]:
len(big_text_encoded)

In [None]:
%load_ext line_profiler


def profile_mergev3(vocab, big_text_encoded_local, vocab_size):
  num_merges = vocab_size - 256
  for i in range(num_merges):
    _, most_frequent_pair, _ = get_paircountv3(big_text_encoded_local)
    big_text_encoded_local = mergev3(big_text_encoded_local, most_frequent_pair, vocab_size)
    vocab_size += 1
    vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
    if i % 100 == 0:
      print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")
  return vocab
  

vocab, vocab_size = init_vocab()
%lprun -f profile_mergev3 profile_mergev3(vocab, big_text_encoded.copy(), vocab_size=276)


In [None]:
vocab, vocab_size = init_vocab()
vocab = profile_mergev3(vocab, big_text_encoded, vocab_size=1257)

In [None]:
reverse_vocab = {v: k for k, v in vocab.items()}
max_token_size = max(map(len, reverse_vocab.keys()))
max_token_size

In [None]:
def encode(text, reverse_vocab):
  i = 0
  text_encoded = []
  while i < len(text):
    for j in range(max_token_size, 0, -1):
      potential_token = text[i:i+j]
      if potential_token in reverse_vocab:
        text_encoded.append(reverse_vocab[potential_token])
        i += j
        break
  return text_encoded


def decode(text_encoded, vocab):
  text = ""
  text_list = []
  for code in text_encoded:
    text += vocab[code]
    text_list.append(vocab[code])
  return text, text_list

encoded_text = encode("Hello this is Ajay", reverse_vocab)
print(encoded_text)
decoded_text, decoded_text_list = decode(encoded_text, vocab)
print(decoded_text)
print(decoded_text_list)
