In [None]:
from dataclasses import dataclass
from collections import defaultdict
from abc import ABC

In [2]:
@dataclass
class BPETokenizerParms:
    vocab: dict[int, bytes] # index -> bytes
    merges: dict[tuple[int, int], int] # index1, index2 -> new_index

In [3]:
class Tokenizer(ABC):
    def encode(self, string: str) -> list[int]:
        raise NotImplementedError

    def decode(self, indices: list[int]) -> str:
        return NotImplementedError

In [None]:
def merge(indices: list[int], pair: tuple[int, int], new_index: int) -> list[int]:
    new_indices = []
    i = 0
    while i < len(indices):
        # 检查当前位置和下一位置是否匹配pair
        if i + 1 < len(indices) and indices[i] == pair[0] and indices[i + 1] == pair[1]:
            new_indices.append(new_index)  # 用新索引替换pair
            i += 2  # 跳过已合并的两个token
        else:
            new_indices.append(indices[i])  # 保留原token
            i += 1
    return new_indices

In [None]:
class BPETokenizer(Tokenizer):
    def __init__(self, params: BPETokenizerParms):
        self.params = params

    def encode(self, string: str) -> list[int]:
        """将字符串编码为token索引列表"""
        # 1. 先将字符串转换为UTF-8字节序列
        indices = list(map(int, string.encode("utf-8")))

        # 2. 按照训练时学到的合并规则，依次合并token
        for pair, new_index in self.params.merges.items():
            indices = merge(indices, pair, new_index)
        return indices

    def decode(self, indices: list[int]) -> str:
        """将token索引列表解码为字符串"""
        # 1. 根据vocab将每个索引映射回对应的字节序列
        bytes_list = list(map(self.params.vocab.get, indices))
        # 2. 拼接所有字节并解码为UTF-8字符串
        string = b"".join(bytes_list).decode("utf-8")
        return string

In [None]:
def train_bpe(string: str, num_merges: int) -> BPETokenizerParms:
    # 1. 初始化：将字符串转为UTF-8字节序列（每个字节作为一个token）
    indices = list(map(int, string.encode("utf-8")))
    merges: dict[tuple[int, int], int] = {}  # 记录合并规则
    vocab: dict[int, bytes] = {
        x: bytes([x]) for x in range(256)
    }  # 初始词汇表包含所有256个字节

    # 2. 执行num_merges次合并操作
    for i in range(num_merges):
        # 统计所有相邻token对的出现频率
        counts = defaultdict(int)
        for index1, index2 in zip(indices, indices[1:]):
            counts[(index1, index2)] += 1

        # 找到出现次数最多的token对
        pair = max(counts, key=counts.get)
        index1, index2 = pair

        # 为这个token对分配新的索引（从256开始递增）
        new_index = 256 + i
        merges[pair] = new_index  # 记录合并规则

        # 新token = 两个旧token的字节拼接
        vocab[new_index] = vocab[index1] + vocab[index2]

        # 在当前序列中应用这个合并规则
        indices = merge(indices=indices, pair=pair, new_index=new_index)

    return BPETokenizerParms(vocab=vocab, merges=merges)

In [15]:
def bpe_tokenizer():
    string = "the cat in the hat"
    print(f"训练字符串: {string}")
    print(f"原始字节: {list(string.encode('utf-8'))}")

    params = train_bpe(string, num_merges=3)
    print(f"\n训练后的词汇表大小: {len(params.vocab)}")
    print(f"合并规则数量: {len(params.merges)}")
    print(f"\n合并规则详情:")
    for pair, new_idx in params.merges.items():
        idx1, idx2 = pair
        token1 = params.vocab[idx1]
        token2 = params.vocab[idx2]
        new_token = params.vocab[new_idx]
        print(
            f"  {pair} -> {new_idx}: {token1} + {token2} = {new_token} ('{new_token.decode('utf-8', errors='ignore')}')"
        )

    tokenizer = BPETokenizer(params)

    string = "the quick brown fox"
    print(f"\n{'='*50}")
    print(f"测试字符串: {string}")
    print(f"原始字节: {list(string.encode('utf-8'))}")

    indices = tokenizer.encode(string)
    print(f"编码后的indices: {indices}")
    print(f"编码后token详情: ", end="")
    print(
        " | ".join(
            [
                f"'{params.vocab[idx].decode('utf-8', errors='ignore')}'"
                for idx in indices
            ]
        )
    )

    reconstructed_string = tokenizer.decode(indices)
    print(f"\n解码后的字符串: {reconstructed_string}")
    print(f"断言通过: {string == reconstructed_string}")

bpe_tokenizer()

训练字符串: the cat in the hat
原始字节: [116, 104, 101, 32, 99, 97, 116, 32, 105, 110, 32, 116, 104, 101, 32, 104, 97, 116]

训练后的词汇表大小: 259
合并规则数量: 3

合并规则详情:
  (116, 104) -> 256: b't' + b'h' = b'th' ('th')
  (256, 101) -> 257: b'th' + b'e' = b'the' ('the')
  (257, 32) -> 258: b'the' + b' ' = b'the ' ('the ')

测试字符串: the quick brown fox
原始字节: [116, 104, 101, 32, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120]
编码后的indices: [258, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120]
编码后token详情: 'the ' | 'q' | 'u' | 'i' | 'c' | 'k' | ' ' | 'b' | 'r' | 'o' | 'w' | 'n' | ' ' | 'f' | 'o' | 'x'

解码后的字符串: the quick brown fox
断言通过: True
