# So! Let's train a tokenizer on the Taylor Swift WikiPage 
> *To be honest, currently in my mind [Fortnight - from TTPD](https://www.youtube.com/watch?v=q3zqJs7JUCQ) is playing ❤️‍🔥*.

👉🏻 Please reach out to this page: [page](https://github.com/karpathy/minbpe/blob/master/exercise.md) to get the instructions that I have followed in this notebook.

## Step 1
Write the `BasicTokenizer` class, with the following three core functions:

- `def train(self, text, vocab_size, verbose=False)`
- `def encode(self, text)`
- `def decode(self, ids)`

Train your tokenizer on whatever text you like and visualize the merged tokens. Do they look reasonable? One default test you may wish to use is the text file tests/taylorswift.txt.

In [1]:
# reading the data
with open("./data/ts.txt", "r") as file:
    file = file.read()

# 💪🏻 Let's do it.

In [2]:
from collections import defaultdict
from typing import List, Tuple, Dict

In [38]:
class BasicTokenizer:
    def __init__(self):
        # initialize the defaut vocab
        self.vocab = {idx:bytes([idx]) for idx in range(256)}
        self.trained=False
        
    
    def find_most_repeated_pair(self, tokens) -> Tuple[Tuple, int, Dict]:
        counter = defaultdict(int)
        for pair in zip(tokens, tokens[1:]):
            counter[pair] += 1

        max_pair = max(counter, key=counter.get)
        max_count = counter[max_pair]
        return max_pair, max_count, counter

    def replace_pair_with_new_token(self, tokens, pair, new_idx) -> List:
        new_tokens = [] # this will hold the copy for the new tokens
        idx = 0
        while idx < len(tokens):
            if idx < len(tokens) - 1 and (tokens[idx] == pair[0]) and (tokens[idx + 1] == pair[1]): # this is a match!
                new_tokens.append(new_idx)
                idx += 2
            else: # this is not a match
                new_tokens.append(tokens[idx])
                idx += 1
        return new_tokens
        
    def train(self, blob, vocab_size=None) -> None:
        '''
        This function will train the tokenizer based on the 
        training data given as text.
        
        1. blob: The data in text format that will be used as training
            of the tokenizer.
        
        2. vocab_size: This is "how many new tokens you want to generate"
            - `None` means indefinite; generate all combinations.
            - `int` means the number of merges.
        '''
        self.vocab_size = vocab_size
        self.tokens = list(map(int, blob.encode("utf-8")))
        
        new_idx = 255
        merges = {}
        for i in range(vocab_size):
            pair, count, stats = self.find_most_repeated_pair(self.tokens)
            if count > 1:
                new_idx += 1
                self.tokens = self.replace_pair_with_new_token(self.tokens, pair, new_idx)
                merges[pair] = new_idx
            else: # every pair is occuring for once only
                break
        self.total_merges = i+1
        
        ## The training is done now merge the stuff
        for pair, idx in merges.items():
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
        self.merges = merges   
        self.trained = True
        
    def encode(self, text):
        '''
        The goal of this function is to encode the given text into the 
        tokens that are acceptable by our `vocab`.
        
        So, we will need to keep encoding the tokens form the start (top)
        to the bottom.
        
        The `order` of the vocab **is not guerenteed** in the older versions
        of python, so we wil need to rely on the `idx`. The lower the idx
        is, the older that token is!
        '''
        
        if not self.trained:
            raise NotImplementedError("Please first train the tokenizer!")
        
        tokens = text.encode("utf-8")    
        while len(tokens) >= 2:
            _, _, stats = self.find_most_repeated_pair(tokens)
            # now the goal is to get all pairs of the new tokens
            # we are not interested in the count, just the pairs
            # then check for each pair, if 
            pair_replace = min(self.merges, key=lambda x: stats.get(x, float("inf")))
            if pair_replace in stats:
                tokens = self.replace_pair_with_new_token(tokens, 
                                                 pair_replace,
                                                 self.merges[pair_replace])
            else:
                break
        return tokens
    
    def decode(self, tokens):
        decoded_stream = [self.vocab[idx] for idx in tokens]
        text = b"".join(decoded_stream)
        return text.decode("utf-8")

In [137]:
tokenizer = BasicTokenizer()

In [138]:
tokenizer.train("This is the sentence that should test the current tokenizer.", vocab_size=100)

In [139]:
tokenizer.decode(tokenizer.encode("This is me now"))

'This is me now'

In [141]:
tokenizer.total_merges

8

> The **base** version of the tokenizer is ready! 🎉🎉 *(Trust me, it was so easy that I didn't have to look at the actual code once!)*.

Let's now train that on the TS data.

In [143]:
tokenizer = BasicTokenizer()
tokenizer.train(file, vocab_size=50_000)

👆🏻 The above training took me around `2` minutes `30` seconds.

> Note: The `50,000` is the max number of merges, there can be less than that. And the total will be stored in `total_merges` variable.

In [144]:
tokenizer.total_merges

8157

Total `8K` merges! Cool! 🔥

### Wait, not cool 😵
We have given it total `50,000` merges to do, and it could do `8K`. That means **we have merged many trash tokens** and that also means, we have long tokens in the end.

In [166]:
tokenizer.vocab[255 + 8156]

b'Wikimedia Foundation'

In [167]:
tokenizer.vocab[255 + 8155]

b'Wikimedia '

I mean WTF!? It has to happen, right!? Now, let's make it better. Just say, 512 tokens.

In [168]:
tokenizer = BasicTokenizer()
tokenizer.train(file, vocab_size=512)
tokenizer.total_merges

512

In [170]:
tokenizer.vocab[255 + 511]

b'tim'

In [172]:
tokenizer.vocab[255 + 512]

b'ay '

Now, better 👍🏻 

In [173]:
# let's visualize the merges:
for pair in list(tokenizer.merges.keys())[:50]:
    print(tokenizer.vocab[tokenizer.merges[pair]])

b'e '
b', '
b'd '
b'. '
b'r '
b'20'
b's '
b'in'
b'on'
b'ri'
b't '
b'th'
b'ed '
b', 20'
b'an'
b'ar'
b'er '
b'y '
b'al'
b'the '
b'ved '
b'wi'
b'er'
b'on '
b'wif'
b'Re'
b'Swif'
b'or '
b'ch'
b', 201'
b'om'
b'ber '
b' the '
b'ay'
b'en'
b'or'
b'al '
b'em'
b'.\n'
b'rie'
b'ing'
b', 202'
b'ti'
b'ayl'
b'". '
b'll'
b'Tayl'
b'trie'
b'.\n '
b'to'


> We can see, how the *model* has started to detect "Taylor" a lot as a single token! **And also note** that there are the tokens with *numbers, punctuations, new lines* and so on.

So, next up we will solve that by training the `RegexTokenizer`.

<img src="./images/regex-flow.png">

In [4]:
import regex as re

In [91]:
class RegexTokenizer:
    '''
    This is supposed to get a little crazy.
    
    Step 1: Split the text based on the regex pattern.
    Step 2: Now, we have the cleaned words.
    Step 3: Get their raw tokens individually.
    Step 4: Don't merge them yet, because it will nullify the step 1-3. 
    Step 4: Find pairs (stats) for each of the words - while keeping "common" stats across each.
    Step 5: Find the max repetative pair.
    Step 6: Replace that pair in each token group.
    '''
    def __init__(self):
        # initialize the defaut vocab
        self.vocab = {idx:bytes([idx]) for idx in range(256)}
        self.trained=False
        self.GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
        self.GPT4_PATTERN_COMPILED = re.compile(self.GPT4_SPLIT_PATTERN)
    
    def find_most_repeated_pair(self, tokens, counter=None) -> Tuple[Tuple, int, Dict]:
        '''
        Now, this function is changed slightly as we will calculcate the 
        max when needed after this function call.
        
        Also, the `counter` can be passed and updated, and returned.
        Doing this will ensure, the global counter.
        '''
        counter = counter if counter is not None else defaultdict(int)
        for pair in zip(tokens, tokens[1:]):
            counter[pair] += 1
        return counter # will be useful when the counter=None passed.

    def replace_pair_with_new_token(self, tokens, pair, new_idx) -> List:
        new_tokens = [] # this will hold the copy for the new tokens
        idx = 0
        while idx < len(tokens):
            if idx < len(tokens) - 1 and (tokens[idx] == pair[0]) and (tokens[idx + 1] == pair[1]): # this is a match!
                new_tokens.append(new_idx)
                idx += 2
            else: # this is not a match
                new_tokens.append(tokens[idx])
                idx += 1
        return new_tokens
        
    def train(self, blob, vocab_size=None) -> None:
        '''
        This function will train the tokenizer based on the 
        training data given as text.
        
        1. blob: The data in text format that will be used as training
            of the tokenizer.
        
        2. vocab_size: This is "how many new tokens you want to generate"
            - `None` means indefinite; generate all combinations.
            - `int` means the number of merges.
        '''
        self.vocab_size = vocab_size
        
        # First split
        cleaned_text = self.GPT4_PATTERN_COMPILED.findall(blob)
        # Then create the tokens
        self.tokens = [list(map(int, word.encode("utf-8"))) for word in cleaned_text]
        
        
        new_idx = 255
        merges = {}
        for i in range(vocab_size):
            stats = defaultdict(int)
            for token_group in self.tokens:
                # pass the stats, which will be updated in place
                self.find_most_repeated_pair(token_group, stats)
            
            max_pair = max(stats, key=stats.get)
            max_count = stats[max_pair]
            
            if max_count > 1:
                new_idx += 1
                self.tokens = [self.replace_pair_with_new_token(token_group, max_pair, new_idx) for token_group in self.tokens]
                merges[max_pair] = new_idx
            else: # every pair is occuring for once only
                break
        self.total_merges = i+1
        
        ## The training is done now merge the stuff
        for pair, idx in merges.items():
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
        self.merges = merges   
        self.trained = True
        
    def encode(self, text):
        '''
        The goal of this function is to encode the given text into the 
        tokens that are acceptable by our `vocab`.
        
        So, we will need to keep encoding the tokens form the start (top)
        to the bottom.
        
        The `order` of the vocab **is not guerenteed** in the older versions
        of python, so we wil need to rely on the `idx`. The lower the idx
        is, the older that token is!
        '''
        
        if not self.trained:
            raise NotImplementedError("Please first train the tokenizer!")
        
        # tokens = text.encode("utf-8")
        split_words = self.GPT4_PATTERN_COMPILED.findall(text)
        split_tokens = [list(word.encode("utf-8")) for word in split_words]
        
        final_tokens = []
        for chunk in split_tokens:
            while len(chunk) >= 2:
                stats = self.find_most_repeated_pair(chunk)
                # now the goal is to get all pairs of the new tokens
                # we are not interested in the count, just the pairs
                # then check for each pair, if 
                pair_replace = min(self.merges, key=lambda x: stats.get(x, float("inf")))
                if pair_replace in stats:
                    chunk = self.replace_pair_with_new_token(chunk, 
                                                     pair_replace,
                                                     self.merges[pair_replace])
                else:
                    break
            final_tokens.extend(chunk)
        return final_tokens
    
    def decode(self, tokens):
        decoded_stream = [self.vocab[idx] for idx in tokens]
        text = b"".join(decoded_stream)
        return text.decode("utf-8")

In [92]:
text = "This is the name of aayush shah 123"

In [93]:
tokenizer = RegexTokenizer()
tokenizer.train(file, vocab_size=512)
tokenizer.total_merges

512

🫱🏻‍🫲🏻 Let's now see the quality of merges.

In [57]:
# let's visualize the merges:
for pair in list(tokenizer.merges.keys())[:50]:
    print(tokenizer.vocab[tokenizer.merges[pair]])

b'er'
b'20'
b'or'
b'in'
b'ed'
b' t'
b'on'
b'he'
b' S'
b'ar'
b'an'
b' A'
b' the'
b'al'
b'ri'
b'ved'
b'st'
b'wi'
b' R'
b'201'
b' f'
b'202'
b' T'
b'ft'
b'ay'
b' "'
b'wift'
b'et'
b' Swift'
b'ch'
b'ber'
b'at'
b'om'
b'es'
b'en'
b'em'
b'".'
b' ('
b'.\n'
b'ing'
b'lor'
b' M'
b'ig'
b' on'
b'aylor'
b'll'
b'rie'
b' Ret'
b' Retrie'
b' Retrieved'


🔥 I think, it works great!!

In [124]:
list("This is me now123".encode("utf-8"))

[84, 104, 105, 115, 32, 105, 115, 32, 109, 101, 32, 110, 111, 119, 49, 50, 51]

In [125]:
tokenizer.encode("This is me now123")

[84, 104, 355, 549, 344, 101, 435, 397, 49, 486]

🔥🔥 Works!

# Aaai, think we should be done here...
The next steps are regarding comparing the GPT-4 tokenizer and so on. But we can skip that since we have acomplished much of the stuff here.

Let's train GPT from scratch with this new tokenizer!