# Train BPE tokenizer


In [1]:
import regex as re
from typing import Dict, Tuple, List
from collections import Counter

In [2]:
string = """
low low low low low <|endoftext|>
lower lower widest widest widest <|endoftext|>
newest newest newest newest newest newest 
"""
special_tokens = ["<|endoftext|>"]
PAT = r"\S+"

In [3]:
def initialize_vocab(special_tokens: List[bytes]) -> Dict[int, bytes]:
    vocab = {i: bytes([i]) for i in range(256)}  # ASCII characters
    for i, token in enumerate(special_tokens, start=256):
        vocab[i] = token

    return vocab

In [4]:
def word_to_bytes(word: str) -> List[bytes]:
    """
    Convert a word to bytes.
    """
    byte_ids = [bytes([b]) for b in word.encode("utf-8")]

    return byte_ids

In [5]:
def split_by_special_tokens(
    text: str, special_tokens: list[str], include_special: bool = False
) -> List[str]:
    special_tokens_sorted = sorted(special_tokens, key=len, reverse=True)
    pattern = "|".join(re.escape(t) for t in special_tokens_sorted)

    if include_special:
        special_chunks = re.split(f"({pattern})", text)
    else:
        # Split without capturing the special tokens
        special_chunks = re.split(pattern, text)

    return special_chunks

In [6]:
def pre_tokenize_string(
    s: str, special_tokens: list[str], include_special: bool = False
) -> Dict[Tuple[bytes], int]:
    """
    Pre-tokenize a string into bytes.
    """

    word_counter = Counter()
    special_chunks = split_by_special_tokens(s, special_tokens, include_special)

    for chunk in special_chunks:
        if chunk in special_tokens:
            if include_special:
                token = tuple(word_to_bytes(chunk))
                word_counter[token] += 1
        else:
            for match in re.finditer(PAT, chunk):
                word = match.group(0)
                token = tuple(word_to_bytes(word))
                word_counter[token] += 1

    return word_counter

In [7]:
def pair_counts(
    word_counter: Dict[Tuple[bytes], int],
) -> Dict[Tuple[bytes, bytes], int]:
    """
    Count pairs of bytes in the word counter.
    """
    pairs: Dict[Tuple[bytes, bytes], int] = {}
    for token, freq in word_counter.items():
        for i in range(len(token) - 1):
            pair = (token[i], token[i + 1])
            pairs[pair] = pairs.get(pair, 0) + freq

    return pairs


def get_most_frequent_pair(
    pairs: Dict[Tuple[bytes, bytes], int],
) -> Tuple[bytes, bytes]:
    max_freq = max(pairs.values())
    candidates = [pair for pair, freq in pairs.items() if freq == max_freq]
    res = max(candidates)

    return res

In [8]:
def add_pair_to_vocab(
    vocab: Dict[int, bytes], pair: Tuple[bytes, bytes], vocab_inv: Dict[bytes, int]
) -> int:
    """
    Add a new pair to the vocabulary.
    """
    index = len(vocab)
    s = vocab[vocab_inv[pair[0]]] + vocab[vocab_inv[pair[1]]]
    vocab[index] = s
    vocab_inv[vocab[index]] = index

    return index

In [9]:
from collections import Counter, defaultdict


def merge_pair(
    word_counter: Dict[Tuple[bytes], int], pair: Tuple[bytes, bytes]
) -> Tuple[Dict[Tuple[bytes], int], Dict]:
    """
    Merge a pair of bytes in the word counter.
    """
    new_word_counter = Counter()
    updated_pair_counts = defaultdict(int)

    for token, freq in word_counter.items():
        new_token = []
        i = 0
        while i < len(token):
            if i < len(token) - 1 and (token[i], token[i + 1]) == pair:
                new_token.append(token[i] + token[i + 1])
                i += 2
            else:
                new_token.append(token[i])
                i += 1

        new_word_counter[tuple(new_token)] += freq

        for j in range(len(new_token) - 1):
            new_pair = (new_token[j], new_token[j + 1])
            updated_pair_counts[new_pair] += freq

    return new_word_counter, updated_pair_counts

In [10]:
def check_and_convert_special_tokens(
    special_tokens: List[str] | List[bytes],
) -> List[bytes]:
    """
    Check if special tokens are in the vocabulary and convert them to bytes.
    """
    if not all(isinstance(token, bytes) for token in special_tokens):
        special_tokens_bytes = [
            token.encode("utf-8") for token in special_tokens if isinstance(token, str)
        ]

    return special_tokens_bytes

In [11]:
def train_bpe(
    string: str = string,
    vocab_size=263,
    special_tokens: List[str] = special_tokens,
):
    special_tokens_bytes = check_and_convert_special_tokens(special_tokens)

    vocab = initialize_vocab(special_tokens_bytes)
    vocab_inv = {v: k for k, v in vocab.items()}

    merges: List[Tuple[bytes, bytes]] = []

    word_counter = pre_tokenize_string(string, special_tokens, include_special=False)
    pairs_freqs = pair_counts(word_counter)

    num_merges = vocab_size - len(vocab)
    for _ in range(num_merges):

        most_common_pair = get_most_frequent_pair(pairs_freqs)

        new_index = add_pair_to_vocab(vocab, most_common_pair, vocab_inv)
        merges.append(most_common_pair)

        word_counter, pairs_freqs = merge_pair(word_counter, most_common_pair)

    return vocab, merges

In [18]:
vocab, merge = train_bpe(vocab_size=269)

In [19]:
merge

[(b's', b't'),
 (b'e', b'st'),
 (b'o', b'w'),
 (b'l', b'ow'),
 (b'w', b'est'),
 (b'n', b'e'),
 (b'ne', b'west'),
 (b'w', b'i'),
 (b'wi', b'd'),
 (b'wid', b'est'),
 (b'low', b'e'),
 (b'lowe', b'r')]

# BPE Tokenizer


In [1]:
vocab = {
    0: b" ",
    1: b"a",
    2: b"c",
    3: b"e",
    4: b"h",
    5: b"t",
    6: b"th",
    7: b" c",
    8: b" a",
    9: b"the",
    10: b"at",
}
string = "the cat ate"
merges = [(b"t", b"h"), (b" ", b"c"), (b" ", "a"), (b"th", b"e"), (b" a", b"t")]

In [None]:
def split_by_special_tokens(text: str, special_tokens: list[str]) -> List[str]:
    special_tokens_sorted = sorted(special_tokens, key=len, reverse=True)
    if not special_tokens_sorted:
        return [text]
    pattern = "|".join(re.escape(t) for t in special_tokens_sorted)
    special_chunks = re.split(f"({pattern})", text)

    return special_chunks

NameError: name 'List' is not defined

In [3]:
def word_to_bytes(word: str) -> List[bytes]:
    """
    Convert a word to bytes.
    """
    byte_ids = [bytes([b]) for b in word.encode("utf-8")]

    return byte_ids

NameError: name 'List' is not defined

In [None]:
class Tokenizer:
    def __init__(
        self,
        vocab: Dict[int, bytes],
        merges: List[Tuple[bytes, bytes]] = [],
        special_tokens: List[str] = [],
    ):
        self.vocab = vocab
        self.merges = merges
        self.special_tokens = {token: i for i, token in enumerate(special_tokens, start=len(vocab))} if special_tokens else {}

        self.vocab_inv = {v: k for k, v in vocab.items()}
    
    def _pre_tokenize(self, text) -> List[bytes]:
        """
        Pre-tokenize the input text into bytes.
        """
        parts = split_by_special_tokens(text, list(self.special_tokens.keys()))
        token_list = []
        
        for part in parts:
            if part in self.special_tokens.keys():
                token_list.append(word_to_bytes(part))
            else:
                tokens = re.findall(PAT, part)
                token_list.extend(word_to_bytes(token) for token in tokens)

        return token_list
        
    
    def encode(self, text: str) -> List[int]:
        
        # Pre-tokenize the input text into bytes

    def decode(self, ids: list[int]) -> str:
        """
        Decode a list of token IDs back to a string.
        """
        tokens = b"".join(self.vocab.get(i, b"\xef\xbf\xbd") for i in ids)
        return tokens.decode("utf-8", errors="replace")