### Wordpiece

In [1]:
sentences = [
    "我",
    "喜欢",
    "吃",
    "苹果",
    "他",
    "不",
    "喜欢",
    "吃",
    "苹果派",
    "I like to eat apples",
    "She has a cute cat",
    "you are very cute",
    "give you a hug",
]

from collections import defaultdict
# 构建频率统计
def build_stats(sentences):
    stats = defaultdict(int)
    for sentence in sentences:
        symbols = sentence.split()
        for symbol in symbols:
            stats[symbol] += 1
    return stats

stats = build_stats(sentences)
print("stats:", stats)

alphabet = []
for word in stats.keys():
    if word[0] not in alphabet:
        alphabet.append(word[0])
    for letter in word[1:]:
        if f"##{letter}" not in alphabet:
            alphabet.append(f"##{letter}")

alphabet.sort()
# 初始词表
vocab = alphabet.copy()
print("alphabet:", alphabet)

stats: defaultdict(<class 'int'>, {'我': 1, '喜欢': 2, '吃': 2, '苹果': 1, '他': 1, '不': 1, '苹果派': 1, 'I': 1, 'like': 1, 'to': 1, 'eat': 1, 'apples': 1, 'She': 1, 'has': 1, 'a': 2, 'cute': 2, 'cat': 1, 'you': 2, 'are': 1, 'very': 1, 'give': 1, 'hug': 1})
alphabet: ['##a', '##e', '##g', '##h', '##i', '##k', '##l', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##y', '##果', '##欢', '##派', 'I', 'S', 'a', 'c', 'e', 'g', 'h', 'l', 't', 'v', 'y', '不', '他', '吃', '喜', '我', '苹']


In [2]:
splits = {
    word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in stats.keys()
}
print("splits:", splits)

splits: {'我': ['我'], '喜欢': ['喜', '##欢'], '吃': ['吃'], '苹果': ['苹', '##果'], '他': ['他'], '不': ['不'], '苹果派': ['苹', '##果', '##派'], 'I': ['I'], 'like': ['l', '##i', '##k', '##e'], 'to': ['t', '##o'], 'eat': ['e', '##a', '##t'], 'apples': ['a', '##p', '##p', '##l', '##e', '##s'], 'She': ['S', '##h', '##e'], 'has': ['h', '##a', '##s'], 'a': ['a'], 'cute': ['c', '##u', '##t', '##e'], 'cat': ['c', '##a', '##t'], 'you': ['y', '##o', '##u'], 'are': ['a', '##r', '##e'], 'very': ['v', '##e', '##r', '##y'], 'give': ['g', '##i', '##v', '##e'], 'hug': ['h', '##u', '##g']}


In [3]:
def compute_pair_scores(splits):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    for word, freq in stats.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq

    scores = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores

pair_scores = compute_pair_scores(splits)
for i, key in enumerate(pair_scores.keys()):
    print(f"{key}: {pair_scores[key]}")
    if i >= 5:
        break

('喜', '##欢'): 0.5
('苹', '##果'): 0.5
('##果', '##派'): 0.5
('l', '##i'): 0.5
('##i', '##k'): 0.5
('##k', '##e'): 0.125


In [4]:
best_pair = ""
max_score = None
for pair, score in pair_scores.items():
    if max_score is None or max_score < score:
        best_pair = pair
        max_score = score

print(best_pair, max_score)

('S', '##h') 1.0


In [5]:
def merge_pair(a, b, splits):
    for word in stats:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

vocab_size = 50
while len(vocab) < vocab_size:
    scores = compute_pair_scores(splits)
    best_pair, max_score = "", None
    for pair, score in scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    splits = merge_pair(*best_pair, splits)
    new_token = (
        best_pair[0] + best_pair[1][2:]
        if best_pair[1].startswith("##")
        else best_pair[0] + best_pair[1]
    )
    vocab.append(new_token)

print("vocab:", vocab)

vocab: ['##a', '##e', '##g', '##h', '##i', '##k', '##l', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##y', '##果', '##欢', '##派', 'I', 'S', 'a', 'c', 'e', 'g', 'h', 'l', 't', 'v', 'y', '不', '他', '吃', '喜', '我', '苹', 'Sh', '喜欢', '苹果', '苹果派', 'li', 'lik', 'gi', 'giv', '##pl', '##ppl', '##ry', 'to', 'yo', 'ea', 'eat']


### Byte-Pair Encoding (BPE)

In [6]:
sentences = [
    "我",
    "喜欢",
    "吃",
    "苹果",
    "他",
    "不",
    "喜欢",
    "吃",
    "苹果派",
    "I like to eat apples",
    "She has a cute cat",
    "you are very cute",
    "give you a hug",
]

In [7]:
# 构建频率统计
def build_stats(sentences):
    stats = defaultdict(int)
    for sentence in sentences:
        symbols = sentence.split()
        for symbol in symbols:
            stats[symbol] += 1
    return stats

stats = build_stats(sentences)
print("stats:", stats)

alphabet = []
for word in stats.keys():
    for letter in word:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()

# 初始词表
vocab = alphabet.copy()
print("alphabet:", alphabet)

stats: defaultdict(<class 'int'>, {'我': 1, '喜欢': 2, '吃': 2, '苹果': 1, '他': 1, '不': 1, '苹果派': 1, 'I': 1, 'like': 1, 'to': 1, 'eat': 1, 'apples': 1, 'She': 1, 'has': 1, 'a': 2, 'cute': 2, 'cat': 1, 'you': 2, 'are': 1, 'very': 1, 'give': 1, 'hug': 1})
alphabet: ['I', 'S', 'a', 'c', 'e', 'g', 'h', 'i', 'k', 'l', 'o', 'p', 'r', 's', 't', 'u', 'v', 'y', '不', '他', '吃', '喜', '我', '果', '欢', '派', '苹']


In [8]:
splits = {word: [c for c in word] for word in stats.keys()}
print("splits:", splits)

def compute_pair_freqs(splits):
    pair_freqs = defaultdict(int)
    for word, freq in stats.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs
pair_freqs = compute_pair_freqs(splits)

for i, key in enumerate(pair_freqs.keys()):
    print(f"{key}: {pair_freqs[key]}")
    if i >= 5:
        break

splits: {'我': ['我'], '喜欢': ['喜', '欢'], '吃': ['吃'], '苹果': ['苹', '果'], '他': ['他'], '不': ['不'], '苹果派': ['苹', '果', '派'], 'I': ['I'], 'like': ['l', 'i', 'k', 'e'], 'to': ['t', 'o'], 'eat': ['e', 'a', 't'], 'apples': ['a', 'p', 'p', 'l', 'e', 's'], 'She': ['S', 'h', 'e'], 'has': ['h', 'a', 's'], 'a': ['a'], 'cute': ['c', 'u', 't', 'e'], 'cat': ['c', 'a', 't'], 'you': ['y', 'o', 'u'], 'are': ['a', 'r', 'e'], 'very': ['v', 'e', 'r', 'y'], 'give': ['g', 'i', 'v', 'e'], 'hug': ['h', 'u', 'g']}
('喜', '欢'): 2
('苹', '果'): 2
('果', '派'): 1
('l', 'i'): 1
('i', 'k'): 1
('k', 'e'): 1


In [9]:
best_pair = ""
max_freq = None
for pair, freq in pair_freqs.items():
    if max_freq is None or max_freq < freq:
        best_pair = pair
        max_freq = freq

print(best_pair, max_freq)

('喜', '欢') 2


In [10]:
def merge_pair(a, b, splits):
    for word in stats:
        split = splits[word]
        if len(split) == 1:
            continue

        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                split = split[:i] + [a + b] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

# 假设我们想要的词典为50
merges = {}
vocab_size = 50

while len(vocab) < vocab_size:
    pair_freqs = compute_pair_freqs(splits)
    best_pair = ""
    max_freq = None
    for pair, freq in pair_freqs.items():
        if max_freq is None or max_freq < freq:
            best_pair = pair
            max_freq = freq
    splits = merge_pair(*best_pair, splits)
    merges[best_pair] = best_pair[0] + best_pair[1]
    vocab.append(best_pair[0] + best_pair[1])

print("merges:", merges)
print("vocab:", vocab)

merges: {('喜', '欢'): '喜欢', ('苹', '果'): '苹果', ('a', 't'): 'at', ('c', 'u'): 'cu', ('cu', 't'): 'cut', ('cut', 'e'): 'cute', ('y', 'o'): 'yo', ('yo', 'u'): 'you', ('v', 'e'): 've', ('苹果', '派'): '苹果派', ('l', 'i'): 'li', ('li', 'k'): 'lik', ('lik', 'e'): 'like', ('t', 'o'): 'to', ('e', 'at'): 'eat', ('a', 'p'): 'ap', ('ap', 'p'): 'app', ('app', 'l'): 'appl', ('appl', 'e'): 'apple', ('apple', 's'): 'apples', ('S', 'h'): 'Sh', ('Sh', 'e'): 'She', ('h', 'a'): 'ha'}
vocab: ['I', 'S', 'a', 'c', 'e', 'g', 'h', 'i', 'k', 'l', 'o', 'p', 'r', 's', 't', 'u', 'v', 'y', '不', '他', '吃', '喜', '我', '果', '欢', '派', '苹', '喜欢', '苹果', 'at', 'cu', 'cut', 'cute', 'yo', 'you', 've', '苹果派', 'li', 'lik', 'like', 'to', 'eat', 'ap', 'app', 'appl', 'apple', 'apples', 'Sh', 'She', 'ha']


### Byte-level BPE(BBPE)

In [11]:
from collections import defaultdict
sentences = [
    "我",
    "喜欢",
    "吃",
    "苹果",
    "他",
    "不",
    "喜欢",
    "吃",
    "苹果派",
    "I like to eat apples",
    "She has a cute cat",
    "you are very cute",
    "give you a hug",
]
# 构建初始词汇表，包含一个字节的256个表示
initial_vocab = [bytes([byte]) for byte in range(256)]
vocab = initial_vocab.copy()
print("initial_vocab:", initial_vocab)

# 构建频率统计
def build_stats(sentences):
    stats = defaultdict(int)
    for sentence in sentences:
        symbols = sentence.split()
        for symbol in symbols:
            stats[symbol.encode("utf-8")] += 1
    return stats
stats = build_stats(sentences)

splits = {word: [byte for byte in word] for word in stats.keys()}
def compute_pair_freqs(splits):
    pair_freqs = defaultdict(int)
    for word, freq in stats.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs

pair_freqs = compute_pair_freqs(splits)

def merge_pair(pair, splits):
    merged_byte = bytes(pair)
    for word in stats:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i:i+2] == pair:  # 检查分割中是否有这对字节
                split = split[:i] + [merged_byte] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

vocab_size = 50
while len(vocab) < vocab_size:
    pair_freqs = compute_pair_freqs(splits)
    best_pair = ()
    max_freq = None
    for pair, freq in pair_freqs.items():
        if max_freq is None or max_freq < freq:
            best_pair = pair
            max_freq = freq
    splits = merge_pair(best_pair, splits)
    merged_byte = bytes(best_pair)

print("vocab:", vocab)

initial_vocab: [b'\x00', b'\x01', b'\x02', b'\x03', b'\x04', b'\x05', b'\x06', b'\x07', b'\x08', b'\t', b'\n', b'\x0b', b'\x0c', b'\r', b'\x0e', b'\x0f', b'\x10', b'\x11', b'\x12', b'\x13', b'\x14', b'\x15', b'\x16', b'\x17', b'\x18', b'\x19', b'\x1a', b'\x1b', b'\x1c', b'\x1d', b'\x1e', b'\x1f', b' ', b'!', b'"', b'#', b'$', b'%', b'&', b"'", b'(', b')', b'*', b'+', b',', b'-', b'.', b'/', b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b':', b';', b'<', b'=', b'>', b'?', b'@', b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X', b'Y', b'Z', b'[', b'\\', b']', b'^', b'_', b'`', b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', b'z', b'{', b'|', b'}', b'~', b'\x7f', b'\x80', b'\x81', b'\x82', b'\x83', b'\x84', b'\x85', b'\x86', b'\x87', b'\x88', b'\x89', b'\x8a', b'\x8b', b'\x8c', b'\x8