# Sources:
1 - https://karpathy.ai/zero-to-hero.html  
2 - https://www.youtube.com/watch?v=kCc8FmEb1nY  
3 - https://github.com/karpathy/minbpe/tree/master  

# Import libraries

In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
import time

# Parameters

In [7]:
# Dataset parameters
train_ratio = 0.9
sample_size = 10

SEED = 15

# Load files

In [9]:
from datasets import load_dataset

# Define the path to the dataset
dataset_name = "20231101.fr"

# Load the dataset
raw_dataset = load_dataset("wikimedia/wikipedia", dataset_name)
print(f"Dataset size: {raw_dataset['train'].num_rows}")

  from .autonotebook import tqdm as notebook_tqdm
Resolving data files: 100%|██████████| 17/17 [00:01<00:00,  8.75it/s]


Dataset size: 2564646


# Data preparation

In [10]:
# Create training and evaluation datasets
if sample_size < 0:
    train_sample = round(raw_dataset['train'].num_rows * train_ratio)
    test_sample = round(raw_dataset['train'].num_rows * (1- train_ratio))
else:
    train_sample = round(sample_size * train_ratio)
    test_sample = round(sample_size * (1- train_ratio))

ds_train_test = raw_dataset['train'].train_test_split(train_size=train_sample, test_size=test_sample, seed=SEED)
train_text = ''.join([t['text'] for t in ds_train_test['train']])

# Step 1

In [11]:
from collections import Counter

class BasicTokenizer():
    def __init__(self):
        self.text_to_int = {}
        self.int_to_text = {}
        
    def train(self, text, max_vocab_size, verbose=False):
        self.text_to_int, self.int_to_text = self._create_init_vocab(text)
        train_tokens = self.encode(text)   
        current_vocab_size = len(self.text_to_int)
        
        if verbose: 
            print("Initial vocab size:", current_vocab_size, "Initial sequence length:", len(train_tokens))
        
        while current_vocab_size < max_vocab_size:
            token_pairs = list(zip(train_tokens, train_tokens[1:]))
            most_freq = self._get_most_common_token_pair(token_pairs)
            
            str_pair = ''.join([self.int_to_text[most_freq[0]], self.int_to_text[most_freq[1]]])
            self.text_to_int[str_pair] = current_vocab_size
            self.int_to_text[current_vocab_size] = str_pair
            current_vocab_size += 1

            new_train_tokens = self._replace_token_pairs(train_tokens, most_freq, self.text_to_int[str_pair])

            train_tokens = new_train_tokens

            if verbose:
                print("Updated vocab size:", current_vocab_size, "Updated sequence length:", len(train_tokens))

    def _create_init_vocab(self, text):
        sorted_text = sorted(set(text))
        text_to_int = {c: i for i, c in enumerate(sorted_text)}
        int_to_text = {i: c for i, c in enumerate(sorted_text)}
        return text_to_int, int_to_text

    def _get_most_common_token_pair(self, token_pairs):
        counter = Counter(token_pairs)
        most_freq = counter.most_common(1)[0][0]
        return most_freq
    
    def _replace_token_pairs(self, tokens, target_pair, new_token):
        new_tokens = []
        i = 0
        while i < len(tokens) - 1:
            if (tokens[i], tokens[i+1]) == target_pair:
                new_tokens.append(new_token)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        if i < len(tokens):
            new_tokens.append(tokens[i])
        return new_tokens

    def encode(self, text):
        return [self.text_to_int[char] for char in text]
    
    def decode(self, tokens):
        return "".join([self.int_to_text[i] for i in tokens])


In [12]:
tokenizer = BasicTokenizer()
tokenizer.train(train_text, max_vocab_size=150, verbose=True)

In [13]:
tokenizer.encode('Bonjour')

[25, 64, 63, 59, 64, 70, 67]

In [14]:
tokenizer.decode([25, 64, 63, 59, 64, 70, 67])

'Bonjour'

# Step 2

In [53]:
from collections import Counter
import regex as re

class RegexTokenizer():
    def __init__(self):
        self.text_to_int = {}
        self.int_to_text = {}
        
    def train(self, text, max_vocab_size, verbose=False):
        self.text_to_int, self.int_to_text = self._create_init_vocab(text)
        split_tokens = self.encode(text)
        self.current_vocab_size = len(self.text_to_int)
        
        if verbose:
            self.print_stats(split_tokens)
            
        while self.current_vocab_size < max_vocab_size:
            token_pairs = []
            for t in split_tokens:
                if len(t) > 1:
                    token_pairs.extend(self._get_token_pairs(t))

            most_freq = self._get_most_common_token_pair(token_pairs)
            most_freq_str = ''.join([self.int_to_text[most_freq[0]], self.int_to_text[most_freq[1]]])
            
            self.text_to_int[most_freq_str] = self.current_vocab_size
            self.int_to_text[self.current_vocab_size] = most_freq_str
            
            new_split_tokens = [self._replace_token_pairs(t, most_freq, self.current_vocab_size) for t in split_tokens]
            split_tokens = new_split_tokens
            self.current_vocab_size += 1
            
            if verbose:
                self.print_stats(split_tokens, most_freq_str)

    def print_stats(self, split_tokens, new_token = ""):
        print("Initial vocab size:", self.current_vocab_size, "Initial sequence length:", sum([len(t) for t in split_tokens]), "New token:", new_token)
        
    def _get_token_pairs(self, tokens):
        return list(zip(tokens, tokens[1:]))
    
    def _split_text(self, text):
        split_text = re.findall(GPT4_SPLIT_PATTERN, text)
        return split_text
    
    def _create_init_vocab(self, text):
        sorted_text = sorted(set(text))
        text_to_int = {c: i for i, c in enumerate(sorted_text)}
        int_to_text = {i: c for i, c in enumerate(sorted_text)}
        return text_to_int, int_to_text

    def _get_most_common_token_pair(self, token_pairs):
        counter = Counter(token_pairs)
        most_freq = counter.most_common(1)[0][0]
        return most_freq
    
    def _replace_token_pairs(self, tokens, target_pair, new_token):
        new_tokens = []
        i = 0
        while i < len(tokens) - 1:
            if (tokens[i], tokens[i+1]) == target_pair:
                new_tokens.append(new_token)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        if i < len(tokens):
            new_tokens.append(tokens[i])
        return new_tokens

    def encode(self, text):
        split_text = self._split_text(text)
        encoded_splits = []
        for split in split_text:
            encoded_splits.append([self.text_to_int[t] for t in split])
        return encoded_splits
    
    def decode(self, tokens_list):
        texts = []
        for tokens in tokens_list:
            texts.append("".join([self.int_to_text[i] for i in tokens]))
        return "".join(texts)


In [54]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
tokenizer = RegexTokenizer()
tokenizer.train(train_text, max_vocab_size=150, verbose=True)

Initial vocab size: 94 Initial sequence length: 35611 New token: 
Initial vocab size: 95 Initial sequence length: 34713 New token:  d
Initial vocab size: 96 Initial sequence length: 34164 New token: es
Initial vocab size: 97 Initial sequence length: 33632 New token:  l
Initial vocab size: 98 Initial sequence length: 33115 New token: on
Initial vocab size: 99 Initial sequence length: 32676 New token:  de
Initial vocab size: 100 Initial sequence length: 32243 New token: an
Initial vocab size: 101 Initial sequence length: 31812 New token: en
Initial vocab size: 102 Initial sequence length: 31451 New token:  p
Initial vocab size: 103 Initial sequence length: 31133 New token: ti
Initial vocab size: 104 Initial sequence length: 30823 New token:  c
Initial vocab size: 105 Initial sequence length: 30527 New token: in
Initial vocab size: 106 Initial sequence length: 30242 New token: er
Initial vocab size: 107 Initial sequence length: 29985 New token:  s
Initial vocab size: 108 Initial sequence 

In [55]:
tokenizer.encode("Bonjour le monde")

[[25, 64, 63, 59, 64, 70, 67], [2, 61, 54], [2, 62, 64, 63, 53, 54]]

In [56]:
generated_ids = [[25, 64, 63, 59, 64, 70, 67], [2, 61, 54], [2, 62, 64, 63, 53, 54]]
tokenizer.decode(generated_ids)

'Bonjour le monde'