# Exercise

Build your own GPT-4 Tokenizer!

### Step 1

Write the `BasicTokenizer` class, with the following three core functions:

- `def train(self, text, vocab_size, verbose=False)`
- `def encode(self, text)`
- `def decode(self, ids)`

Train your tokenizer on whatever text you like and visualize the merged tokens. Do they look reasonable? One default test you may wish to use is the text file `tests/taylorswift.txt`.

In [43]:
from collections import Counter, defaultdict
from itertools import tee
import regex as re

class BasicTokenizer():

    def __init__(self, byte_permutation=None) -> None:
        """
        - byte_permutation: mapping of bytes (values from 0 to 255) that may be used for encoding purposes
        - special_tokens: str -> int dictionary of special tokens
          example: {'<|endoftext|>': 100257}
        """
        super().__init__()
        self.vocab = None
        self.merges = []  
        self.byte_permutation = byte_permutation
        self.special_tokens = {}
        

    def train(self, text, vocab_size, verbose=False, flag=False):
        """
        Train the tokenizer using Byte Pair Encoding (BPE).

        Args:
            text (str): The input text to train on.
            vocab_size (int): The desired vocabulary size.
            verbose (bool, optional): If True, print merge operations. Defaults to False.
            flag (bool, optional): If True, train based on the gpt4 regex pattern. Defaults to False.
        """

        # Step 1: Construct the base vocabulary
        self.vocab = set(self.get_vocab(text=text).keys())        
        if flag: 
            # Split text using the regex pattern
            tokens = re.findall(GPT4_SPLIT_PATTERN, text)
   
        tokens = [' '.join(list(word)) + ' </w>' for word in text.split()]

        # Iterate (steps 2-4) until a specified number of iterations are reached 
        while len(self.vocab) < vocab_size:
            
            # Step 2: Find most frequent pair
            pairs = self.get_stats(tokens)
            if not pairs:
                break
            max_freq_pair = self.get_best_pair(pairs)
            
            # Step 3: Merge pair
            tokens = self.merge_vocab(max_freq_pair, tokens)

            # Step 4: Update vocabulary 
            self.vocab.add(''.join(max_freq_pair))
            self.merges.append(max_freq_pair)

            if verbose: 
                print(f"Merge: {max_freq_pair}")

    def encode(self, text, flag):
        """Encode the input text into a sequence of tokens.

        Args:
            text (str): The input text to encode.
            flag (bool, optional): If True, train based on the gpt4 regex pattern. Defaults to False.

        Returns:
            list of str: List of encoded tokens.
        """
        if flag: 
            # Split text using the regex pattern
            tokens = re.findall(GPT4_SPLIT_PATTERN, text)
        # Initial tokenization
        tokens = [' '.join(list(word)) + ' </w>' for word in text.split()]
        
        # Applying BPE merges
        for pair in self.merges:
            tokens = self.merge_vocab(pair, tokens)
        
        # Token flattening
        encoded_text = [word.split() for word in tokens]
        encoded_ids = [token for sublist in encoded_text for token in sublist]
        
        if flag:
            # Apply byte permutation
            # Apply byte permutation only to single-character tokens
            permuted_ids = []
            for token in encoded_ids:
                if len(token) == 1:
                    byte_val = ord(token)
                    if 0 <= byte_val <= 255:
                        permuted_ids.append(self.byte_permutation[byte_val])
                    else:
                        permuted_ids.append(token)  # Keep non-byte characters as they are
                else:
                    permuted_ids.append(token)
            return permuted_ids

        return encoded_ids

    def decode(self, ids, flag):
        """Decode a sequence of tokens back into the original text.

        Args:
            ids (list of str): List of token IDs to decode.

        Returns:
            str: The decoded text.
        """
        if flag:
            inv_map = {v: k for k, v in self.byte_permutation.items()}
            reversed_ids = [chr(inv_map[id]) if isinstance(id, int) else id for id in ids]
            decoded_text = ''.join(reversed_ids).replace('</w>', ' ')
        else:
            decoded_text = ''.join(reversed_ids).replace(' </w>', '')
        return decoded_text.strip()

    def define_special_tokens(self, special_tokens):
        self.special_tokens = special_tokens
        self.vocab = self.vocab.add(self.special_tokens)

    def get_vocab(self, text):
        # Initialize vocabulary with frequency of each word in text
        base_vocab = Counter(text)
        return {word: freq for word, freq in base_vocab.items()}
    
    def get_set_chars(self, text):
        # Initialize vocabulary with the set of characters in text
        base_vocab = set()
        for byte_arr in text: 
            for byte in byte_arr:
                base_vocab.add(byte)
        return base_vocab
    
    def pairwise(self, iterable):
        #"s -> (s0,s1), (s1,s2), (s2, s3), ..."
        a, b = tee(iterable)
        next(b, None)
        return zip(a, b)
    
    def get_freq_of_pairs(self, text):
        pair_words = list(self.pairwise(text.split()))
        freq_pairs = Counter(pair_words)
        return  {tuple_: freq for tuple_, freq in freq_pairs.items()}
        
    def get_best_pair(self, pairs):
        # Get the most frequent pair
        best_pair = max(pairs, key=pairs.get)
        return best_pair
        
    def get_stats(self, tokens):
        """Compute the frequency of each pair of symbols in the tokenized text.

        Args:
            tokens (list of str): List of tokens to analyze.

        Returns:
            defaultdict: Dictionary with pairs of symbols as keys and their frequency as values.
        """
        pairs = defaultdict(int)
        for word in tokens:
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += 1
        return pairs

    def merge_vocab(self, pair, tokens):
        """Merge the most frequent pair of symbols in the tokens.

        Args:
            pair (tuple of str): The pair of symbols to merge.
            tokens (list of str): List of tokens to process.

        Returns:
            list of str: New list of tokens with the pair merged.
        """
        bigram = re.escape(' '.join(pair))
        pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        new_tokens = [pattern.sub(''.join(pair), word) for word in tokens]
        return new_tokens


In [5]:
from typing import List
import os

def open_file(path: str) -> List[str]:
    try:
        with open(file=path, mode="r", encoding="UTF-8") as file: 
            content = file.read()
    except FileNotFoundError:
        print("File not found!")
    return content


file_name = "taylorswift.txt"
cur_dir = os.getcwd()
abs_path = os.path.join(cur_dir, file_name)
print("The absolute path of the given file is: ", abs_path)

text = open_file(path=abs_path)

The absolute path of the given file is:  c:\Users\c.manara\Documents\GitHub\Hands-on-LLMs-NLP-Transformers-Training\Tokenization\GPT4 Tokenizer\taylorswift.txt


In [16]:
# Object of the class BasicTokenizer
tokenizer = BasicTokenizer()

# Train the tokenizer
vocab_size = 1000
tokenizer.train(text=text, vocab_size=vocab_size, verbose=True, flag=False)

# Encode the text
encoded_text = tokenizer.encode("Here is my first encoding!", flag=False)
print(f"Encoded: {encoded_text}")

# Decode the text
decoded_text = tokenizer.decode(ids=encoded_text, flag=False)
print(f"Decoded: {decoded_text}")

Merge: ('.', '</w>')
Merge: ('e', '</w>')
Merge: (',', '</w>')
Merge: ('d', '</w>')
Merge: ('r', '</w>')
Merge: ('2', '0')
Merge: ('s', '</w>')
Merge: ('i', 'n')
Merge: ('t', '</w>')
Merge: ('o', 'n')
Merge: ('r', 'i')
Merge: ('e', 'd</w>')
Merge: ('t', 'h')
Merge: ('a', 'n')
Merge: ('e', 'r</w>')
Merge: ('a', 'r')
Merge: ('y', '</w>')
Merge: ('a', 'l')
Merge: ('th', 'e</w>')
Merge: ('v', 'ed</w>')
Merge: ('w', 'i')
Merge: ('20', '1')
Merge: ('e', 'r')
Merge: ('on', '</w>')
Merge: ('wi', 'f')
Merge: ('R', 'e')
Merge: ('S', 'wif')
Merge: ('o', 'r</w>')
Merge: ('c', 'h')
Merge: ('o', 'm')
Merge: ('20', '2')
Merge: ('b', 'er</w>')
Merge: ('a', 'y')
Merge: ('e', 'n')
Merge: ('o', 'r')
Merge: ('al', '</w>')
Merge: ('e', 'm')
Merge: ('ri', 'e')
Merge: ('in', 'g')
Merge: ('t', 'i')
Merge: ('ay', 'l')
Merge: ('"', '.</w>')
Merge: ('l', 'l')
Merge: ('T', 'ayl')
Merge: ('t', 'rie')
Merge: ('t', 'o')
Merge: ('Re', 'trie')
Merge: ('Retrie', 'ved</w>')
Merge: ('Tayl', 'or</w>')
Merge: ('e', 's')
Me

### Step 2

Convert you `BasicTokenizer` into a `RegexTokenizer`, which takes a regex pattern and splits the text exactly as GPT-4 would. Process the parts separately as before, then concatenate the results. Retrain your tokenizer and compare the results before and after. You should see that you will now have no tokens that go across categories (numbers, letters, punctuation, more than one whitespace). Use the GPT-4 pattern:

```
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
```



In [23]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

### Step 3

You're now ready to load the merges from the GPT-4 tokenizer and show that your tokenizer produces the identical results for both `encode` and `decode`, matching [tiktoken](https://github.com/openai/tiktoken).

```
# match this
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back
```

Unfortunately, you will run into two issues:

1. It is not trivial to recover the raw merges from the GPT-4 tokenizer. You can easily recover what we call `vocab` here, and what they call and store under `enc._mergeable_ranks`. Feel free to copy paste the `recover_merges` function in `minbpe/gpt4.py`, which takes these ranks and returns the raw merges. If you wish to know how this function works, read [this](https://github.com/openai/tiktoken/issues/60) and [this](https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306). Basically, under some conditions it is enough to only store the parent nodes (and their rank) and get rid of the precise details of which children merged up to any parent.
2. Second, the GPT-4 tokenizer for some reason permutes its raw bytes. It stores this permutation in the first 256 elements of the mergeable ranks, so you can recover this byte shuffle relatively simply as `byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}`. In both your encode and decode, you'll have to shuffle bytes around accordingly. If you're stuck, reference the minbpe/gpt4.py` file for hints.

In [24]:
# match this
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back
print("Encoding: ", ids)
print("After decoding: ", text)

Encoding:  [15339, 1917, 12340, 30, 320, 31495, 230, 75265, 243, 92245, 16715, 28509, 4513, 57037]
After decoding:  hello world!!!? (안녕하세요!) lol123 😉


In [25]:
# Object of the class BasicTokenizer
tokenizer = BasicTokenizer()
# Train the tokenizer
vocab_size = 1000
tokenizer.train(text=text, vocab_size=vocab_size, verbose=False, flag=True)

# Encode the text
encoded_text = tokenizer.encode("hello world!!!? (안녕하세요!) lol123 😉", flag=True)
print(f"Encoded: {encoded_text}")

# Decode the text
decoded_text = tokenizer.decode(ids=encoded_text, flag=False)
print(f"Decoded: {decoded_text}")

Encoded: ['hello</w>', 'world!!!?</w>', '(안녕하세요!)</w>', 'lol123</w>', '😉</w>']
Decoded: hello</w>world!!!?</w>(안녕하세요!)</w>lol123</w>😉</w>


In [26]:
# Source of these functions: https://github.com/karpathy/minbpe/blob/master/minbpe/gpt4.py
def bpe(mergeable_ranks, token, max_rank):
    # helper function used in get_gpt4_merges() to reconstruct the merge forest
    parts = [bytes([b]) for b in token]
    while True:
        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
        if min_rank is None or (max_rank is not None and min_rank >= max_rank):
            break
        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
    return parts

def recover_merges(mergeable_ranks):
    # the `merges` are already the byte sequences in their merged state.
    # so we have to recover the original pairings. We can do this by doing 
    # a small BPE training run on all the tokens, in their order.
    # also see https://github.com/openai/tiktoken/issues/60
    # also see https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306
    merges = {}
    for token, rank in mergeable_ranks.items():
        if len(token) == 1:
            continue # skip raw bytes
        pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
        assert len(pair) == 2
        # recover the integer ranks of the pair
        ix0 = mergeable_ranks[pair[0]]
        ix1 = mergeable_ranks[pair[1]]
        merges[(ix0, ix1)] = rank

    return merges

In [45]:
# Match tiktoken output
enc = tiktoken.get_encoding("cl100k_base")
mergeable_ranks = enc._mergeable_ranks
byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}
# print(byte_shuffle)

# Recover merges and byte permutation
merges = recover_merges(mergeable_ranks)
# print("Recovered merges: ", merges)

# Object of the class BasicTokenizer
tokenizer = BasicTokenizer(byte_shuffle)

# Encode the text
encoded_text = tokenizer.encode("hello world!!!? (안녕하세요!) lol123 😉", flag=True)
print(f"Encoded: {encoded_text}")

# Decode the text
decoded_text = tokenizer.decode(ids=encoded_text, flag=True)
print(f"Decoded: {decoded_text}")

Encoded: [71, 68, 75, 75, 78, '</w>', 86, 78, 81, 75, 67, 0, 0, 0, 30, '</w>', 7, '안', '녕', '하', '세', '요', 0, 8, '</w>', 75, 78, 75, 16, 17, 18, '</w>', '😉', '</w>']
Decoded: hello world!!!? (안녕하세요!) lol123 😉



### Step 4

(Optional, irritating, not obviously useful) Add the ability to handle special tokens. You'll then be able to match the output of tiktoken even when special tokens are present, e.g.:

```
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("<|endoftext|>hello world", allowed_special="all")
```

Without `allowed_special` tiktoken will error.

In [28]:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("<|endoftext|>hello world", allowed_special="all")

In [44]:
enc = tiktoken.get_encoding("cl100k_base")
mergeable_ranks = enc._mergeable_ranks
byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}

special_tokens = "<|endoftext|>"
tokenizer = BasicTokenizer(byte_shuffle)

# Train the tokenizer
vocab_size = 1000
tokenizer.train(text=text, vocab_size=vocab_size, verbose=False, flag=True)

tokenizer.define_special_tokens(special_tokens)

# Encode the text
encoded_text = tokenizer.encode("<|endoftext|>hello world", flag=True)
print(f"Encoded: {encoded_text}")

# Decode the text
decoded_text = tokenizer.decode(ids=encoded_text, flag=True)
print(f"Decoded: {decoded_text}")

Encoded: [27, 91, 68, 77, 67, 78, 69, 83, 68, 87, 83, 91, 29, 'hello</w>', 'world', '</w>']
Decoded: <|endoftext|>hello world


### Step 5

If you've made it this far, you're now a pro at LLM Tokenization! Sadly, you're not exactly done yet because a lot of LLMs outside of OpenAI (e.g. Llama, Mistral) use [sentencepiece](https://github.com/google/sentencepiece) instead. Primary difference being that sentencepiece runs BPE directly on Unicode code points instead of on UTF-8 encoded bytes. Feel free to explore sentencepiece on your own (good luck, it's not too pretty), and stretch goal if you really experience and suffer from the burden of time, re-write your BPE to be on Unicode code points and match the Llama 2 tokenizer.

In [16]:
import sentencepiece as spm
model = spm.SentencePieceTrainer.train(
      input=abs_path, 
      model_prefix='m', 
      vocab_size=1000, 
      user_defined_symbols=['<|endoftext|>'])
      
s = spm.SentencePieceProcessor(model_file='m.model')
for n in range(5):
      encoded_text = s.encode('New York', out_type=str, enable_sampling=True, alpha=0.1, nbest_size=-1)
      print(encoded_text )


['▁Ne', 'w', '▁', 'Y', 'or', 'k']
['▁New', '▁', 'Y', 'or', 'k']
['▁', 'New', '▁York']
['▁New', '▁York']
['▁', 'New', '▁York']
