In [12]:
# mxf 2025.03.10
import re
from dataclasses import dataclass
from collections import defaultdict, Counter

# Byte-Pair Encoding

In [40]:
def get_vocab(filename: str) -> dict:   # 获取语料库中单词频，并将此分成单字符形式并加上词尾记号</w>，如apple分解成'a p p l e </w>'
    vocab = defaultdict(int)
    with open(filename, 'r', encoding='UTF-8') as f:
        for line in f:    # 这里不需要f.readlines，因为这样可以只读取当前的行进内存，而readlines会将所有内容读进，占用内存。
            words = line.strip().split()
            for word in words:
                vocab[' '.join(list(word)) + '</w>'] += 1
    return vocab


def get_stats(vocab: dict) -> dict:   # 获取单词表中子词的频率
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i+1]] += freq  # 相邻的组合
    return pairs

# 这里应该是难点
def merge_vocab(pair: tuple, v_in: dict) -> dict:   # 合并单词表中的子词对， pair：子词对，v_in：词表
    v_out = {}
    bigram = re.escape(' '.join(pair))  # 用空格连接两个子词，并用escape函数转义特殊符号
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')  # 正则表达式，匹配独立的子词组
    for word in v_in:
        substr = ''.join(pair)   # 将两个子词合并成一个子词，用于替换词中的两个子词
        w_out = p.sub(substr, word) #在word尝试匹配正则表达式p，如'a b c a b d'匹配'a b'则变成'ab c ab d'，如匹配不成功，则返回原字符串。
        v_out[w_out] = v_in[word]
    return v_out


def get_tokens(vocab: dict) -> dict: # 获取token表
    tokens = defaultdict(int)
    for word, freq in vocab.items():
        word_tokens = word.split(' ')
        for word_token in word_tokens:
            tokens[word_token] += freq
    return tokens


In [41]:
vocab = get_vocab('pg16457.txt')

# print('==========')
# print('Tokens Before BPE')
# tokens = get_tokens(vocab)
# print('Tokens: {}'.format(tokens))
# print('Number of tokens: {}'.format(len(tokens)))
# print('==========')

num_merges = 1000
for _ in range(num_merges):
    pairs = get_stats(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    # print('Iter: {}'.format(i))
    # print('Best pair: {}'.format(best))
    tokens = get_tokens(vocab)
    # print('Tokens: {}'.format(tokens))
    # print('Number of tokens: {}'.format(len(tokens)))
    # print('==========')


In [11]:
@dataclass
class BPEconfig:
    filename: str
    num_merge: int = 1000

class BytePairEncoding:
    def __init__(self, config: BPEconfig):
        self.filename = config.filename
        self.num_merge = config.num_merge
        self.vocab = defaultdict(int)

    def _get_vocab(self) -> dict:   
        with open(self.filename, 'r', encoding='UTF-8') as f:
            for line in f: 
                words = line.strip().split()
                for word in words:
                    self.vocab[' '.join(list(word)) + '</w>'] += 1

    
    def _get_stats(vocab: dict) -> dict:   
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[symbols[i], symbols[i+1]] += freq  
        return pairs

    def _merge_vocab(self, pair: tuple) -> dict:  
        v_out = defaultdict(int)
        bigram = re.escape(' '.join(pair))  
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')  
        for word in self.vocab:
            substr = ''.join(pair)   
            w_out = p.sub(substr, word) 
            v_out[w_out] = self.vocab[word]
            
        self.vocab = v_out        

    def get_tokens(self) -> dict:
        tokens = defaultdict(int)
        for word, freq in self.vocab.items():
            word_tokens = word.split(' ')
            for word_token in word_tokens:
                self.tokens[word_token] += freq
        return tokens
    
    def train(self):
        self._get_vocab()
        for _ in range(self.num_merges):
            pairs = self._get_stats()
            if not pairs:
                break
            best = max(pairs, key=pairs.get)  
            self._merge_vocab(best)  


        

In [7]:
num_merges = 1000
for _ in range(num_merges):
    pairs = get_stats(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    # print('Iter: {}'.format(i))
    # print('Best pair: {}'.format(best))
    tokens = get_tokens(vocab)
    # print('Tokens: {}'.format(tokens))
    # print('Number of tokens: {}'.format(len(tokens)))
    # print('==========')

BPEconfig(filename='pg16457.txt', num_merge=1000)

# Byte Pair Encoding CLASS

In [35]:
@dataclass
class BPEconfig:
    filename: str
    num_merges: int = 1000


class BytePairEncoding:
    def __init__(self, config: BPEconfig):
        self.filename = config.filename
        self.num_merges = config.num_merges
        self.vocab = defaultdict(int)  # 存储单词及其词频
        self.merges = []  # 记录所有合并的子词对
        self.subword_vocab = set()  # 子词表
        self.subword_index = {}  # 子词及其对应的索引
        self.token_to_id = {}  # 子词到 ID 的映射
        self.id_to_token = {}  # ID 到子词的映射


    def _get_vocab(self): # 从文件中读取数据，初始化单词表及其词频
        with open(self.filename, "r", encoding="utf-8") as f:
            for line in f:
                words = line.strip().split()
                for word in words:
                    self.vocab[" ".join(list(word)) + " </w>"] += 1  # 将单词拆分为字符，添加结束符 '</w>'


    def _get_stats(self):  # 统计当前词汇表中所有子词对的频率
        pairs = defaultdict(int)
        for word, freq in self.vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs


    def _merge_vocab(self, pair):  # 合并子词对，并更新词汇表
        v_out = defaultdict(int)
        bigram = re.escape(" ".join(pair))
        p = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
        for word in self.vocab:
            w_out = p.sub("".join(pair), word)
            v_out[w_out] = self.vocab[word]
        self.vocab = v_out


    def train(self):
        self._get_vocab()
        for i in range(self.num_merges):
            # print('Round {}'.format(i))
            pairs = self._get_stats()
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            self._merge_vocab(best)
            self.merges.append(best)
        
        # 生成子词表及其索引
        self.subword_vocab = set()
        for word in self.vocab:
            for subword in word.split():
                self.subword_vocab.add(subword)
        self.subword_vocab = sorted(self.subword_vocab, key=lambda x: -self.vocab.get(x, 0))
        self.token_to_id = {token: idx for idx, token in enumerate(self.subword_vocab)}
        self.id_to_token = {idx: token for token, idx in self.token_to_id.items()}


    def tokenize(self, text: str):  # 将输入文本拆分为子词 
        tokens = []
        words = text.strip().split()
        for word in words:
            word += "</w>"
            while len(word) > 0:
                for i in range(len(word), 0, -1):  # 从最长的可能子词开始匹配
                    subword = word[:i]
                    if subword in self.subword_vocab:
                        tokens.append(subword)
                        word = word[i:]
                        break
                else:
                    tokens.append(word[0])  # 如果未找到匹配的子词，拆分为单个字符
                    word = word[1:]
        return tokens


    def encode(self, text: str):  # 将输入文本编码为子词 ID 序列
        tokens = self.tokenize(text)
        return [self.token_to_id[token] for token in tokens if token in self.token_to_id]


    def decode(self, ids: list):  # 将子词 ID 序列解码为文本
        tokens = [self.id_to_token.get(id, "<unk>") for id in ids]
        text = "".join(tokens).replace("</w>", " ")  # 去掉子词之间的空格和结束符
        return text.strip()


In [36]:
bpeconfig = BPEconfig(filename="pg16457.txt", num_merges=1000)
bpe = BytePairEncoding(bpeconfig)
bpe.train()

In [46]:
text = "Xiao-feng Mai: hello world!"
tokens = bpe.tokenize(text)
print("Tokens:", tokens)

encoded_ids = bpe.encode(text)
print("Encoded IDs:", encoded_ids)

decoded_text = bpe.decode(encoded_ids)
print("Decoded Text:", decoded_text)

Tokens: ['X', 'i', 'a', 'o', '-', 'fe', 'n', 'g', '</w>', 'M', 'a', 'i', ':</w>', 'he', 'll', 'o</w>', 'wor', 'l', 'd', '!</w>']
Encoded IDs: [322, 324, 544, 348, 518, 695, 830, 272, 857, 700, 544, 324, 588, 326, 767, 403, 833, 689, 538, 431]
Decoded Text: Xiao-feng Mai: hello world!
