<a href="https://colab.research.google.com/github/ajayrfhp/LearningDeepLearning/blob/main/bytepairencoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bytepair encoding
- Goal is to reimplement bytepair encoding from scratch and have output match with the tiktoken library

## GPT2 tokenizer

In [32]:
import tiktoken

input_str = "Hello, world!"
gpt2_tokenizer = tiktoken.encoding_for_model("gpt-2")
gpt2_tokenized = gpt2_tokenizer.encode(input_str)
gpt2_tokenized_intermediate = [gpt2_tokenizer.decode([token]) for token in gpt2_tokenized]
gpt2_tokenized_decoded = gpt2_tokenizer.decode(gpt2_tokenized)

print("Input string:", input_str)
print("GPT-2 tokenized:", gpt2_tokenized)
print("GPT-2 tokenized intermediate:", gpt2_tokenized_intermediate)
print("GPT-2 tokenized decoded:", gpt2_tokenized_decoded)





Input string: Hello, world!
GPT-2 tokenized: [15496, 11, 995, 0]
GPT-2 tokenized intermediate: ['Hello', ',', ' world', '!']
GPT-2 tokenized decoded: Hello, world!


- GPT 2 tokenizer has 50,000 merges and 256 unicode characters

In [33]:
gpt2_tokenizer.n_vocab

50257

## Implementation from scratch

Prepare vocab

In [34]:
from collections import defaultdict

max_vocab_size = 50257 


def init_vocab():
  vocab = {}
  vocab_size = 0

  for chr_code in range(256):
    vocab[chr_code] = chr(chr_code)
    vocab_size += 1
  return vocab, vocab_size

vocab, vocab_size = init_vocab()
len(vocab)

256

- Get pair count statistics
- Merge common tokens

In [35]:
def get_pair_count(text, counts=defaultdict(int)):
  most_frequent_pair = None
  most_frequent_pair_count = 0
  for i in range(len(text)-1):
    pair = tuple(text[i:i+2])
    counts[pair] += 1
    if counts[pair] > most_frequent_pair_count:
      most_frequent_pair = pair
      most_frequent_pair_count = counts[pair]
  return counts, most_frequent_pair, most_frequent_pair_count

text = "Hello, world!"
text_encoded = text.encode('utf-8')
pair_count, most_frequent_pair, most_frequent_pair_count = get_pair_count(text_encoded)
print(most_frequent_pair)

# Add symbol for most frequent pair in vocab and run encoding again to replace most frequent pair with new symbol.

def merge(text_encoded, pair, idx):
  i = 0
  text_encoded_merged = []
  while i < len(text_encoded):
    if not text_encoded[i]:
      continue
    elif i + 1 < len(text_encoded) and text_encoded[i] == pair[0] and text_encoded[i+1] == pair[1]:
      text_encoded_merged.append(idx)
      i += 2
    else:
      text_encoded_merged.append(text_encoded[i])
      i += 1
  return text_encoded_merged

text_encoded=[1, 2, 3, 4, 5, 5, 1, 2, 9,  9, 1, 2]
print(len(text_encoded))
text_encoded = merge(
    text_encoded,
    pair=[1, 2],
    idx=10
)
print(len(text_encoded))


(72, 101)
12
9


Grab big text

In [42]:
import requests
import regex as re
big_text_url = "https://raw.githubusercontent.com/dscape/spell/refs/heads/master/test/resources/big.txt"

big_text = requests.get(big_text_url).text
big_text = big_text[:1000000]
gpt2_pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
compiled_pattern = re.compile(gpt2_pattern)
big_text = re.findall(compiled_pattern, big_text)
big_text_encoded = [ list(chunk.encode("utf-8")) for chunk in big_text if chunk]
big_text_encoded = [item for item in big_text_encoded if len(item) > 0]

len(big_text_encoded)

220288

In [None]:
big_text_encoded[:10]

- Let's do 10 merges and profile

In [37]:
%load_ext line_profiler


def profile_merge(vocab, big_text_encoded_local, vocab_size):
  num_merges = vocab_size - 256
  i = 0
  while i < num_merges:
    most_frequent_pair = None 
    most_frequent_pair_count = 0
    counts = defaultdict(int)
    for chunk in big_text_encoded_local:
      counts, chunk_most_frequent_pair, chunk_most_frequent_pair_count = get_pair_count(chunk, counts)
      if chunk_most_frequent_pair_count > most_frequent_pair_count:
        most_frequent_pair = chunk_most_frequent_pair
        most_frequent_pair_count = chunk_most_frequent_pair_count
    
    if most_frequent_pair:
      big_text_encoded_local= [merge(chunk, most_frequent_pair, vocab_size) for chunk in big_text_encoded_local]
      print(big_text_encoded_local[:2])
      vocab_size += 1
      vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
      i += 1
      print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")
    else:
      print("No more pairs to merge.")
      print(chunk)
      break

  return vocab


vocab, vocab_size = init_vocab()

%lprun -f profile_merge profile_merge(vocab, big_text_encoded.copy(), vocab_size=276)


The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler
[[84, 104, 101], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 1 (32, 116) to  t
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 2 (104, 101) to he
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 3 (32, 97) to  a
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 4 (105, 110) to in
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 5 (276, 277) to Ĕĕ
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 6 (32, 111) to  o
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 7 (114, 101) to re
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 8 (32, 119) to  w
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 9 (32, 115) to  s
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 10 (101, 114) to er
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 11 (111, 110) to on
[[84, 277], [32, 80, 114, 111, 106, 101, 99, 116]]
merge 12 (110, 100

Timer unit: 1e-09 s

Total time: 44.1768 s
File: /tmp/ipykernel_1151/2353498619.py
Function: profile_merge at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
     4                                           def profile_merge(vocab, big_text_encoded_local, vocab_size):
     5         1       2100.0   2100.0      0.0    num_merges = vocab_size - 256
     6         1        500.0    500.0      0.0    i = 0
     7        21      13100.0    623.8      0.0    while i < num_merges:
     8        20       5900.0    295.0      0.0      most_frequent_pair = None 
     9        20       5700.0    285.0      0.0      most_frequent_pair_count = 0
    10        20    1175800.0  58790.0      0.0      counts = defaultdict(int)
    11   4405780  990247800.0    224.8      2.2      for chunk in big_text_encoded_local:
    12   4405760        2e+10   4274.7     42.6        counts, chunk_most_frequent_pair, chunk_most_frequent_pair_count = get_pair_count(chunk, counts)
    13   44057

- V2 with counter

In [None]:
print(len(big_text_encoded))

from collections import Counter
def get_pair_countv2(text, counts):
    pair_iterator = zip(text, text[1:])

    # update counts with pair_iterator
    for pair in pair_iterator:
        counts[pair] += 1

    most_frequent_pair = counts.most_common(1)
    return counts, most_frequent_pair[0][0], most_frequent_pair[0][1]

text = "Hello, world!"
text_encoded = text.encode('utf-8')
pair_count, most_frequent_pair, most_frequent_pair_count = get_pair_countv2(text_encoded, counts=Counter())
print(most_frequent_pair)

def mergev2(text_encoded, pair, idx):
    i = 0
    new_text_encoded = []
    while i < len(text_encoded):
        if i+1 < len(text_encoded) and text_encoded[i] == pair[0] and text_encoded[i+1] == pair[1]:
            new_text_encoded.append(idx)
            i += 2
        else:
            new_text_encoded.append(text_encoded[i])
            i += 1
    del text_encoded
    return new_text_encoded

mergev2(
    text_encoded=[1, 2, 3, 4, 5, 5, 1, 2, 1, 2],
    pair=[1, 2],
    idx=10
)

In [None]:
%load_ext line_profiler


def profile_mergev2(vocab, big_text_encoded_local, vocab_size):
  num_merges = vocab_size - 256
  i = 0
  while i < num_merges:
    most_frequent_pair = None 
    most_frequent_pair_count = 0
    counts = Counter()
    for chunk in big_text_encoded_local:
      counts, chunk_most_frequent_pair, chunk_most_frequent_pair_count = get_pair_countv2(chunk, counts)
      if chunk_most_frequent_pair_count > most_frequent_pair_count:
        most_frequent_pair = chunk_most_frequent_pair
        most_frequent_pair_count = chunk_most_frequent_pair_count
    if most_frequent_pair:
      big_text_encoded_local= [mergev2(chunk, most_frequent_pair, vocab_size) for chunk in big_text_encoded_local]
      vocab_size += 1
      vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
      i += 1
      print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")

  return vocab

vocab, vocab_size = init_vocab()
%lprun -f profile_mergev2 profile_mergev2(vocab, big_text_encoded.copy(), vocab_size=276)


- V3 with cython?

In [38]:
%load_ext cython


The cython extension is already loaded. To reload it, use:
  %reload_ext cython


In [46]:
%%cython
from collections import defaultdict
from libc.stdint cimport uint8_t

def get_paircountv3(list text, counts=None):
    if counts is None:
        counts = defaultdict(int)
    cdef int i
    cdef tuple most_frequent_pair = None
    cdef int most_frequent_pair_count = 0
    cdef tuple pair
    for i in range(len(text) - 1):
        pair = (text[i], text[i + 1])
        counts[pair] += 1
        if counts[pair] > most_frequent_pair_count:
            most_frequent_pair = pair
            most_frequent_pair_count = counts[pair]
    return counts, most_frequent_pair, most_frequent_pair_count

def mergev3(list text_encoded, tuple pair, int idx):
    cdef int i = 0
    cdef list text_encoded_merged = []
    while i < len(text_encoded):
        if i + 1 < len(text_encoded) and text_encoded[i] == pair[0] and text_encoded[i + 1] == pair[1]:
            text_encoded_merged.append(idx)
            i += 2
        else:
            text_encoded_merged.append(text_encoded[i])
            i += 1
    return text_encoded_merged

# Example usage
text_encoded = [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33]
pair_count, most_frequent_pair, most_frequent_pair_count = get_paircountv3(text_encoded)
print(most_frequent_pair)

text_encoded = [1, 2, 3, 4, 5, 5, 1, 2, 9, 9, 1, 2]
print(len(text_encoded))
text_encoded = mergev3(
    text_encoded,
    pair=(1, 2),
    idx=10
)
print(len(text_encoded))


(72, 101)
12
9


In [48]:
%load_ext line_profiler

print(len(big_text_encoded))
def profile_mergev3(vocab, big_text_encoded_local, vocab_size):
  num_merges = vocab_size - 256
  i = 0
  while i < num_merges:
    most_frequent_pair = None 
    most_frequent_pair_count = 0
    counts = Counter()
    for chunk in big_text_encoded_local:
      counts, chunk_most_frequent_pair, chunk_most_frequent_pair_count = get_paircountv3(chunk, counts)
      if chunk_most_frequent_pair_count > most_frequent_pair_count:
        most_frequent_pair = chunk_most_frequent_pair
        most_frequent_pair_count = chunk_most_frequent_pair_count
    if most_frequent_pair:
      big_text_encoded_local= [mergev3(chunk, most_frequent_pair, vocab_size) for chunk in big_text_encoded_local]
      vocab_size += 1
      vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
      i += 1
      print(f"merge {i} {most_frequent_pair} to {vocab[vocab_size]}")

  return vocab
  

vocab, vocab_size = init_vocab()
%lprun -f profile_mergev3 profile_mergev3(vocab, big_text_encoded.copy(), vocab_size=276)


The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler
220288
merge 1 (32, 116) to  t
merge 2 (104, 101) to he
merge 3 (32, 97) to  a
merge 4 (105, 110) to in
merge 5 (276, 277) to Ĕĕ
merge 6 (32, 111) to  o
merge 7 (114, 101) to re
merge 8 (32, 119) to  w
merge 9 (32, 115) to  s
merge 10 (101, 114) to er
merge 11 (111, 110) to on
merge 12 (110, 100) to nd
merge 13 (104, 97) to ha
merge 14 (111, 117) to ou
merge 15 (105, 115) to is
merge 16 (101, 100) to ed
merge 17 (105, 116) to it
merge 18 (281, 102) to ęf
merge 19 (101, 110) to en
merge 20 (32, 99) to  c


Timer unit: 1e-09 s

Total time: 16.3104 s
File: /tmp/ipykernel_1151/3884388159.py
Function: profile_mergev3 at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
     4                                           def profile_mergev3(vocab, big_text_encoded_local, vocab_size):
     5         1       1400.0   1400.0      0.0    num_merges = vocab_size - 256
     6         1        400.0    400.0      0.0    i = 0
     7        21      21100.0   1004.8      0.0    while i < num_merges:
     8        20       8700.0    435.0      0.0      most_frequent_pair = None 
     9        20      17700.0    885.0      0.0      most_frequent_pair_count = 0
    10        20    2627300.0 131365.0      0.0      counts = Counter()
    11   4405780 1187578500.0    269.6      7.3      for chunk in big_text_encoded_local:
    12   4405760 8176326000.0   1855.8     50.1        counts, chunk_most_frequent_pair, chunk_most_frequent_pair_count = get_paircountv3(chunk, counts)
    13   4405760

In [None]:
vocab, vocab_size = init_vocab()
vocab = profile_mergev3(vocab, big_text_encoded, vocab_size=1257)

In [None]:
reverse_vocab = {v: k for k, v in vocab.items()}
max_token_size = max(map(len, reverse_vocab.keys()))
max_token_size

In [None]:
def encode(text, reverse_vocab):
  i = 0
  text_encoded = []
  while i < len(text):
    for j in range(max_token_size, 0, -1):
      potential_token = text[i:i+j]
      if potential_token in reverse_vocab:
        text_encoded.append(reverse_vocab[potential_token])
        i += j
        break
  return text_encoded


def decode(text_encoded, vocab):
  text = ""
  text_list = []
  for code in text_encoded:
    text += vocab[code]
    text_list.append(vocab[code])
  return text, text_list

encoded_text = encode("Hello this is Ajay", reverse_vocab)
print(encoded_text)
decoded_text, decoded_text_list = decode(encoded_text, vocab)
print(decoded_text)
print(decoded_text_list)
