crf模型定义与训练评估

In [4]:
import torch

# 配置模板
UNIGRAM_PATTERNS = [[-2], [-1], [0], [1], [2], [-2, -1], [-1, 0], [-1, 1], [0, 1], [2, 2]]
BIGRAM_PATTERNS = [[-2], [-1], [0], [1], [2], [-2, -1], [-1, 0], [-1, 1], [0, 1], [2, 2]]
TAGS = ["B", "M", "E", "S"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 特征权重
feature_weights = {}

加载训练数据

In [5]:
def load_training_data(file_path):
    sentences, labels = [], []
    buffer_text, buffer_label = "", ""
    with open(file_path, encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            if not line:
                if buffer_text and buffer_label:
                    sentences.append(buffer_text)
                    labels.append(buffer_label)
                buffer_text, buffer_label = "", ""
            else:
                char, tag = line.split("\t")
                buffer_text += char
                buffer_label += tag
    return sentences, labels

crf方法定义

In [6]:
def build_feature_key(template, template_id, text, idx, tag_sequence):
    key = str(template_id)
    for offset in template:
        pos = idx + offset
        key += text[pos] if 0 <= pos < len(text) else " "
    return f"{key}/{tag_sequence}"

def compute_score(text, pos, tag_or_pair, patterns, is_bigram=False):
    total = 0
    for i, template in enumerate(patterns):
        key = build_feature_key(template, i, text, pos, tag_or_pair)
        total += feature_weights.get(key, 0)
    return total

def tag_index_to_label(idx):
    return TAGS[idx]

def tag_label_to_index(label):
    return TAGS.index(label)

def viterbi_decode(text):
    n = len(text)
    dp = [[float('-inf')] * n for _ in range(4)]
    paths = [[""] * n for _ in range(4)]

    for tag_idx in range(4):
        tag = tag_index_to_label(tag_idx)
        dp[tag_idx][0] = compute_score(text, 0, tag, UNIGRAM_PATTERNS) + \
                         compute_score(text, 0, " " + tag, BIGRAM_PATTERNS)

    for i in range(1, n):
        for curr_idx in range(4):
            curr_tag = tag_index_to_label(curr_idx)
            best_score, best_prev = float('-inf'), 0
            for prev_idx in range(4):
                prev_tag = tag_index_to_label(prev_idx)
                score = dp[prev_idx][i - 1] + \
                        compute_score(text, i, curr_tag, UNIGRAM_PATTERNS) + \
                        compute_score(text, i, prev_tag + curr_tag, BIGRAM_PATTERNS)
                if score > best_score:
                    best_score, best_prev = score, prev_idx
            dp[curr_idx][i] = best_score
            paths[curr_idx][i] = tag_index_to_label(best_prev)

    best_final = max(range(4), key=lambda i: dp[i][-1])
    result = [""] * n
    result[-1] = tag_index_to_label(best_final)

    for i in range(n - 2, -1, -1):
        result[i] = paths[tag_label_to_index(result[i + 1])][i + 1]

    return "".join(result)

def update_feature_weights(text, gold_tags):
    pred_tags = viterbi_decode(text)
    for i, (pred, truth) in enumerate(zip(pred_tags, gold_tags)):
        if pred != truth:
            for j, template in enumerate(UNIGRAM_PATTERNS):
                for tag in [pred, truth]:
                    delta = 1 if tag == truth else -1
                    key = build_feature_key(template, j, text, i, tag)
                    feature_weights[key] = feature_weights.get(key, 0) + delta

            prev_pred = pred_tags[i - 1] if i > 0 else " "
            prev_truth = gold_tags[i - 1] if i > 0 else " "
            for j, template in enumerate(BIGRAM_PATTERNS):
                for pair, delta in [(prev_pred + pred, -1), (prev_truth + truth, 1)]:
                    key = build_feature_key(template, j, text, i, pair)
                    feature_weights[key] = feature_weights.get(key, 0) + delta

def count_errors(pred, truth):
    return sum(p != t for p, t in zip(pred, truth))

def count_matches(pred, truth):
    return sum(p == t for p, t in zip(pred, truth))

辅助评估方法定义，包括准确率，精确率，召回率和F值

In [7]:
def split_words(text, tags):
    words = []
    word = ''
    for ch, tag in zip(text, tags):
        word += ch
        if tag in ('S', 'E'):
            words.append(word)
            word = ''
    if word:
        words.append(word)
    return words

def calc_metrics(pred_tags_list, gold_tags_list, sentence_list):
    total_chars, correct_chars = 0, 0
    total_pred_words, total_gold_words, correct_words = 0, 0, 0

    for pred_tags, gold_tags, sentence in zip(pred_tags_list, gold_tags_list, sentence_list):
        total_chars += len(sentence)
        correct_chars += sum(p == g for p, g in zip(pred_tags, gold_tags))

        pred_words = set(_span_indices(split_words(sentence, pred_tags)))
        gold_words = set(_span_indices(split_words(sentence, gold_tags)))

        total_pred_words += len(pred_words)
        total_gold_words += len(gold_words)
        correct_words += len(pred_words & gold_words)

    accuracy = correct_chars / total_chars
    precision = correct_words / total_pred_words if total_pred_words else 0
    recall = correct_words / total_gold_words if total_gold_words else 0
    f1 = (2 * precision * recall / (precision + recall)) if precision + recall > 0 else 0
    return accuracy, precision, recall, f1

def _span_indices(words):
    spans = []
    pos = 0
    for word in words:
        spans.append((pos, pos + len(word)))
        pos += len(word)
    return spans

训练函数和评估函数

In [8]:
def train(train_path, model_path, epochs=5):
    sentences, labels = load_training_data(train_path)
    split = int(0.8 * len(sentences))

    for epoch in range(1, epochs + 1):
        total_errors = 0
        pred_tags_list, gold_tags_list = [], []

        # 训练部分
        for i in range(split):
            s, l = sentences[i], labels[i]
            update_feature_weights(s, l)
            pred_tags = viterbi_decode(s)
            pred_tags_list.append(pred_tags)
            gold_tags_list.append(l)
            total_errors += count_errors(pred_tags, l)

        acc, pre, rec, f1 = calc_metrics(pred_tags_list, gold_tags_list, sentences[:split])
        print(f"Epoch {epoch} - Train")
        print(f"  Accuracy:  {acc:.4f}")
        print(f"  Precision: {pre:.4f}")
        print(f"  Recall:    {rec:.4f}")
        print(f"  F1 Score:  {f1:.4f}")

        # 测试部分
        pred_tags_list, gold_tags_list = [], []
        for i in range(split, len(sentences)):
            s, l = sentences[i], labels[i]
            pred = viterbi_decode(s)
            pred_tags_list.append(pred)
            gold_tags_list.append(l)

        acc, pre, rec, f1 = calc_metrics(pred_tags_list, gold_tags_list, sentences[split:])
        print(f"Epoch {epoch} - Test")
        print(f"  Accuracy:  {acc:.4f}")
        print(f"  Precision: {pre:.4f}")
        print(f"  Recall:    {rec:.4f}")
        print(f"  F1 Score:  {f1:.4f}")

        torch.save({
            'feature_weights': feature_weights,
            'unigram_patterns': UNIGRAM_PATTERNS,
            'bigram_patterns': BIGRAM_PATTERNS
        }, model_path)
    print(f"模型 {model_path} 训练完成\n")


def predict(sentence, model_path):
    model = torch.load(model_path, map_location=DEVICE)
    global feature_weights, UNIGRAM_PATTERNS, BIGRAM_PATTERNS
    feature_weights = model['feature_weights']
    UNIGRAM_PATTERNS = model['unigram_patterns']
    BIGRAM_PATTERNS = model['bigram_patterns']

    tags = viterbi_decode(sentence)

    if tags[-1] in ('B', 'M'):
        tags = tags[:-1] + ('E' if tags[-2] in ('B', 'M') else 'S',)

    segmented = ''
    for i, ch in enumerate(sentence):
        segmented += ch
        if tags[i] in ('S', 'E') and i != len(sentence) - 1:
            segmented += '|'

    print("原始语句：", sentence)
    print("BMES标签：", ' '.join(tags))
    print("分词结果：", segmented)
    print()

In [11]:
train_path = "BMES_corpus.txt"
model_path = "crf_model.pt"
epochs = 5

train(train_path, model_path, epochs)

Epoch 1 - Train
  Accuracy:  0.9905
  Precision: 0.9884
  Recall:    0.9886
  F1 Score:  0.9885
Epoch 1 - Test
  Accuracy:  0.9191
  Precision: 0.9065
  Recall:    0.9034
  F1 Score:  0.9049
Epoch 2 - Train
  Accuracy:  0.9973
  Precision: 0.9967
  Recall:    0.9968
  F1 Score:  0.9968
Epoch 2 - Test
  Accuracy:  0.9345
  Precision: 0.9266
  Recall:    0.9196
  F1 Score:  0.9231
Epoch 3 - Train
  Accuracy:  0.9988
  Precision: 0.9986
  Recall:    0.9986
  F1 Score:  0.9986
Epoch 3 - Test
  Accuracy:  0.9425
  Precision: 0.9360
  Recall:    0.9288
  F1 Score:  0.9324
Epoch 4 - Train
  Accuracy:  0.9992
  Precision: 0.9991
  Recall:    0.9990
  F1 Score:  0.9990
Epoch 4 - Test
  Accuracy:  0.9443
  Precision: 0.9388
  Recall:    0.9302
  F1 Score:  0.9345
Epoch 5 - Train
  Accuracy:  0.9996
  Precision: 0.9995
  Recall:    0.9995
  F1 Score:  0.9995
Epoch 5 - Test
  Accuracy:  0.9477
  Precision: 0.9418
  Recall:    0.9351
  F1 Score:  0.9384
模型 crf_model.pt 训练完成



In [34]:
model_path = "crf_model.pt"  # 训练好的模型路径
sentence = "他说的的确有道理，他说的确实是对的"
predict(sentence, model_path)
sentence = "这是一项有意义的研究生活动，用来研究生活中的科学"
predict(sentence, model_path)
sentence = "4月1日出版的第7期《求是》杂志发表总书记重要文章《朝着建成科技强国的宏伟目标奋勇前进》。如何一步一个脚印把建成科技强国的战略目标变为现实？习总书记这样指引方向"
predict(sentence, model_path)
sentence = "关于紧急开发的中缅英翻译系统，该系统由国家应急语言服务团秘书处和北京语言大学迅速组建的语言服务支持团队，在仅仅七小时内利用AI开发完成。"
predict(sentence, model_path)
sentence = "他严格要求自己，从一个科举出身的进士成为一个伟大的民主主义者、一个共产主义战士。"
predict(sentence, model_path)
sentence = "但就在那一年，他为当时还默默无闻的青年选手奥尔同波士顿棕熊队谈判达成了一项年薪七万美元的合同。"
predict(sentence, model_path)
sentence = "行行出状元，干一行，行一行"
predict(sentence, model_path)


原始语句： 他说的的确有道理，他说的确实是对的
BMES标签： S S S B E S B E S S S S B E S S S
分词结果： 他|说|的|的确|有|道理|，|他|说|的|确实|是|对|的

原始语句： 这是一项有意义的研究生活动，用来研究生活中的科学
BMES标签： S S S S S B E S B M E B E S S S B E B E S S B E
分词结果： 这|是|一|项|有|意义|的|研究生|活动|，|用|来|研究|生活|中|的|科学

原始语句： 4月1日出版的第7期《求是》杂志发表总书记重要文章《朝着建成科技强国的宏伟目标奋勇前进》。如何一步一个脚印把建成科技强国的战略目标变为现实？习总书记这样指引方向
BMES标签： B E B E B E S B E S S B E S B E B E B M E B E B E S B E B E B E B E S B E B E B E B E S S B E B M M M M E S B E B E B E S B E B E B E B E S S B M E B E B E B E
分词结果： 4月|1日|出版|的|第7|期|《|求是|》|杂志|发表|总书记|重要|文章|《|朝着|建成|科技|强国|的|宏伟|目标|奋勇|前进|》|。|如何|一步一个脚印|把|建成|科技|强国|的|战略|目标|变为|现实|？|习|总书记|这样|指引|方向

原始语句： 关于紧急开发的中缅英翻译系统，该系统由国家应急语言服务团秘书处和北京语言大学迅速组建的语言服务支持团队，在仅仅七小时内利用AI开发完成。
BMES标签： B E B E B E S B M E B E B E S S B E S B E B E B E B M E B M E S B E B E B E B E B E S B E B E B E B E S S B E S B E S B E B E B E B E S
分词结果： 关于|紧急|开发|的|中缅英|翻译|系统|，|该|系统|由|国家|应急|语言|服务团|秘书处|和|北京|语言|大学|迅速|组建|的|语言|服务|支持|团队|，|在|仅仅|七|小时|内|利用|AI|开发|完成|。

原始语句： 他严格要求自己，从一个科举出身的进士成为一个伟大的民主主义者、一个共产主义战

In [9]:
model_path = "crf_model.pt"  # 训练好的模型路径
sentence = "明明明明明白白白喜欢他可是他就是不说"
predict(sentence, model_path)
sentence = "重庆市长江边的景色迷人，让人流连忘返。"
predict(sentence, model_path)


原始语句： 明明明明明白白白喜欢他可是他就是不说
BMES标签： B E B E B E S E S E S B E S S S S S
分词结果： 明明|明明|明白|白|白|喜|欢|他|可是|他|就|是|不|说

原始语句： 重庆市长江边的景色迷人，让人流连忘返。
BMES标签： B M E B E S S B E B E S S S B M M E S
分词结果： 重庆市|长江|边|的|景色|迷人|，|让|人|流连忘返|。

