# This creates a Byte Pair Encoding tokenizer from scratch.

it is based on this git repo:
https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb

In [3]:
from collections import Counter, deque
import json

In [4]:
class BPETokenizerSimple:
    def __init__(self):
        # maping token id to str
        self.vocab = {}
        # maping the reverse
        self.inverse_vocab = {}
        # dictionary of BPE merges
        self.bpe_merges = {}

        # For the official OpenAI GPT-2 merges, use a rank dict:
        #  of form {(string_A, string_B): rank}, where lower rank = higher priority
        self.bpe_ranks = {}

        # from gpt2 (uses "Ġ" instead of spaces...)
        self.spc_char = "Ġ"

    def train(self, text, vocab_size, allowed_special={"<|endoftext|>"}):
        """Train the BPE from scratch"""

        processed_text = []
        for i, char in enumerate(text):
            if char == " " and i != 0:
                processed_text.append(self.spc_char)
            if char != " ":
                processed_text.append(char)
        processed_text = "".join(processed_text)

        # Initialize the vocabulary using the 255 ascii characters
        unique_chars = [chr(i) for i in range(256)]
        # Add any char in the text not in the ascii range
        unique_chars.extend(
            char for char in sorted(set(processed_text))
            if char not in unique_chars
        )
        # Make sure the "space" is included
        if self.spc_char not in unique_chars:
            unique_chars.append(self.spc_char)
        # add special tokens
        if allowed_special:
            unique_chars.extend(allowed_special)

        self.vocab = {i: char for i, char in enumerate(unique_chars)}
        self.inverse_vocab = {char: i for i, char in self.vocab.items()}

        # tokenize the text
        tokens = [self.inverse_vocab[char] for char in processed_text]

        # find char pairs
        for new_id in range(len(self.vocab), vocab_size):
            pair_id = self.find_freq_pair(tokens, mode='most')
            if pair_id is None:
                break
            tokens = self.replace_pair(tokens, pair_id, new_id)
            self.bpe_merges[pair_id] = new_id

        # build updated vocab with the new merges
        for (p0, p1), new_id in self.bpe_merges.items():
            merged_token = self.vocab[p0] + self.vocab[p1]
            self.vocab[new_id] = merged_token
            self.inverse_vocab[merged_token] = new_id

    def encode(self, text, allowed_especial=None):
        import re

        tokens = []
        if allowed_especial is not None and len(allowed_especial) > 0:
            special_regex = (
                "(" + "|".join(re.escape(tok) for tok in sorted(allowed_special, key=len, reverse=True)) + ")"
            )
            last_index = 0
            for match in re.finditer(special_regex, text):
                prefix = text[last_index:match.start()]
                tokens.extend(self.encode(prefix, allowed_especial=None))

                special_token = match.group(0)
                if special_token in self.inverse_vocab:
                    tokens.append(self.inverse_vocab[special_token])
                else:
                    raise ValueError(f"Special token {special_token} not found in vocabulary")

            text = text[last_index:]

            disallowed = [
                tok for tok in self.inverse_vocab
                if tok.startswith("<|") and tok.endswith("|>") and tok in text and tok not in allowed_especial
            ]
            if disallowed:
                raise ValueError(f"Special token {special_token} found in text not allowed")
            
        str_tokens = []
        lines = text.split("\n")
        for i, line in enumerate(lines):
            if i > 0:
                str_tokens.append("\n")

            words = line.split()
            for j, word in enumerate(words):
                if j == 0 and i == 0: # dont add a space on the first word of the first line
                    str_tokens.append(word)
                else:
                    str_tokens.append(self.spc_char + word)

        # Convert the char tokens to integer
        for token in str_tokens:
            if token in self.inverse_vocab:
                tokens.append(self.inverse_vocab[token])
            else:
                tokens.extend(self.tokenize_with_bpe(token))

        return tokens

    def tokenize_with_bpe(self, token, method='own'):
        tokens = [self.inverse_vocab.get(char, None) for char in token]

        if None in tokens:
            missing_chars = [char for char, tid in zip(token, tokens) if tid is None]
            raise ValueError(f"Characters not found in vocabulary: {missing_chars}")
        
        if method == 'own':
            can_merge = True
            while can_merge and len(tokens) > 1:
                can_merge = False
                new_tokens = []

                i = 0
                while i < len(tokens) - 1:
                    pair = (tokens[i], tokens[i+1])
                    if pair in self.bpe_merges:
                        merged_token_id = self.bpe_merges[pair]
                        new_tokens.append(merged_token_id)

                        i += 2
                        can_merge = True
                    else:
                        new_tokens.append(tokens[i])
                        i += 1
                
                if i < len(tokens):
                    new_tokens.append(tokens[i])

                tokens = new_tokens
                
            return tokens

    def decode(self, tokens):

        decoded_string = ""
        for i, token_id in enumerate(tokens):
            if token_id not in self.vocab:
                raise ValueError(f"Token id {token_id} not foudn in vocab")
            
            token = self.vocab[token_id]
            if token == "\n":
                if decoded_string and not decoded_string.endswith(" "):
                    decoded_string += " "
                decoded_string += token
            elif token.startswith(self.spc_char):
                decoded_string += " " + token[1:]
            else:
                decoded_string += token
        
        return decoded_string

    def save_vocab_and_merges(self, vocab_path, merges_path):
        with open(vocab_path, "w", encoding="utf-8") as file:
            json.dump(self.vocab, file, ensure_ascii=False, indent=2)

        with open(merges_path, "w", encoding="utf-8") as file:
            merges_list = [{"pair": list(pair), "new_id": new_id} for pair, new_id in self.bpe_merges.items()]
            json.dump(merges_list, file, ensure_ascii=False, indent=2)

    def load_vocab_and_merges(self, vocab_path, merges_path):
        with open(vocab_path, "r", encoding="utf-8") as file:
            data_loaded = json.load(file)
            self.vocab = {int(k): v for k, v in data_loaded.items()}
            self.inverse_vocab = {v: int(k) for k, v in data_loaded.items()}

        with open(merges_path, "r", encoding="utf-8") as file:
            data_loaded = json.load(file)
            for merge in data_loaded:
                pair = tuple(merge["pair"])
                new_id = merge["new_id"]
                self.bpe_merges[pair] = new_id

    @staticmethod
    def find_freq_pair(tokens, mode='most'):
        pairs = Counter(zip(tokens, tokens[1:]))

        if not pairs:
            return None
        
        if mode == 'most':
            return max(pairs.items(), key=lambda x: x[1])[0]
        elif mode == 'least':
            return min(pairs.items(), key= lambda x: x[1])[0]
        else:
            raise ValueError("Invalid mode. Choose 'most' or 'least'.")
        
    @staticmethod
    def replace_pair(tokens, pair_id, new_id):
        dq = deque(tokens)
        replaced = []

        while dq:
            current = dq.popleft()
            if dq and (current, dq[0]) == pair_id:
                replaced.append(new_id)
                dq.popleft()
            else:
                replaced.append(current)
        
        return replaced



## Training the tokenizer, using "The Verdict"

In [5]:
with open("data/the-verdict.txt", "r", encoding="utf-8") as f: # added ../01_main-chapter-code/
    text = f.read()

In [6]:
tokenizer = BPETokenizerSimple()
tokenizer.train(text, vocab_size=1000, allowed_special={"<|endoftext|>"})

In [7]:
print(len(tokenizer.vocab))
print(len(tokenizer.bpe_merges))

1000
742


In [8]:
# testing the tokenizer
input_text = "Jack embraced beauty through art and life."
token_ids = tokenizer.encode(input_text)
print(token_ids)

[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46]


In [9]:
print("Number of characters:", len(input_text))
print("Number of token IDs:", len(token_ids))

Number of characters: 42
Number of token IDs: 20


In [10]:
# decoding the message back
print(tokenizer.decode(token_ids))

Jack embraced beauty through art and life.


In [11]:
vocab_path = "tokenizer/vocabulary.json"
merges_path = "tokenizer/merges.json"

# save the tokenizer vocab and merges
tokenizer.save_vocab_and_merges(vocab_path, merges_path) 

In [12]:
# test loading from file
tokenizer.load_vocab_and_merges(vocab_path, merges_path) 