<a href="https://colab.research.google.com/github/ajayrfhp/LearningDeepLearning/blob/main/bytepairencoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bytepair encoding
- Goal is to reimplement bytepair encoding from scratch and have output match with the tiktoken library

## GPT2 tokenizer

In [1]:
import tiktoken

input_str = "Hello, world!"
gpt2_tokenizer = tiktoken.encoding_for_model("gpt-2")
gpt2_tokenized = gpt2_tokenizer.encode(input_str)
gpt2_tokenized_intermediate = [gpt2_tokenizer.decode([token]) for token in gpt2_tokenized]
gpt2_tokenized_decoded = gpt2_tokenizer.decode(gpt2_tokenized)

print("Input string:", input_str)
print("GPT-2 tokenized:", gpt2_tokenized)
print("GPT-2 tokenized intermediate:", gpt2_tokenized_intermediate)
print("GPT-2 tokenized decoded:", gpt2_tokenized_decoded)





Input string: Hello, world!
GPT-2 tokenized: [15496, 11, 995, 0]
GPT-2 tokenized intermediate: ['Hello', ',', ' world', '!']
GPT-2 tokenized decoded: Hello, world!


- GPT 2 tokenizer has 50,000 merges and 256 unicode characters

In [2]:
gpt2_tokenizer.n_vocab

50257

## Implementation from scratch

Prepare vocab

In [3]:
from collections import defaultdict

max_vocab_size = 50257

vocab = {}
vocab_size = 0

for chr_code in range(256):
  vocab[chr_code] = chr(chr_code)
  vocab_size += 1


len(vocab)

256

Get pair count

In [4]:
def get_pair_count(text):
  counts = defaultdict(int)
  most_frequent_pair = None
  most_frequent_pair_count = 0
  for i in range(len(text)-1):
    pair = tuple(text[i:i+2])
    counts[pair] += 1
    if counts[pair] > most_frequent_pair_count:
      most_frequent_pair = pair
      most_frequent_pair_count = counts[pair]
  return counts, most_frequent_pair, most_frequent_pair_count

text = "Hello, world!"
text_encoded = text.encode('utf-8')
pair_count, most_frequent_pair, most_frequent_pair_count = get_pair_count(text_encoded)
most_frequent_pair

(72, 101)

In [5]:
from collections import Counter
def get_pair_countv2(text):
    pair_iterator = zip(text, text[1:])
    counts = Counter(pair_iterator)
    most_frequent_pair = counts.most_common(1)
    return counts, most_frequent_pair[0][0], most_frequent_pair[0][1]

text = "Hello, world!"
text_encoded = text.encode('utf-8')
pair_count, most_frequent_pair, most_frequent_pair_count = get_pair_countv2(text_encoded)
most_frequent_pair

(72, 101)

In [9]:
text_encoded

b'Hello, world!'

- Merge common tokens

In [7]:
# Add symbol for most frequent pair in vocab and run encoding again to replace most frequent pair with new symbol.

def merge(text_encoded, pair, idx):
  i = 0

  while i < len(text_encoded) - 1:
    if text_encoded[i] == pair[0] and text_encoded[i+1] == pair[1]:
      text_encoded[i] = idx
      text_encoded.pop(i+1)
    else:
      i += 1
  return text_encoded

merge(
    text_encoded=[1, 2, 3, 4, 5, 5],
    pair=[1, 2],
    idx=10
)


[10, 3, 4, 5, 5]

In [8]:
def mergev2(text_encoded, pair, idx):
    i = 0
    new_text_encoded = []
    while i < len(text_encoded):
        if i+1 < len(text_encoded) and text_encoded[i] == pair[0] and text_encoded[i+1] == pair[1]:
            new_text_encoded.append(idx)
            i += 2
        else:
            new_text_encoded.append(text_encoded[i])
            i += 1
    del text_encoded
    return new_text_encoded

mergev2(
    text_encoded=[1, 2, 3, 4, 5, 5],
    pair=[1, 2],
    idx=10
)

[10, 3, 4, 5, 5]

- Had copilot rewrite the above functions in cython

In [11]:
%load_ext Cython

In [27]:
%%cython

from collections import defaultdict
import cython # Allows using @cython decorators if needed, and type hints

# cpdef makes the function available to Python and optimized for C calls
# We type inputs/outputs. Assuming 'data' is bytes as in BPE context.
# Returns a Python tuple: (dict, tuple | None, int)
cpdef tuple get_paircountv3(list data):

    # --- C Type Declarations ---
    cdef Py_ssize_t i, n
    cdef int p0, p1
    cdef int count, max_count = 0
    cdef tuple pair_key
    cdef tuple max_pair = None
    counts = defaultdict(int)
    # ---------------------------

    n = len(data)

    if n < 2:
        return counts, None, 0

    # --- Counting Loop ---
    for i in range(n - 1):
        p0 = data[i]
        p1 = data[i+1]
        pair_key = (p0, p1)
        counts[pair_key] += 1
    # ---------------------

    # --- Find Maximum After Loop ---
    # Iterate using the standard .items() method
    for pair_key, count in counts.items(): # CORRECTED LINE
        if count > max_count:
            max_count = count
            max_pair = pair_key
    # ---------------------------

    return counts, max_pair, max_count

In [32]:
%%cython

# Import necessary types if needed (often optional for basic types)
# cimport cython # Uncomment if using @cython decorators later

# Use 'cpdef' for a function callable from both Python and C (Cython) code
# Add type declarations for variables using 'cdef' and in the signature
# We assume inputs are Python lists/int, output is Python list
# Typing loop variables (i, n) and known types (idx) gives most benefit here
cpdef list mergev3(list original_text_encoded, tuple pair_to_replace, int replacement_idx):
    # --- C variable declarations ---
    cdef Py_ssize_t i = 0  # Py_ssize_t is preferred for indexing
    cdef Py_ssize_t n = len(original_text_encoded)
    cdef list new_list = [] # Output remains a standard Python list
    # Assume pair elements are integers for comparison
    # Type checking happens when accessing pair_to_replace[0]
    cdef int p0 = pair_to_replace[0]
    cdef int p1 = pair_to_replace[1]
    # ------------------------------

    while i < n:
        # Accessing list elements (original_text_encoded[i]) still involves
        # Python object overhead as it's a Python list.
        # For max speed, inputs would ideally be memoryviews or arrays.
        if i + 1 < n and original_text_encoded[i] == p0 and original_text_encoded[i+1] == p1:
            new_list.append(replacement_idx) # Append the typed int
            i += 2
        else:
            new_list.append(original_text_encoded[i]) # Append existing Python object
            i += 1

    # No need for 'del text_encoded' as the original list wasn't modified
    return new_list

Grab big text

In [18]:
import requests
big_text_url = "https://raw.githubusercontent.com/dscape/spell/refs/heads/master/test/resources/big.txt"

big_text = requests.get(big_text_url).text
big_text = big_text[:1000000]
big_text_encoded = list(big_text.encode('utf-8'))
len(big_text_encoded)

1000000

- Let's do 10 merges and profile

In [19]:
%load_ext line_profiler


def profile_merge(big_text_encoded_local, vocab_size):
  num_merges = vocab_size - 256
  for i in range(num_merges):
    _, most_frequent_pair, _ = get_pair_count(big_text_encoded_local)
    merge(big_text_encoded_local, most_frequent_pair, idx=vocab_size)
    vocab_size += 1
    vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
    print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")

%lprun -f profile_merge profile_merge(big_text_encoded, vocab_size=276)


merge 0 (101, 32) to e 
merge 1 (116, 104) to th
merge 2 (100, 32) to d 
merge 3 (115, 32) to s 
merge 4 (116, 32) to t 
merge 5 (105, 110) to in
merge 6 (101, 114) to er
merge 7 (97, 110) to an
merge 8 (44, 32) to , 
merge 9 (277, 276) to ĕĔ
merge 10 (111, 110) to on
merge 11 (121, 32) to y 
merge 12 (101, 110) to en
merge 13 (111, 117) to ou
merge 14 (111, 32) to o 
merge 15 (102, 32) to f 
merge 16 (111, 114) to or
merge 17 (46, 32) to . 
merge 18 (101, 278) to eĖ
merge 19 (111, 291) to oģ


Timer unit: 1e-09 s

Total time: 35.5335 s
File: /tmp/ipykernel_18169/4291962890.py
Function: profile_merge at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
     4                                           def profile_merge(big_text_encoded_local, vocab_size):
     5         1       1100.0   1100.0      0.0    num_merges = vocab_size - 256
     6        21      16400.0    781.0      0.0    for i in range(num_merges):
     7        20 9578063600.0    5e+08     27.0      _, most_frequent_pair, _ = get_pair_count(big_text_encoded_local)
     8        20        3e+10    1e+09     73.0      merge(big_text_encoded_local, most_frequent_pair, idx=vocab_size)
     9        20      24500.0   1225.0      0.0      vocab_size += 1
    10        20     178900.0   8945.0      0.0      vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
    11        20    2750300.0 137515.0      0.0      print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")

In [20]:
len(big_text_encoded)

769057

In [29]:
%load_ext line_profiler


def profile_mergev2(big_text_encoded_local, vocab_size):
  num_merges = vocab_size - 256
  for i in range(num_merges):
    _, most_frequent_pair, _ = get_pair_countv2(big_text_encoded_local)
    big_text_encoded_local = mergev2(big_text_encoded_local, most_frequent_pair, vocab_size)
    vocab_size += 1
    vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
    print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")

%lprun -f profile_mergev2 profile_mergev2(big_text_encoded, vocab_size=276)


The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler
merge 0 (97, 114) to ar
merge 1 (32, 32) to   
merge 2 (114, 101) to re
merge 3 (283, 278) to ěĖ
merge 4 (116, 105) to ti
merge 5 (97, 280) to aĘ
merge 6 (116, 290) to tĢ
merge 7 (281, 103) to ęg
merge 8 (283, 32) to ě 
merge 9 (97, 108) to al
merge 10 (104, 105) to hi
merge 11 (115, 116) to st
merge 12 (97, 32) to a 
merge 13 (10, 10) to 


merge 14 (97, 279) to aė
merge 15 (281, 32) to ę 
merge 16 (282, 32) to Ě 
merge 17 (101, 115) to es
merge 18 (286, 32) to Ğ 
merge 19 (111, 109) to om


Timer unit: 1e-09 s

Total time: 9.10709 s
File: /tmp/ipykernel_18169/1458884573.py
Function: profile_mergev2 at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
     4                                           def profile_mergev2(big_text_encoded_local, vocab_size):
     5         1        600.0    600.0      0.0    num_merges = vocab_size - 256
     6        21      17500.0    833.3      0.0    for i in range(num_merges):
     7        20 1376918400.0    7e+07     15.1      _, most_frequent_pair, _ = get_pair_countv2(big_text_encoded_local)
     8        20 7727178200.0    4e+08     84.8      big_text_encoded_local = mergev2(big_text_encoded_local, most_frequent_pair, vocab_size)
     9        20      31700.0   1585.0      0.0      vocab_size += 1
    10        20     228600.0  11430.0      0.0      vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
    11        20    2717900.0 135895.0      0.0      print(f"merge {i} {most_frequent_pair} to {vocab[vocab

In [30]:
len(big_text_encoded)

769057

In [33]:
%load_ext line_profiler


def profile_mergev3(big_text_encoded_local, vocab_size):
  num_merges = vocab_size - 256
  for i in range(num_merges):
    _, most_frequent_pair, _ = get_paircountv3(big_text_encoded_local)
    big_text_encoded_local = mergev3(big_text_encoded_local, most_frequent_pair, vocab_size)
    vocab_size += 1
    vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
    print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")

%lprun -f profile_mergev3 profile_mergev3(big_text_encoded, vocab_size=276)


The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler
merge 0 (97, 114) to ar
merge 1 (32, 32) to   
merge 2 (114, 101) to re
merge 3 (283, 278) to ěĖ
merge 4 (116, 105) to ti
merge 5 (97, 280) to aĘ
merge 6 (116, 290) to tĢ
merge 7 (281, 103) to ęg
merge 8 (283, 32) to ě 
merge 9 (97, 108) to al
merge 10 (104, 105) to hi
merge 11 (115, 116) to st
merge 12 (97, 32) to a 
merge 13 (10, 10) to 


merge 14 (97, 279) to aė
merge 15 (281, 32) to ę 
merge 16 (282, 32) to Ě 
merge 17 (101, 115) to es
merge 18 (286, 32) to Ğ 
merge 19 (111, 109) to om


Timer unit: 1e-09 s

Total time: 1.92666 s
File: /tmp/ipykernel_18169/3404624777.py
Function: profile_mergev3 at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
     4                                           def profile_mergev3(big_text_encoded_local, vocab_size):
     5         1        500.0    500.0      0.0    num_merges = vocab_size - 256
     6        21      17100.0    814.3      0.0    for i in range(num_merges):
     7        20 1709547400.0    9e+07     88.7      _, most_frequent_pair, _ = get_paircountv3(big_text_encoded_local)
     8        20  215326800.0    1e+07     11.2      big_text_encoded_local = mergev3(big_text_encoded_local, most_frequent_pair, vocab_size)
     9        20      50100.0   2505.0      0.0      vocab_size += 1
    10        20     224100.0  11205.0      0.0      vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
    11        20    1492200.0  74610.0      0.1      print(f"merge {i} {most_frequent_pair} to {vocab[vocab_

In [None]:
reverse_vocab = {v: k for k, v in vocab.items()}
max_token_size = max(map(len, reverse_vocab.keys()))
max_token_size

In [None]:
def encode(text, reverse_vocab):
  i = 0
  text_encoded = []
  while i < len(text):
    for j in range(max_token_size, 0, -1):
      potential_token = text[i:i+j]
      if potential_token in reverse_vocab:
        text_encoded.append(reverse_vocab[potential_token])
        i += j
        break
  return text_encoded


def decode(text_encoded, vocab):
  text = ""
  for code in text_encoded:
    text += vocab[code]
  return text

encoded_text = encode("Hello this is Ajay", reverse_vocab)
print(encoded_text)
decoded_text = decode(encoded_text, vocab)
print(decoded_text)
