In [3]:
import re
from collections import defaultdict

# Step 1: Compute pair scores
def compute_pair_scores(splits, word_freqs):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq

    scores = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores

# Step 2: Merge the selected pair in the splits
def merge_pair(pair, splits):
    for word in splits:
        split = splits[word]
        i = 0
        while i < len(split) - 1:
            if split[i] == pair[0] and split[i + 1] == pair[1]:
                split = split[:i] + [pair[0] + pair[1]] + split[i + 2:]
            else:
                i += 1
        splits[word] = split
    return splits

In [1]:
# Token Learner: Learn vocabulary and merge rules
def learn_wordpiece(word_freqs, vocab_size):
    splits = {word: list(word) + ['</w>'] for word in word_freqs}  # Initialize splits as characters
    vocab = set(char for word in splits.values() for char in word)  # Initial vocab (chars)
    merge_rules = []  # To store merge rules

    while len(vocab) < vocab_size:
        pair_scores = compute_pair_scores(splits, word_freqs)
        if not pair_scores:
            break

        best_pair = max(pair_scores, key=pair_scores.get)  # Best pair based on score
        splits = merge_pair(best_pair, splits)  # Merge the best pair
        vocab.add(''.join(best_pair))  # Add the new token to vocab
        merge_rules.append(best_pair)  # Store the merge rule

        print(f"Merged pair: {best_pair}, Score: {pair_scores[best_pair]:.6f}")

    return vocab, merge_rules

In [2]:
# Token Segmenter: Use learned vocabulary and rules to segment new words
def segment_word(word, merge_rules):
    tokens = list(word) + ['</w>']  # Start with characters and end-of-word token
    for pair in merge_rules:
        i = 0
        while i < len(tokens) - 1:
            if tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
                tokens = tokens[:i] + [pair[0] + pair[1]] + tokens[i + 2:]
            else:
                i += 1
    return tokens

In [5]:
# Training vocabulary (word frequencies)
word_freqs = {'low': 5, 'lower': 2, 'newest': 6, 'widest': 3}

# Learn vocabulary and merge rules
vocab_size = 30
vocab, merge_rules = learn_wordpiece(word_freqs, vocab_size)
print("\nLearned Vocabulary:")
print(sorted(vocab))
print("\nLearned Merge Rules:")
print(merge_rules)

# Segment test words
test_words = ['lowest', 'newest', 'widest']
print("\nSegmented Test Words:")
for word in test_words:
    segmented = segment_word(word, merge_rules)
    print(f"{word} -> {' '.join(segmented)}")

Merged pair: ('i', 'd'), Score: 0.333333
Merged pair: ('l', 'o'), Score: 0.142857
Merged pair: ('s', 't'), Score: 0.111111
Merged pair: ('lo', 'w'), Score: 0.062500
Merged pair: ('w', 'id'), Score: 0.111111
Merged pair: ('r', '</w>'), Score: 0.062500
Merged pair: ('st', '</w>'), Score: 0.071429
Merged pair: ('low', '</w>'), Score: 0.142857
Merged pair: ('low', 'e'), Score: 0.058824
Merged pair: ('lowe', 'r</w>'), Score: 0.500000
Merged pair: ('n', 'e'), Score: 0.066667
Merged pair: ('ne', 'w'), Score: 0.166667
Merged pair: ('new', 'e'), Score: 0.111111
Merged pair: ('wid', 'e'), Score: 0.333333
Merged pair: ('newe', 'st</w>'), Score: 0.111111
Merged pair: ('wide', 'st</w>'), Score: 0.333333

Learned Vocabulary:
['</w>', 'd', 'e', 'i', 'id', 'l', 'lo', 'low', 'low</w>', 'lowe', 'lower</w>', 'n', 'ne', 'new', 'newe', 'newest</w>', 'o', 'r', 'r</w>', 's', 'st', 'st</w>', 't', 'w', 'wid', 'wide', 'widest</w>']

Learned Merge Rules:
[('i', 'd'), ('l', 'o'), ('s', 't'), ('lo', 'w'), ('w', 'i