In [1]:
from transformers import AutoTokenizer
from collections import defaultdict

In [2]:
# Corpus for testing
corpus = '''
    Object raspberrypi functools dict kwargs. Gevent raspberrypi functools. Dunder raspberrypi decorator dict didn't lambda zip import pyramid, she lambda iterate?
    Kwargs raspberrypi diversity unit object gevent. Import fall integration decorator unit django yield functools twisted. Dunder integration decorator he she future.
    Python raspberrypi community pypy. 
    Kwargs integration beautiful test reduce gil python closure. Gevent he integration generator fall test kwargs raise didn't visor he itertools...
    Reduce integration coroutine bdfl he python. Cython didn't integration while beautiful list python didn't nit!
    Object fall diversity 2to3 dunder script. Python fall for: integration exception dict kwargs dunder pycon. Import raspberrypi beautiful test import six web. Future 
    integration mercurial self script web. Return raspberrypi community test she stable.
    Django raspberrypi mercurial unit import yield raspberrypi visual rocksdahouse. Dunder raspberrypi mercurial list reduce class test scipy helmet zip?
'''

In [3]:
print(corpus)


    Object raspberrypi functools dict kwargs. Gevent raspberrypi functools. Dunder raspberrypi decorator dict didn't lambda zip import pyramid, she lambda iterate?
    Kwargs raspberrypi diversity unit object gevent. Import fall integration decorator unit django yield functools twisted. Dunder integration decorator he she future.
    Python raspberrypi community pypy. 
    Kwargs integration beautiful test reduce gil python closure. Gevent he integration generator fall test kwargs raise didn't visor he itertools...
    Reduce integration coroutine bdfl he python. Cython didn't integration while beautiful list python didn't nit!
    Object fall diversity 2to3 dunder script. Python fall for: integration exception dict kwargs dunder pycon. Import raspberrypi beautiful test import six web. Future 
    integration mercurial self script web. Return raspberrypi community test she stable.
    Django raspberrypi mercurial unit import yield raspberrypi visual rocksdahouse. Dunder raspberrypi me

## Normalisation

The normalization step involves some general cleanup, such as removing needless whitespace, lowercasing, and/or removing accents. If you’re familiar with Unicode normalization (such as NFC or NFKC), this is also something the tokenizer may apply.

The Transformers tokenizer has an attribute called backend_tokenizer that provides access to the underlying tokenizer from the Tokenizers library:

In [4]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
print(type(tokenizer.backend_tokenizer))

<class 'tokenizers.Tokenizer'>


In [5]:
print(tokenizer.backend_tokenizer.normalizer.normalize_str(corpus))

     Object raspberrypi functools dict kwargs. Gevent raspberrypi functools. Dunder raspberrypi decorator dict didn't lambda zip import pyramid, she lambda iterate?     Kwargs raspberrypi diversity unit object gevent. Import fall integration decorator unit django yield functools twisted. Dunder integration decorator he she future.     Python raspberrypi community pypy.      Kwargs integration beautiful test reduce gil python closure. Gevent he integration generator fall test kwargs raise didn't visor he itertools...     Reduce integration coroutine bdfl he python. Cython didn't integration while beautiful list python didn't nit!     Object fall diversity 2to3 dunder script. Python fall for: integration exception dict kwargs dunder pycon. Import raspberrypi beautiful test import six web. Future      integration mercurial self script web. Return raspberrypi community test she stable.     Django raspberrypi mercurial unit import yield raspberrypi visual rocksdahouse. Dunder raspberrypi me

In [6]:
print(tokenizer.backend_tokenizer.normalizer.normalize_str("Héllò hôw are ü?"))

Héllò hôw are ü?


## Pre-tokenizer

A tokenizer cannot be trained on raw text alone. Instead, we first need to split the texts into small entities, like words. That’s where the pre-tokenization step comes in. A word-based tokenizer can simply split a raw text into words on whitespace and punctuation. Those words will be the boundaries of the subtokens the tokenizer can learn during its training.

In [7]:
word_freqs = defaultdict(int)
words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(corpus.strip())
new_words = [word for word, offset in words_with_offsets]
for word in new_words:
    word_freqs[word] += 1
print(word_freqs)

defaultdict(<class 'int'>, {'Object': 2, 'raspberrypi': 10, 'functools': 3, 'dict': 3, 'kwargs': 3, '.': 17, 'Gevent': 2, 'Dunder': 3, 'decorator': 3, 'didn': 4, "'": 4, 't': 4, 'lambda': 2, 'zip': 2, 'import': 3, 'pyramid': 1, ',': 1, 'she': 3, 'iterate': 1, '?': 2, 'Kwargs': 2, 'diversity': 2, 'unit': 3, 'object': 1, 'gevent': 1, 'Import': 2, 'fall': 4, 'integration': 8, 'django': 1, 'yield': 2, 'twisted': 1, 'he': 4, 'future': 1, 'Python': 2, 'community': 2, 'pypy': 1, 'beautiful': 3, 'test': 5, 'reduce': 2, 'gil': 1, 'python': 3, 'closure': 1, 'generator': 1, 'raise': 1, 'visor': 1, 'itertools': 1, 'Reduce': 1, 'coroutine': 1, 'bdfl': 1, 'Cython': 1, 'while': 1, 'list': 2, 'nit': 1, '!': 1, '2to3': 1, 'dunder': 2, 'script': 2, 'for': 1, ':': 1, 'exception': 1, 'pycon': 1, 'six': 1, 'web': 2, 'Future': 1, 'mercurial': 3, 'self': 1, 'Return': 1, 'stable': 1, 'Django': 1, 'visual': 1, 'rocksdahouse': 1, 'class': 1, 'scipy': 1, 'helmet': 1})


## Model

Like BPE, WordPiece starts from a small vocabulary including the special tokens used by the model and the initial alphabet. Since it identifies subwords by adding a prefix (like ## for BERT), each word is initially split by adding that prefix to all the characters inside the word. So, the result look like below:

In [8]:
alphabet = []
for word in word_freqs.keys():
    if word[0] not in alphabet:
        alphabet.append(word[0])
    for letter in word[1:]:
        if f"##{letter}" not in alphabet:
            alphabet.append(f"##{letter}")

alphabet.sort()
alphabet

print(alphabet)

['!', '##3', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##j', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##x', '##y', "'", ',', '.', '2', ':', '?', 'C', 'D', 'F', 'G', 'I', 'K', 'O', 'P', 'R', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z']


Thus, the initial alphabet contains all the characters present at the beginning of a word and the characters present inside a word preceded by the WordPiece prefix.

We also add the special tokens used by the model at the beginning of that vocabulary. In the case of BERT, it’s the list ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]:

In [9]:
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet.copy()
print(vocab)

['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '!', '##3', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##j', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##x', '##y', "'", ',', '.', '2', ':', '?', 'C', 'D', 'F', 'G', 'I', 'K', 'O', 'P', 'R', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z']


This list vocab will be our output at the end after adding all the commun pairs.

Next we need to split each word, with all the letters that are not the first prefixed by ##:

In [10]:
splits = {
    word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in word_freqs.keys()
}

In [11]:
print(splits)

{'Object': ['O', '##b', '##j', '##e', '##c', '##t'], 'raspberrypi': ['r', '##a', '##s', '##p', '##b', '##e', '##r', '##r', '##y', '##p', '##i'], 'functools': ['f', '##u', '##n', '##c', '##t', '##o', '##o', '##l', '##s'], 'dict': ['d', '##i', '##c', '##t'], 'kwargs': ['k', '##w', '##a', '##r', '##g', '##s'], '.': ['.'], 'Gevent': ['G', '##e', '##v', '##e', '##n', '##t'], 'Dunder': ['D', '##u', '##n', '##d', '##e', '##r'], 'decorator': ['d', '##e', '##c', '##o', '##r', '##a', '##t', '##o', '##r'], 'didn': ['d', '##i', '##d', '##n'], "'": ["'"], 't': ['t'], 'lambda': ['l', '##a', '##m', '##b', '##d', '##a'], 'zip': ['z', '##i', '##p'], 'import': ['i', '##m', '##p', '##o', '##r', '##t'], 'pyramid': ['p', '##y', '##r', '##a', '##m', '##i', '##d'], ',': [','], 'she': ['s', '##h', '##e'], 'iterate': ['i', '##t', '##e', '##r', '##a', '##t', '##e'], '?': ['?'], 'Kwargs': ['K', '##w', '##a', '##r', '##g', '##s'], 'diversity': ['d', '##i', '##v', '##e', '##r', '##s', '##i', '##t', '##y'], 'unit':

Like BPE, WordPiece learns merge rules. The main difference is the way the pair to be merged is selected. Instead of selecting the most frequent pair, WordPiece computes a score for each pair, using the following formula:
$$
     score=(freq\_of\_pair)/(freq\_of\_first\_element × freq\_of\_second\_element)
$$
By dividing the frequency of the pair by the product of the frequencies of each of its parts, the algorithm prioritizes the merging of pairs where the individual parts are less frequent in the vocabulary.

Now that we are ready for training, let’s write a function that computes the score of each pair. We’ll need to use this at each step of the training:

In [12]:
#function that computes the scores for each pair of successive elements in each word
def compute_pair_scores(splits):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq

    scores = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores

In [13]:
#Example of score computing
pair_scores = compute_pair_scores(splits)
for i, key in enumerate(pair_scores.keys()):
    print(f"{key}: {pair_scores[key]}")
    if i >= 5:
        break

('O', '##b'): 0.05555555555555555
('##b', '##j'): 0.03333333333333333
('##j', '##e'): 0.007142857142857143
('##e', '##c'): 0.002976190476190476
('##c', '##t'): 0.004934210526315789
('r', '##a'): 0.015714285714285715


Now, finding the pair with the best score only takes a quick loop:

In [15]:
best_pair = ""
max_score = None
for pair, score in pair_scores.items():
    if max_score is None or max_score < score:
        best_pair = pair
        max_score = score

print(best_pair, max_score)

('e', '##x') 0.5


So the first merge to learn is ('e', '##x') -> 'ex', and we add 'ex' to the vocabulary dic that we've created before:

In [16]:
vocab.append("ex")

To continue, we need to apply that merge in our splits dictionary. Let’s write another function for this:

In [18]:
#function that apply the apply the previous lerge in the splits dictionnary 
def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

In [19]:
#And we can have a look at the result of the first merge:
splits = merge_pair("e", "##x", splits)
splits["exception"]

['ex', '##c', '##e', '##p', '##t', '##i', '##o', '##n']

In [20]:
len(vocab)

69

Now we have everything we need to loop until we have learned all the merges we want. Let’s aim for a vocab size of 100:

In [21]:
vocab_size = 100
while len(vocab) < vocab_size:
    scores = compute_pair_scores(splits)
    best_pair, max_score = "", None
    for pair, score in scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    splits = merge_pair(*best_pair, splits)
    new_token = (
        best_pair[0] + best_pair[1][2:]
        if best_pair[1].startswith("##")
        else best_pair[0] + best_pair[1]
    )
    vocab.append(new_token)

In [22]:
print(vocab)

['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '!', '##3', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##j', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##x', '##y', "'", ',', '.', '2', ':', '?', 'C', 'D', 'F', 'G', 'I', 'K', 'O', 'P', 'R', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'ex', 'kw', 'Kw', 'tw', 'Im', 'Ob', 'Obj', 'Dj', 'ob', 'obj', 'dj', 'exc', '##ck', 'py', 'Py', 'Cy', 'Du', 'Fu', 'Imp', 'sh', 'wh', '##cks', '##cksd', 'sc', '##pb', '##mb', '##mbd', '##yp', '##mp', 'imp', '##mm', '##py']


To tokenize a new text, we pre-tokenize it, split it, then apply the tokenization algorithm on each word. That is, we look for the biggest subword starting at the beginning of the first word and split it, then we repeat the process on the second part, and so on for the rest of that word and the following words in the text:

In [23]:
def encode_word(word):
    tokens = []
    while len(word) > 0:
        i = len(word)
        while i > 0 and word[:i] not in vocab:
            i -= 1
        if i == 0:
            return ["[UNK]"]
        tokens.append(word[:i])
        word = word[i:]
        if len(word) > 0:
            word = f"##{word}"
    return tokens

In [24]:
print(encode_word("exception"))
print(encode_word("raspberrypi"))
print(encode_word("python"))
print(encode_word("tokenize"))

['exc', '##e', '##p', '##t', '##i', '##o', '##n']
['r', '##a', '##s', '##pb', '##e', '##r', '##r', '##yp', '##i']
['py', '##t', '##h', '##o', '##n']
['[UNK]']


Now, let’s write a function that tokenizes a text:

In [25]:
def tokenize(text):
    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
    pre_tokenized_text = [word for word, offset in pre_tokenize_result]
    encoded_words = [encode_word(word) for word in pre_tokenized_text]
    return sum(encoded_words, [])

In [26]:
tokenize("This is a simple demo of WordPiece tokenization")

['[UNK]',
 'i',
 '##s',
 '[UNK]',
 's',
 '##i',
 '##mp',
 '##l',
 '##e',
 'd',
 '##e',
 '##m',
 '##o',
 'o',
 '##f',
 '[UNK]',
 '[UNK]']