# WordPiece tokenization

Install the Transformers, Datasets, and Evaluate libraries to run this notebook.

In [None]:
# Install required packages for building a WordPiece tokenizer from scratch
# - datasets: For loading and processing text datasets
# - evaluate: For model evaluation metrics
# - transformers[sentencepiece]: Core library with SentencePiece support
!uv pip install datasets evaluate transformers[sentencepiece]

In [None]:
# Define the same training corpus to compare WordPiece with BPE
# WordPiece uses a different algorithm to decide which pairs to merge
# Instead of pure frequency, it uses a likelihood-based scoring function
corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

In [None]:
# Load BERT tokenizer to understand WordPiece pre-tokenization
# BERT uses cased tokenization (preserves case) unlike the uncased version
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [None]:
# Step 1: Pre-tokenize corpus and count word frequencies for WordPiece
# WordPiece uses the same pre-tokenization step as BERT
# Note: BERT doesn't use space tokens like GPT-2, it splits on whitespace and punctuation
from collections import defaultdict

word_freqs = defaultdict(int)
for text in corpus:
    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [word for word, offset in words_with_offsets]
    for word in new_words:
        word_freqs[word] += 1

word_freqs

In [None]:
# Step 2: Build WordPiece alphabet with special prefix notation
# WordPiece uses "##" to denote word-internal characters (not word beginnings)
# The first character of each word has no prefix, subsequent characters get "##"
alphabet = []
for word in word_freqs.keys():
    # First character of each word (no prefix)
    if word[0] not in alphabet:
        alphabet.append(word[0])
    # Remaining characters get "##" prefix to indicate word-internal position
    for letter in word[1:]:
        if f"##{letter}" not in alphabet:
            alphabet.append(f"##{letter}")

alphabet.sort()
alphabet

print(alphabet)

In [None]:
# Step 3: Initialize vocabulary with BERT's special tokens
# BERT uses specific special tokens for various purposes:
# [PAD]: Padding token for batch processing
# [UNK]: Unknown token for out-of-vocabulary words
# [CLS]: Classification token (start of sequence)
# [SEP]: Separator token (end of sequence/between sentences)
# [MASK]: Mask token for masked language modeling
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet.copy()

In [None]:
# Step 4: Split words according to WordPiece convention
# First character keeps original form, subsequent characters get "##" prefix
# This preserves information about word boundaries in the tokenization
splits = {
    word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in word_freqs.keys()
}

In [None]:
# Step 5: WordPiece scoring function (different from BPE's pure frequency)
# WordPiece uses a likelihood-based score: P(pair) / (P(first) * P(second))
# This prioritizes pairs that are more predictive than random co-occurrence
def compute_pair_scores(splits):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    
    # Count frequencies of individual letters and pairs
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        # Count each character and adjacent pair
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq  # Don't forget the last character

    # Compute WordPiece score: P(pair) / (P(first) * P(second))
    # Higher scores indicate pairs that occur together more than expected by chance
    scores = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores

In [None]:
# Step 6: Compute initial pair scores to see WordPiece's scoring behavior
# Notice how scores differ from simple frequency counts
pair_scores = compute_pair_scores(splits)
for i, key in enumerate(pair_scores.keys()):
    print(f"{key}: {pair_scores[key]}")
    if i >= 5:
        break

In [None]:
# Step 7: Find the highest-scoring pair for first merge
# WordPiece merges the pair with the highest likelihood score, not necessarily the most frequent
best_pair = ""
max_score = None
for pair, score in pair_scores.items():
    if max_score is None or max_score < score:
        best_pair = pair
        max_score = score

print(best_pair, max_score)

In [None]:
# Step 8: Add the first merged token to vocabulary
# "ab" is created from merging "a" + "##b"
vocab.append("ab")

In [None]:
# Step 9: WordPiece merge function handles "##" prefix properly
# When merging tokens with "##" prefix, the prefix is removed in the result
# This maintains the WordPiece convention where "##" only marks word-internal positions
def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                # Remove "##" prefix from second token when merging
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

In [None]:
# Step 10: Apply the first merge and examine the result
# See how "about" is now split as ['ab', '##o', '##u', '##t'] instead of individual characters
splits = merge_pair("a", "##b", splits)
splits["about"]

In [None]:
# Step 11: Continue WordPiece training until target vocabulary size
# Main training loop using WordPiece's likelihood-based scoring
vocab_size = 70
while len(vocab) < vocab_size:
    scores = compute_pair_scores(splits)
    best_pair, max_score = "", None
    # Find pair with highest likelihood score
    for pair, score in scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    
    # Apply merge and add new token to vocabulary
    splits = merge_pair(*best_pair, splits)
    new_token = (
        best_pair[0] + best_pair[1][2:]  # Remove "##" prefix if present
        if best_pair[1].startswith("##")
        else best_pair[0] + best_pair[1]
    )
    vocab.append(new_token)

In [None]:
# Step 12: Examine the final WordPiece vocabulary
# Notice how WordPiece discovered different subwords compared to BPE
# The likelihood-based scoring leads to different merge decisions
print(vocab)

In [None]:
# Step 13: WordPiece encoding function using greedy longest-match
# WordPiece tokenization uses a greedy approach: find the longest token that matches
# If no token matches, use [UNK] for unknown words
def encode_word(word):
    tokens = []
    while len(word) > 0:
        i = len(word)
        # Try progressively shorter substrings until we find a match in vocabulary
        while i > 0 and word[:i] not in vocab:
            i -= 1
        if i == 0:
            # No substring found in vocabulary - this is an unknown word
            return ["[UNK]"]
        tokens.append(word[:i])
        word = word[i:]  # Continue with remaining part of word
        if len(word) > 0:
            word = f"##{word}"  # Add "##" prefix for word-internal tokens
    return tokens

In [None]:
# Step 14: Test the WordPiece encoding function
# "Hugging" was in our training data and gets nicely segmented
# "HOgging" was not seen, so it becomes [UNK] (unknown token)
print(encode_word("Hugging"))
print(encode_word("HOgging"))

In [None]:
# Step 15: Complete tokenization function for full text
# Combines pre-tokenization with WordPiece encoding
def tokenize(text):
    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
    pre_tokenized_text = [word for word, offset in pre_tokenize_result]
    encoded_words = [encode_word(word) for word in pre_tokenized_text]
    return sum(encoded_words, [])  # Flatten list of lists

In [None]:
# Step 16: Test our WordPiece tokenizer on new text
# Notice the "##" prefixes indicating word-internal tokens
# The exclamation mark "!" becomes [UNK] since it wasn't in our training vocabulary
tokenize("This is the Hugging Face course!")