In [1]:
from pathlib import Path
from collections import Counter
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# 加载训练数据
def load_words(file_path: str | Path):
    with open(file_path, "r", encoding="utf8") as f:
        lines = f.readlines()
        return [word for line in lines if (word := line.rstrip("\n").split())]


train_words = load_words("pku_training.utf8")
gold_words = load_words("pku_test_gold.utf8")

In [3]:
# 进行 BIEOS 标注
def bieos_tag(vocab: set[str]):
    range_chinese = range(ord("\u4e00"), ord("\u9fff") + 1)
    range_lower_eng = range(ord("a"), ord("z") + 1)
    range_upper_eng = range(ord("A"), ord("Z") + 1)
    range_number = range(ord("0"), ord("9") + 1)

    def is_character(order: int):
        return order in range_chinese or order in range_lower_eng or order in range_upper_eng or order in range_number

    bieos_dataset = {}

    for word in vocab:
        lst = []
        l = 0
        for char in word:
            l += 1
            if is_character(ord(char)):
                lst.append("I")
            else:
                lst.append("O")
        for i in range(l):
            if lst[i] == "I":
                lst[i] = "B"
                break
        for i in range(l - 1, -1, -1):
            if lst[i] == "I":
                lst[i] = "E"
                break
        else:
            for i in range(l):
                if lst[i] == "B":
                    lst[i] = "S"
                    break
        bieos_dataset[word] = lst
    return bieos_dataset


bieos_dataset = bieos_tag(set(word for sent in train_words for word in sent))

In [4]:
# 创建分词器
class Tokenizer:
    def __init__(self, bieos_dataset: dict[str, list[str]], vocab: list[str] | None = None) -> None:
        self.bieos_dataset = bieos_dataset
        self.id2tag = ["<STOP>", "B", "I", "O", "E", "S", "<START>"]
        self.tag2id = {t: i for i, t in enumerate(self.id2tag)}
        if vocab is not None:
            self.id2token = vocab
        else:
            vocab_set = set()
            for word in bieos_dataset.keys():
                vocab_set.update(set(word))
                # vocab_set.add(word)
            self.id2token = sorted(vocab_set, key=lambda x: ord(x))
            self.id2token = ["[PAD]", "[UNK]", "[BOS]", "[EOS]"] + self.id2token

        self.token2id = {c: i for i, c in enumerate(self.id2token)}

    def encode(self, text: str) -> list[int]:
        return [self.token2id.get(c, 1) for c in text]

    def decode(self, indices: list[int]) -> str:
        return "".join(self.id2token[i] for i in indices)


tokenizer = Tokenizer(bieos_dataset)

In [5]:
# 准备训练数据
def dataset(train_words: list[list[str]], device="cpu"):
    all_tokens = []
    all_tags = []
    for sent in train_words:
        tokens = []
        tags = []
        for word in sent:
            tokens.extend(tokenizer.token2id[c] for c in word)
            tags.extend(tokenizer.tag2id[tag] for tag in bieos_dataset[word])
        assert len(tokens) == len(tags)
        all_tokens.append(torch.tensor(tokens, device=device))
        all_tags.append(torch.tensor(tags, device=device))
    return TensorDataset(
        pad_sequence(all_tokens, batch_first=True),
        pad_sequence(all_tags, batch_first=True),
    )


# 加载数据集
dataloader = DataLoader(dataset(train_words, DEVICE), batch_size=1024, shuffle=True)

In [6]:
# 准备测试代码
def evaluate_f1(pred_words, gold_words):
    pred_counter = Counter()
    gold_counter = Counter()
    correct_counter = Counter()

    for pred_sent, gold_sent in zip(pred_words, gold_words):
        # 统计词频（以词为单位）
        pred_counter.update(pred_sent)
        gold_counter.update(gold_sent)
        # 统计正确匹配的词
        correct_words = set(pred_sent) & set(gold_sent)
        correct_counter.update(correct_words)

    precision = sum(correct_counter.values()) / sum(pred_counter.values())
    recall = sum(correct_counter.values()) / sum(gold_counter.values())
    f1 = 2 * precision * recall / (precision + recall)
    return precision, recall, f1


def evaluate_oov(pred_words, gold_words, train_words):
    # 从训练集构建已知词表
    known_words = set()
    for sent in train_words:
        known_words.update(sent)

    oov_total = 0
    oov_correct = 0

    for gold_sent, pred_sent in zip(gold_words, pred_words):
        for gold_word in gold_sent:
            if gold_word not in known_words:  # 判断是否为OOV
                oov_total += 1
                if gold_word in pred_sent:  # 检查是否被正确切分
                    oov_correct += 1

    oov_recall = oov_correct / oov_total if oov_total > 0 else 0
    return oov_recall


testlines = ["".join(line) for line in gold_words]


def evaluate(cuts):
    pred_words = [cuts(line) for line in testlines]
    precision, recall, f1 = evaluate_f1(pred_words, gold_words)
    oov_recall = evaluate_oov(pred_words, gold_words, train_words)
    print(f"F1: {f1:.4f}, OOV Recall: {oov_recall:.4f}")

In [7]:
# 使用 jieba 分词器测试评估函数

import jieba_fast as jieba

evaluate(lambda line: list(jieba.cut(line)))

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\karis\AppData\Local\Temp\jieba.cache
Loading model cost 0.425 seconds.
Prefix dict has been built succesfully.


F1: 0.5747, OOV Recall: 0.5861


In [8]:
# 创建分词模型
class BiLSTM_CRF(torch.nn.Module):
    def __init__(self, tokenizer: Tokenizer, embed_dim=64, hidden_dim=128):
        super().__init__()
        self.tokenizer = tokenizer
        self.tag_size = len(tokenizer.id2tag)
        self.embedding = torch.nn.Embedding(len(tokenizer.id2token), embed_dim)
        self.lstm = torch.nn.LSTM(embed_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        self.hidden2tag = torch.nn.Linear(hidden_dim, self.tag_size)
        self.transitions = torch.nn.Parameter(torch.randn(self.tag_size, self.tag_size))
        # self.transitions.data[self.tokenizer.tag2id["B"]][self.tokenizer.tag2id["S"]] = -100.0  # 禁止B→S转移
        # self.transitions.data[self.tokenizer.tag2id["S"]][self.tokenizer.tag2id["E"]] = -100.0  # 禁止S→E转移
        # self.transitions.data[self.tokenizer.tag2id["S"]][self.tokenizer.tag2id["I"]] = -100.0  # 禁止S→I转移
        self.transitions.data[:, self.tokenizer.tag2id["<START>"]] = -100.0
        # self.transitions.data[self.tokenizer.tag2id["<START>"]][self.tokenizer.tag2id["B"]] = 100.0
        # self.transitions.data[self.tokenizer.tag2id["<STOP>"], :] = -1000.0
        # self.transitions.data[self.tokenizer.tag2id["<STOP>"]][self.tokenizer.tag2id["<STOP>"]] = 100.0

    def lstm_features(self, sentence: torch.Tensor) -> torch.Tensor:
        emb = self.embedding(sentence)
        lengths = (sentence != self.tokenizer.token2id["[PAD]"]).sum(dim=-1).cpu()
        packed = pack_padded_sequence(emb, lengths, enforce_sorted=False, batch_first=True)
        lstm_out, _ = self.lstm(packed)
        lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def viterbi_decode(self, feats: torch.Tensor):
        backpointers = []
        forward_var = torch.full((1, model.tag_size), -2000.0, device=DEVICE)
        forward_var[0][model.tokenizer.tag2id["<START>"]] = 2000.0
        # 初始化向前传播变量，并让 (全部 B)初始 forward_var 的预测标签为 START
        # forward_var 的形状应该是 [B, tag_size]，但此时的形状是 [1,tag_size]
        # 在第一次与 feat 相加时会把第一维被广播成 B
        # feats:[B, T, tag_size], feat是遍历时间步,形状是 [B, tag_size]
        for feat in feats.unbind(dim=1):
            # bptrs_t = []  # bptrs_t: 保存了当前时间步的标签
            # viterbivars_t = []  # viterbivars_t: 保存当前时间步的 viterbi 变量
            # for iter_tag in range(self.tag_size):
            #     # iter_tag 是遍历的所有的标签
            #     iter_tag_var = forward_var + self.transitions[iter_tag]
            #     # iter_tag_var 是前一步的分数加上转移到 iter_tag 的分数 (广播 self.transitions[iter_tag])
            #     current_tag = torch.argmax(iter_tag_var, dim=-1)
            #     # current_tag 是如果前序标签是 iter_tag 时，此时最有可能的标签
            #     # current_tag 的形状是 [B]
            #     bptrs_t.append(current_tag)
            #     viterbivars_t.append(torch.gather(iter_tag_var, 1, current_tag.unsqueeze(-1)))
            #     # 保存 iter_tag 转移到 current_tag 的分数
            # backpointers.append(torch.stack(bptrs_t, dim=-1))
            # # viterbivars_t 是 list[Tensor[B]] (len(viterbivars_t) == tag_size)。需要转换成 [B, tag_size] 形状的 viterbi 变量
            # forward_var = torch.cat(viterbivars_t, dim=-1) + feat
            expanded_forward = forward_var.unsqueeze(1).expand(-1, self.tag_size, -1)  # [B, tag_size, tag_size]
            transitions = self.transitions.unsqueeze(0)  # [1, tag_size, tag_size]
            iter_tag_vars = expanded_forward + transitions  # [B, tag_size, tag_size]
            bptrs_t = torch.argmax(iter_tag_vars, dim=-1)  # [B, tag_size]
            viterbivars_t = torch.gather(iter_tag_vars, 2, bptrs_t.unsqueeze(-1))  # [B, tag_size]
            backpointers.append(bptrs_t)
            forward_var = viterbivars_t.squeeze(-1) + feat
        # 转移到STOP标签
        terminal_var = forward_var + self.transitions[self.tokenizer.tag2id["<STOP>"]]
        terminal_tag = torch.argmax(terminal_var, dim=-1)
        score = torch.gather(terminal_var, 1, terminal_tag.unsqueeze(-1))
        # 回溯得到最佳路径
        path = [terminal_tag]
        current_tag = terminal_tag.unsqueeze(-1)  # 这里的 current_tag 是批量的,需要给批次内的每个标签都回溯
        # 弹出开始标签
        backpointers.pop(0)
        for bptrs_t in reversed(backpointers):
            # bptrs_t 的形状为 [B, tag_size]
            current_tag = torch.gather(bptrs_t, 1, current_tag)
            path.append(current_tag.squeeze(-1))
        path.reverse()
        return path, score

    def forward(self, sentence: torch.Tensor):
        feats = self.lstm_features(sentence)
        path, score = self.viterbi_decode(feats)
        return path, score

    def tag(self, line: str):
        tensor = torch.tensor(self.tokenizer.encode(line), device=DEVICE)
        path, score = self(tensor.unsqueeze(0))
        return [self.tokenizer.id2tag[index.item()] for index in path]

    @torch.no_grad()
    def cut(self, line: str) -> list[str]:
        words = []
        current_word = []
        for i, tag in enumerate(self.tag(line)):
            match tag:
                case "B":
                    if current_word:  # 处理前一个未完成的词
                        words.append("".join(current_word))
                        current_word = []
                    current_word.append(line[i])
                case "I" | "O":
                    current_word.append(line[i])
                case "E":
                    current_word.append(line[i])
                    words.append("".join(current_word))
                    current_word = []
                case "S" | "<STOP>":
                    if current_word:
                        words.append("".join(current_word))
                        current_word = []
                    words.append(line[i])
        if current_word:
            words.append("".join(current_word))
        return words

model = BiLSTM_CRF(tokenizer).to(DEVICE) if False else torch.load("BiLSTM_CRF.pth", weights_only=False)

In [9]:
# 测试模型输入输出是否正常，并进行训练前评估
print(model.tag(testlines[0]))
print(model.cut(testlines[0]))
print(testlines[0])

evaluate(lambda line: model.cut(line))

['B', 'E', 'B', 'E', 'B', 'E', 'S', 'S', 'B', 'E', 'O', 'O', 'B', 'O', 'O', 'S', 'E', 'B', 'E', 'B', 'E']
['共同', '创造', '美好', '的', '新', '世纪', '——', '二○○', '一', '年', '新年', '贺词']
共同创造美好的新世纪——二○○一年新年贺词
F1: 0.6029, OOV Recall: 0.4644


In [10]:
# 训练模型

epochs = 0

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(epochs):
    for texts, tags in dataloader:
        forward_var = torch.full((1, model.tag_size), -2000.0, device=DEVICE)
        forward_var[0][model.tokenizer.tag2id["<START>"]] = 2000.0
        feats = model.lstm_features(texts)
        score = torch.zeros(texts.size(0), device=DEVICE)
        tags = torch.nn.functional.pad(tags, (1, 0), "constant", model.tokenizer.tag2id["<START>"])
        # 迭代句子中的每个词
        for i, feat in enumerate(feats.unbind(dim=1)):
            current_tag = tags[:, i + 1]
            forward_tag = tags[:, i]
            score += model.transitions[forward_tag, current_tag] + torch.gather(feat, 1, current_tag.unsqueeze(-1)).squeeze(-1)
            # alphas_t = []
            # for iter_tag in range(model.tag_size):
            #     emit_score = feat[:, iter_tag].unsqueeze(-1).expand(-1, model.tag_size)
            #     transition_score = model.transitions[iter_tag].unsqueeze(0)
            #     iter_tag_var = forward_var + transition_score + emit_score
            #     alphas_t.append(torch.logsumexp(iter_tag_var, dim=-1))
            alphas_t = feat.unsqueeze(-1).expand(-1, -1, model.tag_size) + forward_var.unsqueeze(1) + model.transitions.unsqueeze(0)
            forward_var = torch.logsumexp(alphas_t, dim=-1)
        # 转移到STOP标签
        terminal_var = forward_var + model.transitions[model.tokenizer.tag2id["<STOP>"]]
        alpha = torch.logsumexp(terminal_var, dim=-1)
        # 计算给定标签序列的分数
        loss = (alpha - score).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
    if epoch % 10 == 9:
        evaluate(lambda line: model.cut(line))

In [11]:
print(testlines[11])
print(model.tag(testlines[11]))
print(model.cut(testlines[11]))
print(list(jieba.cut(testlines[11])))

我相信，只要全世界人民以及所有关心人类前途和命运的政治家们共同努力，携手前进，我们居住的这个星球一定能够成为各国人民共享和平、共同发展和共同进步的美好世界！
['S', 'B', 'E', 'O', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'S', 'B', 'I', 'B', 'S', 'B', 'E', 'B', 'E', 'O', 'B', 'E', 'B', 'E', 'O', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'O', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'O']
['我', '相信', '，', '只要', '全', '世界', '人民', '以及', '所有', '关心', '人类', '前途', '和', '命运', '的', '政治', '家', '们', '共同', '努力', '，', '携手', '前进', '，', '我们', '居住', '的', '这个', '星球', '一定', '能够', '成为', '各国', '人民', '共享', '和平', '、', '共同', '发展', '和', '共同', '进步', '的', '美好', '世界', '！']
['我', '相信', '，', '只要', '全世界', '人民', '以及', '所有', '关心', '人类', '前途', '和', '命运', '的', '政治家', '们', '共同努力', '，', '携手前进', '，', '我们', '居住', '的', '这个', '星球', '一定', '能够', '成为', '各国', '人民', '共享', '和平', '、', '共同', '发展', '和', '共同进步', '的', '美好世界', '！']


In [12]:
torch.save(model, "BiLSTM_CRF.pth")