### Pre-Tokenization using regex

In [3]:
import tiktoken

model = "gpt-5"
enc = tiktoken.encoding_for_model(model)

print(enc.name)
print(enc._pat_str)

o200k_base
[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+


In [5]:
import regex as re
regex_pattern = enc._pat_str
text = "Hello, world!"
re.findall(regex_pattern, text)

['Hello', ',', ' world', '!']

In [375]:
from collections import Counter
from typing import Iterable, Optional
import regex as re

IntSeq = tuple[int, ...]
Pair = tuple[int, int]

class BPE():
    def __init__(self, pat_str, target_merges: int=3, vocab_size:int=512, special_tokens: Optional[list[str]]=None):
        self.pat_str = pat_str
        self.regex = re.compile(self.pat_str)
        self.merges: list[tuple[Pair, int]] = []
        self.ranks: dict[Pair, int] = {}    
        self.next_id: int  = 256
        self.target_merges: int = target_merges
        self.vocab_size: int = vocab_size
        self.encoder: dict[IntSeq, int] = {(i,): i for i in range(256)}
        self.decoder: dict[int, IntSeq] = {i: (i,) for i in range(256)}
        self.special_tokens: list[str] = special_tokens or  []
        self.special_token_to_id: dict[str, int] = {}
        self.add_special_token()

    def add_special_token(self) -> None:
        for token_str in self.special_tokens:
            token_int: IntSeq = tuple(token_str.encode("utf-8"))
            if token_int in self.encoder:
                token_id = self.encoder[token_int]
            else:
                token_id = self.next_id
                self.encoder[token_int] = token_id
                self.decoder[token_id] = token_int
                self.next_id += 1
            self.special_token_to_id[token_str] = token_id

    def decode_to_bytes(self, token_ids: list[int]) -> bytes:
            return b"".join(bytes(self.decoder[token_id]) for token_id in token_ids)
    
    def decode_text(self, token_ids: list[int]) -> str:
        return self.decode_to_bytes(token_ids).decode("utf-8", errors="strict")
    
    def pre_tokenize(self, text: str) -> list[str]:
        return [m.group(0) for m in self.regex.finditer(text)]

    @staticmethod
    def _get_all_pairs(tokens: list[int]) -> Iterable[Pair]:
        for i in range(len(tokens) - 1):
            yield (tokens[i], tokens[i+1])
    
    @staticmethod
    def text_to_ints(text: str) -> list[int]:
        return list(text.encode('utf-8'))

    def _replace_pair(self, chunks: list[list[int]], pair: Pair) -> list[list[int]]:
        def go_to_next_seq(index: int) -> int:
            return index + 1

        def add_current_seq_to_replaced_chunk(index: int, replaced_chunk: list[int]) -> list[int]:
            replaced_chunk.append(chunk[index])
            return replaced_chunk

        def skip_next_byte(index: int) -> int:
            return index + 2


        def not_end_of_chunks(i: int, chunk: list[int]) -> bool:
            return i + 1 < len(chunk)

        def pair_at_index(chunk: list[int], pair: Pair, index: int) -> bool:
            return chunk[index] == pair[0] and chunk[index + 1] == pair[1]

        def found_pair_at_index(i: int, chunk: list[int], pair: Pair) -> bool:
            return not_end_of_chunks(i, chunk) and pair_at_index(chunk, pair, i)

        def replace_pair_with_merged(replaced_chunk: list[int]) -> list[int]:
            new_id = self.next_id - 1
            replaced_chunk.append(new_id)
            return replaced_chunk

        def replace_pair_in_chunk(chunk: list[int], pair: Pair) -> list[int]:
            replaced_chunk: list[int] = []
            i: int = 0
            while i < len(chunk):
                if found_pair_at_index(i, chunk, pair):
                    replaced_chunk = replace_pair_with_merged(replaced_chunk)
                    i = skip_next_byte(i)
                else:
                    replaced_chunk = add_current_seq_to_replaced_chunk(i, replaced_chunk)
                    i = go_to_next_seq(i)
            return replaced_chunk
        
        replaced_chunks: list[list[int]] = []
        for chunk in chunks:
            replaced_chunk: list[int] = replace_pair_in_chunk(chunk, pair)
            replaced_chunks.append(replaced_chunk)
        return replaced_chunks
    
    def _count_pairs_frequency(self, chunks: list[list[int]]) -> Counter[Pair]:
        def chunk_to_small_for_merging(chunk: list[int]) -> bool:
            return len(chunk) < 2

        global_counter: Counter[Pair] = Counter()
        for chunk in chunks:
            if chunk_to_small_for_merging(chunk): continue
            pairs: zip[int, int] = self._get_all_pairs(chunk)
            global_counter.update(pairs)
        return global_counter


    def _get_most_frequent_pair(self, chunks: list[list[int]]) -> Optional[Pair]:
        def tie_breaker_key(kv: tuple[Pair, int]) -> tuple[int, Pair]:
            pair, freq = kv
            resulting = self.decoder[pair[0]] + self.decoder[pair[1]]
            return (-freq, resulting, pair)

        def hasnt_pairs() -> bool:
            return not global_counter
        
        global_counter: Counter[Pair] = self._count_pairs_frequency(chunks)
        if hasnt_pairs(): return None
        best_pair, _ = min(global_counter.items(), key=lambda kv: tie_breaker_key(kv))
        return best_pair

    def _register_merge(self, pair: Pair, merges_done: int):
        def token_already_exists() -> bool:
            return merged_token in self.encoder

        first_token: IntSeq = self.decoder[pair[0]]
        second_token: IntSeq = self.decoder[pair[1]]
        merged_token: IntSeq = first_token + second_token

        if token_already_exists(): return
        
        self.encoder[merged_token] = self.next_id
        self.decoder[self.next_id] = merged_token

        self.merges.append((pair, self.next_id))
        self.next_id += 1
        self.ranks[pair] = merges_done

    def train(self, corpus: str):
        def not_enough_merges_done() -> bool:
            nonlocal merges_done
            return merges_done < self.target_merges
            
        pre_tokens: list[str] = self.pre_tokenize(corpus)
        chunks: list[list[int]] = [self.text_to_ints(token) for token in pre_tokens]

        max_merges: int = max(0, self.vocab_size - self.next_id)
        self.target_merges: int = min(self.target_merges, max_merges)

        merges_done: int = 0
        while not_enough_merges_done():
            best_pair: Optional[Pair] = self._get_most_frequent_pair(chunks)
            if best_pair is None: break
            self._register_merge(best_pair, merges_done)
            chunks: list[list[int]] = self._replace_pair(chunks, best_pair)
            
            merges_done += 1

        print("Training complete."
              f" Total merges done: {merges_done}."
              f" Vocabulary size: {256 + len(self.merges)}.")

In [376]:
tokenizer = BPE(pat_str=enc._pat_str)
tokenizer.train(corpus="Helloll, world!")

Training complete. Total merges done: 3. Vocabulary size: 259.


In [None]:
[b'H', b'e', b'l', b'l', b'o']