# Sources:
1 - https://karpathy.ai/zero-to-hero.html  
2 - https://www.youtube.com/watch?v=kCc8FmEb1nY  
3 - https://github.com/karpathy/minbpe/tree/master  

# Import libraries

In [93]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
import time

# Parameters

In [94]:
# Dataset parameters
train_ratio = 0.9
sample_size = 10

SEED = 15

# Load files

In [95]:
from datasets import load_dataset

# Define the path to the dataset
dataset_name = "20231101.fr"

# Load the dataset
raw_dataset = load_dataset("wikimedia/wikipedia", dataset_name)
print(f"Dataset size: {raw_dataset['train'].num_rows}")

Resolving data files: 100%|██████████| 17/17 [00:01<00:00,  9.05it/s]


Dataset size: 2564646


# Data preparation

In [96]:
# Create training and evaluation datasets
if sample_size < 0:
    train_sample = round(raw_dataset['train'].num_rows * train_ratio)
    test_sample = round(raw_dataset['train'].num_rows * (1- train_ratio))
else:
    train_sample = round(sample_size * train_ratio)
    test_sample = round(sample_size * (1- train_ratio))

ds_train_test = raw_dataset['train'].train_test_split(train_size=train_sample, test_size=test_sample, seed=SEED)
train_text = ''.join([t['text'] for t in ds_train_test['train']])

OSError: [WinError 1224] L’opération demandée n’a pu s’accomplir sur un fichier ayant une section mappée utilisateur ouverte

# Step 1
Create a simple Byte Pair Encoder (BPE)

In [97]:
from collections import Counter, OrderedDict

class BasicTokenizer():
    def __init__(self):
        self.text_to_int = {}  # {text: int}
        self.int_to_text = {}  # {int: text}
        self.merged_dict = OrderedDict()  # {(token, token): token}
        
    def train(self, text, max_vocab_size, verbose=False):
        self.text_to_int, self.int_to_text = self._create_initial_vocabulary(text)  
        self.current_vocab_size = len(self.text_to_int)
        
        train_tokens = self.encode(text) 
        
        if verbose: 
            self.print_stats(train_tokens)
        
        while self.current_vocab_size < max_vocab_size:
            token_pairs = self._get_pairs(train_tokens)
            most_freq = self._get_most_common_token_pair(token_pairs)
            
            str_pair = ''.join([self.int_to_text[most_freq[0]], self.int_to_text[most_freq[1]]])
            self.text_to_int[str_pair] = self.current_vocab_size
            self.int_to_text[self.current_vocab_size] = str_pair
            self.merged_dict[most_freq] = self.text_to_int[str_pair]
            self.current_vocab_size += 1

            new_train_tokens = self._replace_token_pairs(train_tokens, most_freq, self.text_to_int[str_pair])

            train_tokens = new_train_tokens

            if verbose:
                self.print_stats(train_tokens, str_pair, most_freq)

    def print_stats(self, train_tokens, str_pair="", token_pair=None):
        print("Updated vocab size:", self.current_vocab_size, 
              "Updated sequence length:", len(train_tokens), 
              "New vocab:", str_pair, 
              "Token pair:", token_pair)
        
    def _get_pairs(self, tokens):
        return list(zip(tokens, tokens[1:]))
    
    def _create_initial_vocabulary(self, text):
        sorted_text = sorted(set(text))
        text_to_int = {c: i for i, c in enumerate(sorted_text)}
        int_to_text = {i: c for i, c in enumerate(sorted_text)}
        return text_to_int, int_to_text

    def _get_most_common_token_pair(self, token_pairs):
        counter = Counter(token_pairs)
        most_freq = counter.most_common(1)[0][0]
        return most_freq
    
    def _replace_token_pairs(self, tokens, target_pair, new_token):
        """
        Replace all occurrences of `target_pair` in `tokens` by `new_token`.
        """
        new_tokens = []
        i = 0
        while i < len(tokens) - 1:
            if (tokens[i], tokens[i+1]) == target_pair:
                new_tokens.append(new_token)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        if i < len(tokens):
            new_tokens.append(tokens[i])
        return new_tokens

    def encode(self, text):
        tokens = [self.text_to_int[char] for char in text]
        for key in self.merged_dict:
            tokens = self._replace_token_pairs(tokens, key, self.merged_dict[key])
        return tokens
            
    
    def decode(self, tokens):
        return "".join([self.int_to_text[i] for i in tokens])


In [98]:
tokenizer = BasicTokenizer()
tokenizer.train(train_text, max_vocab_size=150, verbose=True)

text = "Bonjour le monde, comment va ta journée?"
code = tokenizer.encode(text)
print(f"len original text: {len(text)}, len bpe text: {len(code)}") 
print(code)
print(tokenizer.decode(code))

Updated vocab size: 94 Updated sequence length: 35611 New vocab:  Token pair: None
Updated vocab size: 95 Updated sequence length: 34076 New vocab: e  Token pair: (54, 2)
Updated vocab size: 96 Updated sequence length: 33269 New vocab: s  Token pair: (68, 2)
Updated vocab size: 97 Updated sequence length: 32752 New vocab: on Token pair: (64, 63)
Updated vocab size: 98 Updated sequence length: 32259 New vocab: t  Token pair: (69, 2)
Updated vocab size: 99 Updated sequence length: 31824 New vocab: de  Token pair: (53, 94)
Updated vocab size: 100 Updated sequence length: 31391 New vocab: an Token pair: (50, 63)
Updated vocab size: 101 Updated sequence length: 30960 New vocab: en Token pair: (54, 63)
Updated vocab size: 102 Updated sequence length: 30598 New vocab: es  Token pair: (54, 95)
Updated vocab size: 103 Updated sequence length: 30280 New vocab: ti Token pair: (69, 58)
Updated vocab size: 104 Updated sequence length: 29982 New vocab: ,  Token pair: (6, 2)
Updated vocab size: 105 U

# Step 2
Improve the BPE from step 1 by splitting first the text into words and then applying the BPE to each word. This will avoid having tokens that are composed of multiple words or heterogeneous (word + number for example)

In [99]:
import regex as re

class RegexTokenizer():
    def __init__(self):
        self.text_to_int = {}  # {text: int}
        self.int_to_text = {}  # {int: text}
        self.merged_dict = OrderedDict()  # {(token, token): token}
        
    def train(self, text, max_vocab_size, verbose=False):
        self.text_to_int, self.int_to_text = self._create_initial_vocabulary(text)
        self.current_vocab_size = len(self.text_to_int)
        
        tokens_splits = self.encode(text)
        
        if verbose: 
            self.print_stats(tokens_splits)
            
        while self.current_vocab_size < max_vocab_size:
            token_pairs = []
            for tokens_split in tokens_splits:
                if len(tokens_split) > 1:
                    token_pairs.extend(self._get_pairs(tokens_split))
                    
            most_freq = self._get_most_common_token_pair(token_pairs)
            
            str_pair = ''.join([self.int_to_text[most_freq[0]], self.int_to_text[most_freq[1]]])
            self.text_to_int[str_pair] = self.current_vocab_size
            self.int_to_text[self.current_vocab_size] = str_pair
            self.merged_dict[most_freq] = self.text_to_int[str_pair]
            
            new_tokens_splits = [self._replace_token_pairs(t, most_freq, self.current_vocab_size) for t in tokens_splits]
            tokens_splits = new_tokens_splits
            self.current_vocab_size += 1
            
            if verbose: 
                self.print_stats(tokens_splits, str_pair, most_freq)

    def print_stats(self, tokens_splits, str_pair="", token_pair=None):
        print("Updated vocab size:", self.current_vocab_size, 
              "Updated sequence length:", sum([len(split) for split in tokens_splits]),
              "New vocab:", str_pair, 
              "Token pair:", token_pair)
        
    def _get_pairs(self, tokens):
        return list(zip(tokens, tokens[1:]))
    
    def _create_initial_vocabulary(self, text):
        sorted_text = sorted(set(text))
        text_to_int = {c: i for i, c in enumerate(sorted_text)}
        int_to_text = {i: c for i, c in enumerate(sorted_text)}
        return text_to_int, int_to_text

    def _get_most_common_token_pair(self, token_pairs):
        counter = Counter(token_pairs)
        most_freq = counter.most_common(1)[0][0]
        return most_freq
    
    def _replace_token_pairs(self, tokens, target_pair, new_token):
        """
        Replace all occurrences of `target_pair` in `tokens` by `new_token`.
        """
        new_tokens = []
        i = 0
        while i < len(tokens) - 1:
            if (tokens[i], tokens[i+1]) == target_pair:
                new_tokens.append(new_token)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        if i < len(tokens):
            new_tokens.append(tokens[i])
        return new_tokens

    def _simple_encode(self, text):
        
        tokens = [self.text_to_int[char] for char in text]
        for key in self.merged_dict:
            tokens = self._replace_token_pairs(tokens, key, self.merged_dict[key])
        return tokens
    
    def encode(self, text):
        text_splits = self._split_text(text)
        tokens_splits = [self._simple_encode(text_split) for text_split in text_splits]
        return tokens_splits
        
    def _split_text(self, text):
        split_text = re.findall(GPT4_SPLIT_PATTERN, text)
        return split_text
    
    def decode(self, tokens_splits):
        texts = []
        for tokens in tokens_splits:
            texts.append("".join([self.int_to_text[i] for i in tokens]))
        return "".join(texts)


In [102]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
tokenizer = RegexTokenizer()
tokenizer.train(train_text, max_vocab_size=150, verbose=True)

text = "Bonjour le monde, comment va ta journée?"
code = tokenizer.encode(text)
print(f"len original text: {len(text)}, len bpe text: {sum([len(c) for c in code])}")
print(code)
print(tokenizer.decode(code))

Updated vocab size: 94 Updated sequence length: 35611 New vocab:  Token pair: None
Updated vocab size: 95 Updated sequence length: 34713 New vocab:  d Token pair: (2, 53)
Updated vocab size: 96 Updated sequence length: 34164 New vocab: es Token pair: (54, 68)
Updated vocab size: 97 Updated sequence length: 33632 New vocab:  l Token pair: (2, 61)
Updated vocab size: 98 Updated sequence length: 33115 New vocab: on Token pair: (64, 63)
Updated vocab size: 99 Updated sequence length: 32676 New vocab:  de Token pair: (94, 54)
Updated vocab size: 100 Updated sequence length: 32243 New vocab: an Token pair: (50, 63)
Updated vocab size: 101 Updated sequence length: 31812 New vocab: en Token pair: (54, 63)
Updated vocab size: 102 Updated sequence length: 31451 New vocab:  p Token pair: (2, 65)
Updated vocab size: 103 Updated sequence length: 31133 New vocab: ti Token pair: (69, 58)
Updated vocab size: 104 Updated sequence length: 30823 New vocab:  c Token pair: (2, 52)
Updated vocab size: 105 U