## What is the BPE tokenization?

### BPE training phase

Now let's figure out how the BPE tokenizer got trained. Let's assume that we have some documents $D$
1. For each $d$, we will transform the documents into word list *in some way*. *For instance, you may choose to split the document by whitespace to get words*.
2. Count the word freq for each word $w$ in $D$, and we can also get the alphabet of $D$ as the initial vocab(plus the `</w>`)
3. For each word, transform the word into a utf-8 char list. We call it a split. *For example, `highest -> h, i, g, h, e, s, t`*
4. Append `</w>` to each utf-8 list. *e.g. `highest -> h, i, g, h, e, s, t, </w>`*
5. Repeat the following steps until any one of the two conditions is met: 1) Vocab reaches the upper limit. 2) Reach the maximum number of iterations
    1. Find **the most frequent pair**, add it to a merge table, and add the merged result to the vocab
    2. Update all splits of all words. *For example, the most frequent pair may be `(h, i)` in our previous example, then we will do `highest -> hi, g, h, e, s, t, </w>`*

You may have 3 puzzles:
1. Why word frequency? Because we want to find the most frequent pair easily
2. Why append `</w>`? Because we want to reconstruct the input later, we use `</w>` to mark that it is the end of a word
3. What if we have multiple pairs with the same frequency? How to handle this may vary in different implementations, but *shouldn't* have much impact in my opinion.

> 💡 You can observe that when the BPE algorithm merges the most frequently occurring pair, it doesn't cross over words.

### How to use a trained BPE?

After we trained a BPE tokenizer, we will obtain a merge table and a vocab. Assuming that we now need to tokenize the text `s`

1. Use the same method as during training, start by splitting `s` into individual words, with each word further divided into utf-8 char.
2. Iterate through the merge table and check if each merge rule can be applied to update the split of each word.

> 💡 An important detail here is that the merge rules we extracted are sorted in descending order of frequency. Thus, by sequentially traversing the merge table, we are *implicitly* incorporating the notion of prioritizing the merging of the most frequently occurring pairs.

## BPE in practice

The API provided by the Huggingface is quite simple. *You may notice that it uses `Char` in the class name, which confirms what I mentioned earlier*

In [1]:
from utils import get_code
from tqdm.auto import tqdm
from pprint import pprint

In [2]:
# use the "test" dataset to speed up "training"
corpus = get_code("test", language="go")

To create t BPE Tokenizer, we leverage the `CharBPETokenizer` provided by the Huggingface.

In [3]:
from datasets import load_dataset
from tokenizers import CharBPETokenizer

# Instantiate tokenizer
tokenizer = CharBPETokenizer()

In [4]:
def batch_iterator(batch_size=256):
    for i in range(0, len(corpus), batch_size):
        yield corpus[i: i + batch_size]

We just need to call `tokenizer.train_from_iterator` method here :)

In [5]:
tokenizer.train_from_iterator(
    batch_iterator(),
    vocab_size=50265,
    min_frequency=2,
    special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<mask>",
    ],
)






Now let's grab one random sample code from the corpus and see the tokenization result

In [6]:
sample = corpus[0]
print(sample)

func mustWaitPinReady(t *testing.T, cli *clientv3.Client) {
	// TODO: decrease timeout after balancer rewrite!!!
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	_, err := cli.Get(ctx, "foo")
	cancel()
	if err != nil {
		t.Fatal(err)
	}
}


In [7]:
print(tokenizer.encode(sample).tokens)

['func</w>', 'must', 'Wait', 'Pin', 'Ready</w>', '(</w>', 't</w>', '*</w>', 'testing</w>', '.</w>', 'T</w>', ',</w>', 'cli</w>', '*</w>', 'clientv3</w>', '.</w>', 'Client</w>', ')</w>', '{</w>', '/</w>', '/</w>', 'TODO</w>', ':</w>', 'decrease</w>', 'timeout</w>', 'after</w>', 'balancer</w>', 'rewrite</w>', '!</w>', '!</w>', '!</w>', 'ctx</w>', ',</w>', 'cancel</w>', ':</w>', '=</w>', 'context</w>', '.</w>', 'WithTimeout</w>', '(</w>', 'context</w>', '.</w>', 'Background</w>', '(</w>', ')</w>', ',</w>', '10</w>', '*</w>', 'time</w>', '.</w>', 'Second</w>', ')</w>', '_</w>', ',</w>', 'err</w>', ':</w>', '=</w>', 'cli</w>', '.</w>', 'Get</w>', '(</w>', 'ctx</w>', ',</w>', '"</w>', 'foo</w>', '"</w>', ')</w>', 'cancel</w>', '(</w>', ')</w>', 'if</w>', 'err</w>', '!</w>', '=</w>', 'nil</w>', '{</w>', 't</w>', '.</w>', 'Fatal</w>', '(</w>', 'err</w>', ')</w>', '}</w>', '}</w>']


According to the tokenization result, we find some interesting things about the BPE Tokenizer
- It learns how to **split function/variable name in camelCase automatically**. *`mustWaitPinReady` -> `['must', 'Wait', 'Pin', 'Ready</w>']`*
- It also **keeps the meaningful keyword of go language**. *`func` means a function declaration in go*

## Implement a BPE Tokenizer

To get a better understanding of the BPE tokenizer, we can try to implement one

In [8]:
from collections import defaultdict, Counter
from pprint import pprint


class BPE:
    def __init__(
        self,
        corpus: list[str],
        vocab_size: int,
        max_iter: int | None = None,
        debug: bool = False,
    ):
        self.corpus = corpus
        self.vocab_size = vocab_size
        self.vocab = []
        self.word_freq = Counter()
        self.splits = {}  # e.g. highest: [high, est</w>]
        self.merges = {}  # e.g. [high, est</w>]: highest
        self.max_iter = max_iter
        self.debug = debug

    def train(self):
        """Train a BPE Tokenizer"""
        # count the word frequency
        for document in self.corpus:
            # split each document in corpus by whitespace
            words = document.split()
            self.word_freq += Counter(words)

        # intialize the self.splits
        for word in self.word_freq:
            self.splits[word] = list(word) + ["</w>"]

        if self.debug:
            print(f"Init splits: {self.splits}")

        alphabet = set()
        for word in self.word_freq:
            alphabet |= set(list(word))
        alphabet.add("</w>")

        self.vocab = list(alphabet)
        self.vocab.sort()

        cnt = 0
        while len(self.vocab) < self.vocab_size:
            if self.max_iter and cnt >= self.max_iter:
                break

            # find the most frequent pair
            pair_freq = self.get_pairs_freq()

            if len(pair_freq) == 0:
                print("No pair available")
                break

            pair = max(pair_freq, key=pair_freq.get)

            self.update_splits(pair[0], pair[1])

            if self.debug:
                print(f"Updated splits: {self.splits}")

            self.merges[pair] = pair[0] + pair[1]

            self.vocab.append(pair[0] + pair[1])

            if self.debug:
                print(
                    f"Most frequent pair({max(pair_freq.values())} times) "
                    f"is : {pair[0]}, {pair[1]}. Vocab size: {len(self.vocab)}"
                )

            cnt += 1

    def update_splits(self, lhs: str, rhs: str):
        """If we see lhs and rhs appear consecutively, we merge them"""
        for word, word_split in self.splits.items():
            new_split = []
            cursor = 0
            while cursor < len(word_split):
                if (
                    word_split[cursor] == lhs
                    and cursor + 1 < len(word_split)
                    and word_split[cursor + 1] == rhs
                ):
                    new_split.append(lhs + rhs)
                    cursor += 2
                else:
                    new_split.append(word_split[cursor])
                    cursor += 1
            self.splits[word] = new_split

            # if word_split != new_split:
            #     print(f"old: {word_split}")
            #     print(f"new: {new_split}")

    def get_pairs_freq(self) -> dict:
        """Compute the pair frequency"""
        pairs_freq = defaultdict(int)
        for word, freq in self.word_freq.items():
            split = self.splits[word]
            for i in range(len(split)):
                if i + 1 < len(split):
                    pairs_freq[(split[i], split[i + 1])] += freq

        return pairs_freq

    def tokenize(self, s: str) -> list[str]:
        splits = [list(t) + ["</w>"] for t in s.split()]

        for lhs, rhs in self.merges:
            for idx, split in enumerate(splits):
                new_split = []
                cursor = 0
                while cursor < len(split):
                    if (
                        cursor + 1 < len(split)
                        and split[cursor] == lhs
                        and split[cursor + 1] == rhs
                    ):
                        new_split.append(lhs + rhs)
                        cursor += 2
                    else:
                        new_split.append(split[cursor])
                        cursor += 1
                assert "".join(new_split) == "".join(split)
                splits[idx] = new_split

        return sum(splits, [])

Let's use the subset of `corpus` for training because the Python implementation is inefficient

In [9]:
bpe = BPE(corpus[:200], vocab_size=2000, debug=False)
bpe.train()

In [10]:
print(bpe.tokenize(sample))

['func</w>', 'must', 'W', 'a', 'it', 'P', 'in', 'Read', 'y', '(t</w>', '*', 'testing.T', ',</w>', 'cli', '</w>', '*', 'client', 'v3.Client', ')</w>', '{</w>', '//</w>', 'TODO:</w>', 'de', 'c', 'rea', 'se</w>', 'time', 'out</w>', 'after</w>', 'bal', 'an', 'c', 'er</w>', 're', 'write', '!', '!', '!', '</w>', 'ctx,</w>', 'cancel</w>', ':=</w>', 'context.With', 'Timeout', '(context.Background', '(),</w>', '1', '0', '*', 'time.Second)</w>', '_,</w>', 'err</w>', ':=</w>', 'cli', '.Get', '(ctx,</w>', '"', 'foo', '")</w>', 'cancel()</w>', 'if</w>', 'err</w>', '!=</w>', 'nil</w>', '{</w>', 't.Fatal(err)</w>', '}</w>', '}</w>']


Let's do a comparison

In [11]:
print(tokenizer.encode(sample).tokens)

['func</w>', 'must', 'Wait', 'Pin', 'Ready</w>', '(</w>', 't</w>', '*</w>', 'testing</w>', '.</w>', 'T</w>', ',</w>', 'cli</w>', '*</w>', 'clientv3</w>', '.</w>', 'Client</w>', ')</w>', '{</w>', '/</w>', '/</w>', 'TODO</w>', ':</w>', 'decrease</w>', 'timeout</w>', 'after</w>', 'balancer</w>', 'rewrite</w>', '!</w>', '!</w>', '!</w>', 'ctx</w>', ',</w>', 'cancel</w>', ':</w>', '=</w>', 'context</w>', '.</w>', 'WithTimeout</w>', '(</w>', 'context</w>', '.</w>', 'Background</w>', '(</w>', ')</w>', ',</w>', '10</w>', '*</w>', 'time</w>', '.</w>', 'Second</w>', ')</w>', '_</w>', ',</w>', 'err</w>', ':</w>', '=</w>', 'cli</w>', '.</w>', 'Get</w>', '(</w>', 'ctx</w>', ',</w>', '"</w>', 'foo</w>', '"</w>', ')</w>', 'cancel</w>', '(</w>', ')</w>', 'if</w>', 'err</w>', '!</w>', '=</w>', 'nil</w>', '{</w>', 't</w>', '.</w>', 'Fatal</w>', '(</w>', 'err</w>', ')</w>', '}</w>', '}</w>']


## Wrap up

The BPE tokenize is simple and practical, but **when you delve into its implementation, you will encounter several details. However, it's precisely by engaging with these intricacies that your understanding of BPE becomes more profound**.

Let's also talk about some limitations of BPE. For instance, you will notice that we are using whitespace to split text, which **works for whitespaced language**. However, for languages like Chinese, spaces don't define word boundaries, which makes things more complex and calls for a better tokenizing method.