<a href="https://colab.research.google.com/github/Firojpaudel/Demystifying_Language_Modeling/blob/main/Tokenizer/Tokenizer_Exercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Building own GPT-4 Tokenizer
---


#### Step 0: Base Configs

In [8]:
import regex as re
import unicodedata
from collections import defaultdict

In [40]:
def get_status(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    merged = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            merged.append(idx)
            i += 2
        else:
            merged.append(ids[i])
            i += 1
    return merged

def replace_control_characters(s: str) -> str:
    chars = []
    for ch in s:
        if unicodedata.category(ch)[0] != "C":
            chars.append(ch)
        else:
            chars.append(f"\\u{ord(ch):04x}")
    return "".join(chars)

def render_token(t: bytes) -> str:
    s = t.decode('utf-8', errors='replace')
    s = replace_control_characters(s)
    return s

class Tokenizer:
    def __init__(self):
        self.merges = {}
        self.pattern = ""
        self.special_tokens = {}
        self.vocab = self._build_vocab()

    def train(self, text, vocab_size, verbose=False):
        raise NotImplementedError

    def encode(self, text):
        raise NotImplementedError

    def decode(self, ids):
        raise NotImplementedError

    def _build_vocab(self):
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab

    def save(self, file_prefix):
        model_file = file_prefix + ".model"
        with open(model_file, 'w') as f:
            f.write("minbpe v1\n")
            f.write(f"{self.pattern}\n")
            f.write(f"{len(self.special_tokens)}\n")
            for special, idx in self.special_tokens.items():
                f.write(f"{special} {idx}\n")
            for (idx1, idx2), idx in self.merges.items():
                f.write(f"{idx1} {idx2}\n")
        vocab_file = file_prefix + ".vocab"
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:
            for idx, token in self.vocab.items():
                s = render_token(token)
                if idx in inverted_merges:
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
                else:
                    f.write(f"[{s}] {idx}\n")

    def load(self, model_file):
        assert model_file.endswith(".model")
        merges = {}
        special_tokens = {}
        idx = 256
        with open(model_file, 'r', encoding="utf-8") as f:
            version = f.readline().strip()
            assert version == "minbpe v1"
            self.pattern = f.readline().strip()
            num_special = int(f.readline().strip())
            for _ in range(num_special):
                special, special_idx = f.readline().strip().split()
                special_tokens[special] = int(special_idx)
            for line in f:
                idx1, idx2 = map(int, line.split())
                merges[(idx1, idx2)] = idx
                idx += 1
        self.merges = merges
        self.special_tokens = special_tokens
        self.vocab = self._build_vocab()

GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

---
#### Step 1: Create a `BasicTokenizer` class that has:
- `train`, `encode` and `decode` functions .

In [41]:
class BasicTokenizer(Tokenizer):
    def __init__(self):
        super().__init__()
        self.pattern = GPT4_SPLIT_PATTERN
        self.compiled_pattern = re.compile(self.pattern)
        self.vocab_idx_to_token = {}

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 0, "Vocabulary size must be non-negative"

        # Normalize text to lowercase
        text = text.lower()

        # Step 1: Split text into initial tokens
        tokens = self.compiled_pattern.findall(text)
        if verbose:
            print(f"Initial tokens (sample): {tokens[:20]}")

        # Step 2: Create initial vocabulary, including single characters
        unique_tokens = list(dict.fromkeys(tokens))
        vocab = {token: idx for idx, token in enumerate(unique_tokens)}
        vocab_idx_to_token = {idx: token for idx, token in enumerate(unique_tokens)}

        # Add individual characters to vocab for fallback
        for token in unique_tokens:
            for char in token:
                if char not in vocab:
                    idx = len(vocab)
                    vocab[char] = idx
                    vocab_idx_to_token[idx] = char

        # Step 3: Convert text to initial IDs
        ids = [vocab[token] for token in tokens]

        # Step 4: Perform BPE merges
        merges = {}
        num_merges = vocab_size - len(vocab) if vocab_size > len(vocab) else 0
        min_freq = 2
        max_token_length = 15  # Tighter constraint

        for i in range(num_merges):
            stats = get_status(ids)
            valid_pairs = {
                pair: count for pair, count in stats.items()
                if count >= min_freq and len(vocab_idx_to_token[pair[0]] + vocab_idx_to_token[pair[1]]) <= max_token_length
            }
            if not valid_pairs:
                if verbose:
                    print(f"No pairs with frequency >= {min_freq} and length <= {max_token_length}.")
                break
            pair = max(valid_pairs, key=valid_pairs.get)
            idx = len(vocab)
            ids = merge(ids, pair, idx)
            merges[pair] = idx

            # Update vocabulary
            token0 = vocab_idx_to_token[pair[0]]
            token1 = vocab_idx_to_token[pair[1]]
            new_token = token0 + token1
            vocab[new_token] = idx
            vocab_idx_to_token[idx] = new_token

            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({new_token}) had {stats[pair]} occurrences")

        self.merges = merges
        self.vocab = vocab
        self.vocab_idx_to_token = vocab_idx_to_token
        if verbose:
            print(f"Vocabulary size: {len(vocab)}")
            print(f"Sample vocab: {dict(list(vocab.items())[:10])}")

    def encode(self, text):
        text = text.lower()
        tokens = self.compiled_pattern.findall(text)
        ids = []
        for token in tokens:
            if token in self.vocab:
                ids.append(self.vocab[token])
            else:
                # Fallback to longest possible subwords
                i = 0
                while i < len(token):
                    # Try longest matching prefix in vocab
                    found = False
                    for j in range(len(token), i, -1):
                        subword = token[i:j]
                        if subword in self.vocab:
                            ids.append(self.vocab[subword])
                            i = j
                            found = True
                            break
                    if not found:
                        # Use single character
                        char = token[i]
                        if char not in self.vocab:
                            idx = len(self.vocab)
                            self.vocab[char] = idx
                            self.vocab_idx_to_token[idx] = char
                        ids.append(self.vocab[char])
                        i += 1

        # Apply BPE merges
        while len(ids) >= 2:
            stats = get_status(ids)
            # Prioritize earliest merge (lowest merge index)
            pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            if pair not in self.merges:
                break
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)

        return ids

    def decode(self, ids):
        tokens = [self.vocab_idx_to_token.get(idx, '\ufffd') for idx in ids]
        return ''.join(tokens)

---

##### Now we test:

In [None]:
# training_text = """

# """

In [44]:
tokenizer = BasicTokenizer()
text = (
    "Hello, world! This is a tokenizer drill! 🙌 "
    "Namaste! Welcome to the tokenizer. Let's learn more. "
    "Python is great for building models. Tokenization is fun! "
    "Hello again! Let's tokenize some text. 🙌 "
) * 100
tokenizer.train(text, vocab_size=500, verbose=True)
encoded = tokenizer.encode("namaste!")
decoded = tokenizer.decode(encoded)
print(f"Original: {text[:50]}...")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")

Initial tokens (sample): ['hello', ',', ' world', '!', ' this', ' is', ' a', ' tokenizer', ' drill', '!', ' 🙌', ' namaste', '!', ' welcome', ' to', ' the', ' tokenizer', '.', ' let', "'s"]
merge 1/443: (15, 16) -> 57 ( let's) had 200 occurrences
merge 2/443: (1, 2) -> 58 (, world) had 100 occurrences
merge 3/443: (58, 3) -> 59 (, world!) had 100 occurrences
merge 4/443: (59, 4) -> 60 (, world! this) had 100 occurrences
merge 5/443: (5, 6) -> 61 ( is a) had 100 occurrences
merge 6/443: (61, 7) -> 62 ( is a tokenizer) had 100 occurrences
merge 7/443: (8, 3) -> 63 ( drill!) had 100 occurrences
merge 8/443: (63, 9) -> 64 ( drill! 🙌) had 100 occurrences
merge 9/443: (10, 3) -> 65 ( namaste!) had 100 occurrences
merge 10/443: (11, 12) -> 66 ( welcome to) had 100 occurrences
merge 11/443: (66, 13) -> 67 ( welcome to the) had 100 occurrences
merge 12/443: (7, 14) -> 68 ( tokenizer.) had 100 occurrences
merge 13/443: (57, 17) -> 69 ( let's learn) had 100 occurrences
merge 14/443: (18, 14) -> 70