# WordPiece Tokenization

## 1 语料库

In [126]:
corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

## 2 预分词

WordPiece分词算法是Google为预训练BERT模型而开发的，由于我们将会手撸一个WordPiece分词器，所以我们首先需要引入`bert-base-cased`预训练模型对应的分词器来预分词。

In [127]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [128]:
words = []
for sent in corpus:
    for word, _ in tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(sent):
        words.append(word)
words[:10]

['This',
 'is',
 'the',
 'Hugging',
 'Face',
 'Course',
 '.',
 'This',
 'chapter',
 'is']

## 3 统计词频

In [129]:
from collections import defaultdict

word_freqs = defaultdict(int)
for word in words:
    word_freqs[word] += 1
word_freqs

defaultdict(int,
            {'This': 3,
             'is': 2,
             'the': 1,
             'Hugging': 1,
             'Face': 1,
             'Course': 1,
             '.': 4,
             'chapter': 1,
             'about': 1,
             'tokenization': 1,
             'section': 1,
             'shows': 1,
             'several': 1,
             'tokenizer': 1,
             'algorithms': 1,
             'Hopefully': 1,
             ',': 1,
             'you': 1,
             'will': 1,
             'be': 1,
             'able': 1,
             'to': 1,
             'understand': 1,
             'how': 1,
             'they': 1,
             'are': 1,
             'trained': 1,
             'and': 1,
             'generate': 1,
             'tokens': 1})

## 4 初始化词汇表

按照WordPiece的规定，初始的词汇表是由语料库中所有单词的首字母集合、所有以“##”前缀标记的词内字母集合构成的以及所有BERT模型指定的所有特殊字符集合构成的：

In [130]:
alphabet = []
chars = set()

for word in words:
    if word[0] not in chars:
        chars.add(word[0])
        alphabet.append(word[0])
    for c in word[1:]:
        pc = f"##{c}"  # c with prefix
        if pc not in chars:
            chars.add(pc)
            alphabet.append(pc)

alphabet.sort()
alphabet

['##a',
 '##b',
 '##c',
 '##d',
 '##e',
 '##f',
 '##g',
 '##h',
 '##i',
 '##k',
 '##l',
 '##m',
 '##n',
 '##o',
 '##p',
 '##r',
 '##s',
 '##t',
 '##u',
 '##v',
 '##w',
 '##y',
 '##z',
 ',',
 '.',
 'C',
 'F',
 'H',
 'T',
 'a',
 'b',
 'c',
 'g',
 'h',
 'i',
 's',
 't',
 'u',
 'w',
 'y']

In [131]:
special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]

In [132]:
vocab = special_tokens + alphabet.copy()
vocab[:10]

['[PAD]',
 '[UNK]',
 '[CLS]',
 '[SEP]',
 '[MASK]',
 '##a',
 '##b',
 '##c',
 '##d',
 '##e']

## 5 分割单词

接下来我们需要将语料库中的所有单词分割为首字母以及为非首字母都加上“##”前缀：

In [133]:
splits = {
    word: [word[i] if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in words
}
splits

{'This': ['T', '##h', '##i', '##s'],
 'is': ['i', '##s'],
 'the': ['t', '##h', '##e'],
 'Hugging': ['H', '##u', '##g', '##g', '##i', '##n', '##g'],
 'Face': ['F', '##a', '##c', '##e'],
 'Course': ['C', '##o', '##u', '##r', '##s', '##e'],
 '.': ['.'],
 'chapter': ['c', '##h', '##a', '##p', '##t', '##e', '##r'],
 'about': ['a', '##b', '##o', '##u', '##t'],
 'tokenization': ['t',
  '##o',
  '##k',
  '##e',
  '##n',
  '##i',
  '##z',
  '##a',
  '##t',
  '##i',
  '##o',
  '##n'],
 'section': ['s', '##e', '##c', '##t', '##i', '##o', '##n'],
 'shows': ['s', '##h', '##o', '##w', '##s'],
 'several': ['s', '##e', '##v', '##e', '##r', '##a', '##l'],
 'tokenizer': ['t', '##o', '##k', '##e', '##n', '##i', '##z', '##e', '##r'],
 'algorithms': ['a',
  '##l',
  '##g',
  '##o',
  '##r',
  '##i',
  '##t',
  '##h',
  '##m',
  '##s'],
 'Hopefully': ['H', '##o', '##p', '##e', '##f', '##u', '##l', '##l', '##y'],
 ',': [','],
 'you': ['y', '##o', '##u'],
 'will': ['w', '##i', '##l', '##l'],
 'be': ['b', '##e

## 6 计算分数

WordPiece合并相邻的字符对/子词对/词元对的依据是为每对字符/子词/词元计算一个分数：

$$
score = \frac{freq\_of\_pair}{freq\_of\_first\_element \times freq\_of\_second\_elements}
$$

然后，选取那些分数较大的Pair来进行合并。

In [134]:
def bigram():
    for word, tokens in splits.items():
        print(f"{word}:", end=" ")
        for i in range(len(tokens) - 1):
            print(f"{(tokens[i], tokens[i+1])}", end=", ")
        break

bigram()

This: ('T', '##h'), ('##h', '##i'), ('##i', '##s'), 

In [135]:
def compute_scores(word_freqs, splits):
    # 先统计各个字符的频率以及相邻Pair的频率：
    char_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            char_freqs[split[0]] += freq
        else:
            for i in range(len(split) - 1):
                pair = (split[i], split[i+1])
                pair_freqs[pair] += freq
                char_freqs[split[i]] += freq
            char_freqs[split[-1]] += freq
    # 然后计算各个Pair的分数：
    scores = {
        pair: freq / (char_freqs[pair[0]] * char_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores

In [136]:
scores = compute_scores(word_freqs, splits)
for i, key in enumerate(scores.keys()):
    print(f"{key}: {scores[key]}")
    if i == 5:
        break

('T', '##h'): 0.125
('##h', '##i'): 0.03409090909090909
('##i', '##s'): 0.02727272727272727
('i', '##s'): 0.1
('t', '##h'): 0.03571428571428571
('##h', '##e'): 0.011904761904761904


## 7 合并子词

In [137]:
# def merge_pair(first, second, splits):
#     for word, split in splits.items():
#         if len(split) == 1:
#             continue
#         i = 0
#         while i < len(split) - 1:
#             if split[i] == first and split[i+1] == second:
#                 pair = first + second[2:] if second.startswith("##") else first + second
#                 split = split[:i] + [pair] + split[i+2:]
#             else:
#                 i += 1
#         splits[word] = split

def merge_pair(first, second, splits):
    for word, split in splits.items():
        if len(split) == 1:
            continue
        i = 0
        new_split = split.copy()
        while i < len(split) - 1:
            if split[i] == first and split[i+1] == second:
                pair = first + second[2:] if second.startswith("##") else first + second
                new_split = split[:i] + [pair] + split[i+2:]
                i += 2
            else:
                i += 1
        splits[word] = new_split

In [138]:
merge_pair("a", "##b", splits)
splits["about"]

['ab', '##o', '##u', '##t']

In [139]:
def merge_pair_v2(first, second, splits):
    for word, split in splits.items():
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == first and split[i+1] == second:
                pair = first + second[2:] if second.startswith("##") else first + second
                new_split = split[:i] + [pair] + split[i+2:]
                i += 2
            else:
                i += 1
        splits[word] = new_split

In [140]:
from copy import deepcopy

splits_copyed = deepcopy(splits)
merge_pair("##u", "##t", splits_copyed)
splits_copyed["about"]

['ab', '##o', '##ut']

## 8 训练

找到最大分数的Pair：

In [141]:
sorted_scores = sorted(scores.items(), key=lambda x: x[-1], reverse=True)
sorted_scores[:5]


[(('a', '##b'), 0.2),
 (('##f', '##u'), 0.2),
 (('F', '##a'), 0.14285714285714285),
 (('T', '##h'), 0.125),
 (('c', '##h'), 0.125)]

In [142]:
vocab.append("ab")

In [143]:
vocab_size = 70
while len(vocab) < vocab_size:
    scores = compute_scores(word_freqs, splits)
    sorted_scores = sorted(scores.items(), key=lambda x: x[-1], reverse=True)
    best = sorted_scores[0]
    pair, _ = best
    merge_pair(*pair, splits)
    token = pair[0] + pair[1][2:] if pair[1].startswith("##") else pair[0] + pair[1]
    vocab.append(token)

In [144]:
expected = ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k',
 '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H',
 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y', 'ab', '##fu', 'Fa', 'Fac', '##ct', '##ful', '##full', '##fully',
 'Th', 'ch', '##hm', 'cha', 'chap', 'chapt', '##thm', 'Hu', 'Hug', 'Hugg', 'sh', 'th', 'is', '##thms', '##za', '##zat',
 '##ut']

In [145]:
try:
    assert vocab == expected
except AssertionError:
    print(f"len of vocab: {len(vocab)}")
    print(f"len of expected: {len(expected)}")
    s1 = set(vocab)
    s2 = set(expected)
    print(f"diff between vocab and expected: {s1.difference(s2)}")
    print(f"diff between expected and vocab: {s2.difference(s1)}")
else:
    print(f"actual vocab is equal to the expected!")

actual vocab is equal to the expected!


In [146]:
if vocab != expected:
    "ab" not in s1 and "##ta" not in s2

## 9 分词

In [147]:
def encode(word, vocab_set):
    tokens = []
    while len(word) > 0:
        end = len(word)
        while end > 0 and word[:end] not in vocab_set:
            end -= 1
        if end == 0:
            # tokens.append("[UNK]")
            tokens = ["[UNK]"]
            break
        # if word[:end] in vocab_set:
        tokens.append(word[:end])
        word = word[end:]
        if len(word) > 0:
            word = f"##{word}"
    return tokens

In [148]:
vocab_set = set(vocab)
print(encode("Hugging", vocab_set))
print(encode("HOgging", vocab_set))

['Hugg', '##i', '##n', '##g']
['[UNK]']


In [149]:
def tokenize(text):
    tokens = []
    for word, _ in tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text):
        token = encode(word, vocab_set)
        tokens.extend(token)
    return tokens

In [150]:
actual = tokenize("This is the Hugging Face course!")
actual

['Th',
 '##i',
 '##s',
 'is',
 'th',
 '##e',
 'Hugg',
 '##i',
 '##n',
 '##g',
 'Fac',
 '##e',
 'c',
 '##o',
 '##u',
 '##r',
 '##s',
 '##e',
 '[UNK]']

In [151]:
expected = ['Th', '##i', '##s', 'is', 'th', '##e', 'Hugg', '##i', '##n', '##g', 'Fac', '##e', 'c', '##o', '##u', '##r', '##s',
 '##e', '[UNK]']

In [152]:
assert actual == expected