In [1]:
with open("../data/the_verdict.txt","r") as f:
    raw_text = f.read()

In [2]:
import regex as re
gpt4=re.compile(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""")

In [3]:
print(re.findall(gpt4,"Building TOkenizer of Gpt4!"))

['Building', ' TOkenizer', ' of', ' Gpt', '4', '!']


In [4]:
preprocessed=re.findall(gpt4,raw_text)

In [5]:
len(preprocessed)

4538

In [6]:
def get_stats(words):
    counts={}
    for ids in words:
        for pair in zip(ids,ids[1:]):
            counts[pair]=counts.get(pair,0)+1
    return counts

In [7]:
get_stats(['Building', ' TOkenizer', ' of', ' Gpt', '4', '!'])

{('B', 'u'): 1,
 ('u', 'i'): 1,
 ('i', 'l'): 1,
 ('l', 'd'): 1,
 ('d', 'i'): 1,
 ('i', 'n'): 1,
 ('n', 'g'): 1,
 (' ', 'T'): 1,
 ('T', 'O'): 1,
 ('O', 'k'): 1,
 ('k', 'e'): 1,
 ('e', 'n'): 1,
 ('n', 'i'): 1,
 ('i', 'z'): 1,
 ('z', 'e'): 1,
 ('e', 'r'): 1,
 (' ', 'o'): 1,
 ('o', 'f'): 1,
 (' ', 'G'): 1,
 ('G', 'p'): 1,
 ('p', 't'): 1}

In [178]:
def merge(ids, pair, idx):
  newids = []
  for id in ids:
    i = 0
    sub_new_id=[]
    while i < len(id):
      if i < len(id) - 1 and id[i] == pair[0] and id[i+1] == pair[1]:
        sub_new_id.append(idx)
        i += 2
      else:
        sub_new_id.append(id[i])
        i += 1
    newids.append(sub_new_id)
  return newids

In [179]:
ids=list(preprocessed)
utf_ids=[]
for id in ids:
    utf_ids.append(list(id.encode()))

In [194]:
num_merges = 1100
vocabulary_size=256
merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(utf_ids)
  try:
    pair = max(stats, key=stats.get)
  except Exception as e:
    break
  idx = 256 + i
  print(f"merging {pair} into a new token {idx}")
  utf_ids = merge(utf_ids, pair, idx)
  merges[pair] = idx

merging (1255, 729) into a new token 256
merging (256, 401) into a new token 257
merging (737, 120) into a new token 258
merging (258, 769) into a new token 259
merging (302, 274) into a new token 260
merging (409, 926) into a new token 261
merging (279, 558) into a new token 262
merging (262, 372) into a new token 263
merging (34, 78) into a new token 264
merging (264, 948) into a new token 265
merging (108, 496) into a new token 266
merging (265, 957) into a new token 267
merging (65, 272) into a new token 268
merging (610, 116) into a new token 269
merging (269, 676) into a new token 270
merging (468, 295) into a new token 271
merging (256, 264) into a new token 272
merging (272, 259) into a new token 273
merging (273, 462) into a new token 274
merging (274, 732) into a new token 275
merging (262, 364) into a new token 276
merging (288, 947) into a new token 277
merging (511, 274) into a new token 278
merging (448, 108) into a new token 279
merging (279, 415) into a new token 280
me

In [195]:
import regex as re

class GPT4Tokenizer:
    def __init__(self, merge):
        self.byte_to_token = {}
        self.token_to_byte = {}
        self.mergepair_to_byte = {}
        self.byte_to_mergepair = {}

        for i in range(0, 256):
            self.byte_to_token[i] = i
            self.token_to_byte[i] = i

        for pair, new_byte in merge.items():
            self.byte_to_mergepair[new_byte] = pair
            self.mergepair_to_byte[pair] = new_byte
            self.byte_to_token[new_byte] = new_byte
            self.token_to_byte[new_byte] = new_byte

        special_token_id = len(self.byte_to_token)
        self.byte_to_token["<|endoftext|>"] = special_token_id
        self.token_to_byte[special_token_id] = "<|endoftext|>"
        self.special_token_id = special_token_id

        self.pat = re.compile(
            r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
        )

    def encode(self, text, add_special_token=False):
        words = re.findall(self.pat, text)
        words_in_bytes = [list(word.encode("utf-8")) for word in words]

        encoded_tokens = []

        for word in words_in_bytes:

            while len(word) > 1:

                best_pair = None
                best_idx = float('inf')
                
                for i in range(len(word) - 1):
                    pair = (word[i], word[i + 1])
                    if pair in self.mergepair_to_byte:

                        merge_idx = self.mergepair_to_byte[pair]
                        if merge_idx < best_idx:
                            best_idx = merge_idx
                            best_pair = pair
                
                if best_pair is None:
                    break
                
                new_word = []
                i = 0
                while i < len(word):
                    if (i < len(word) - 1 and 
                        word[i] == best_pair[0] and 
                        word[i + 1] == best_pair[1]):
                        new_word.append(self.mergepair_to_byte[best_pair])
                        i += 2
                    else:
                        new_word.append(word[i])
                        i += 1
                
                word = new_word

            encoded_tokens.extend(word)
        
        if add_special_token:
            encoded_tokens.append(self.special_token_id)
        
        return encoded_tokens

    def merged_token(self, pair, decode_ids=None):
        if decode_ids is None:
            decode_ids = []

        first, second = pair

        if first > 255:
            self.merged_token(self.byte_to_mergepair[first], decode_ids)
        else:
            decode_ids.append(first)

        if second > 255:
            self.merged_token(self.byte_to_mergepair[second], decode_ids)
        else:
            decode_ids.append(second)

        return decode_ids

    def decode(self, ids):
        decoded_bytes = []
        
        for id in ids:
            if id not in self.token_to_byte:
                raise ValueError(f"Unknown token ID: {id}")
            
            tid = self.token_to_byte[id]
            
            if isinstance(tid, str):
                continue  
            
            if tid > 255 and tid in self.byte_to_mergepair:
                pair = self.byte_to_mergepair[tid]
                base_utf = self.merged_token(pair)
                decoded_bytes.extend(base_utf)
            else:
                decoded_bytes.append(tid)
        
        return bytes(decoded_bytes).decode("utf-8", errors="replace")

In [196]:
tokenizer=GPT4Tokenizer(merges)

In [197]:
text = "Build GPT4Tokenizer From Sratch"
tokens = tokenizer.encode(text)
print(f"Encoded: {tokens}")
decoded = tokenizer.decode(tokens)
print(f"Decoded: {decoded}")

Encoded: [66, 117, 105, 108, 100, 32, 71, 80, 84, 52, 84, 111, 107, 101, 110, 105, 122, 101, 114, 32, 70, 114, 111, 109, 32, 83, 114, 97, 116, 99, 104]
Decoded: Build GPT4Tokenizer From Sratch
