In [2]:
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 模型测评与错误分析\n",
        "\n",
        "本笔记本加载训练好的命名实体识别模型，对测试集进行预测，并输出错误类型统计信息和典型错误案例，帮助定位模型薄弱点。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "from collections import Counter\n",
        "from pathlib import Path\n",
        "from typing import Iterable, List, Sequence, Tuple\n",
        "\n",
        "from data import build_corpus\n",
        "from evaluating import Metrics\n",
        "from utils import extend_maps, load_model, prepocess_data_for_lstmcrf\n",
        "\n",
        "PROJECT_ROOT = Path.cwd()\n",
        "DATA_DIR = PROJECT_ROOT / \"DataNER\"\n",
        "CKPT_DIR = PROJECT_ROOT / \"ckpts\"\n",
        "MODEL_PATH = CKPT_DIR / \"bilstm_crf.pkl\"\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "# 读取数据集\n",
        "train_word_lists, train_tag_lists, word2id, tag2id = build_corpus(\"train\", data_dir=str(DATA_DIR))\n",
        "dev_word_lists, dev_tag_lists = build_corpus(\"dev\", make_vocab=False, data_dir=str(DATA_DIR))\n",
        "test_word_lists, test_tag_lists = build_corpus(\"test\", make_vocab=False, data_dir=str(DATA_DIR))\n",
        "\n",
        "print(f\"训练集句子数: {len(train_word_lists)}\")\n",
        "print(f\"开发集句子数: {len(dev_word_lists)}\")\n",
        "print(f\"测试集句子数: {len(test_word_lists)}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "def extract_entities(words: Sequence[str], tags: Sequence[str]) -> List[Tuple[int, int, str, str]]:\n",
        "    \"\"\"将BMES/BIOS序列解析为实体列表。\"\"\"\n",
        "    entities: List[Tuple[int, int, str, str]] = []\n",
        "    idx = 0\n",
        "    while idx < len(tags):\n",
        "        tag = tags[idx]\n",
        "        if tag.startswith(\"B-\"):\n",
        "            ent_type = tag[2:]\n",
        "            start = idx\n",
        "            idx += 1\n",
        "            while idx < len(tags) and tags[idx].startswith((\"M-\", \"I-\")) and tags[idx][2:] == ent_type:\n",
        "                idx += 1\n",
        "            if idx < len(tags) and tags[idx].startswith(\"E-\") and tags[idx][2:] == ent_type:\n",
        "                end = idx\n",
        "                idx += 1\n",
        "            else:\n",
        "                end = idx - 1\n",
        "            text = \"\".join(words[start : end + 1])\n",
        "            entities.append((start, end, ent_type, text))\n",
        "        elif tag.startswith(\"S-\"):\n",
        "            ent_type = tag[2:]\n",
        "            entities.append((idx, idx, ent_type, words[idx]))\n",
        "            idx += 1\n",
        "        else:\n",
        "            idx += 1\n",
        "    return entities\n",
        "\n",
        "\n",
        "def compute_error_statistics(\n",
        "    word_lists: Iterable[Sequence[str]],\n",
        "    gold_tags: Iterable[Sequence[str]],\n",
        "    pred_tags: Iterable[Sequence[str]],\n",
        "):\n",
        "    missed = Counter()\n",
        "    spurious = Counter()\n",
        "    type_confusions = Counter()\n",
        "\n",
        "    for words, gold, pred in zip(word_lists, gold_tags, pred_tags):\n",
        "        gold_entities = extract_entities(words, gold)\n",
        "        pred_entities = extract_entities(words, pred)\n",
        "\n",
        "        gold_spans = {(s, e): (t, text) for s, e, t, text in gold_entities}\n",
        "        pred_spans = {(s, e): (t, text) for s, e, t, text in pred_entities}\n",
        "\n",
        "        for span, (gold_type, _text) in gold_spans.items():\n",
        "            if span in pred_spans:\n",
        "                pred_type, _ = pred_spans[span]\n",
        "                if pred_type != gold_type:\n",
        "                    type_confusions[(gold_type, pred_type)] += 1\n",
        "            else:\n",
        "                missed[gold_type] += 1\n",
        "\n",
        "        for span, (pred_type, _) in pred_spans.items():\n",
        "            if span not in gold_spans:\n",
        "                spurious[pred_type] += 1\n",
        "\n",
        "    return missed, spurious, type_confusions\n",
        "\n",
        "\n",
        "def collect_error_cases(\n",
        "    word_lists: List[List[str]],\n",
        "    gold_tags: List[List[str]],\n",
        "    pred_tags: List[List[str]],\n",
        "    max_cases: int = 5,\n",
        "):\n",
        "    cases = []\n",
        "    for words, gold, pred in zip(word_lists, gold_tags, pred_tags):\n",
        "        if gold == pred:\n",
        "            continue\n",
        "        gold_entities = extract_entities(words, gold)\n",
        "        pred_entities = extract_entities(words, pred)\n",
        "        cases.append(\n",
        "            {\n",
        "                \"sentence\": \"\".join(words),\n",
        "                \"gold_tags\": gold,\n",
        "                \"pred_tags\": pred,\n",
        "                \"gold_entities\": gold_entities,\n",
        "                \"pred_entities\": pred_entities,\n",
        "            }\n",
        "        )\n",
        "        if len(cases) >= max_cases:\n",
        "            break\n",
        "    return cases\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "# 加载模型并在测试集上进行预测\n",
        "crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)\n",
        "\n",
        "test_words_for_model = [list(seq) for seq in test_word_lists]\n",
        "test_tags_for_model = [list(seq) for seq in test_tag_lists]\n",
        "test_words_for_model, test_tags_for_model = prepocess_data_for_lstmcrf(\n",
        "    test_words_for_model, test_tags_for_model, test=True\n",
        ")\n",
        "\n",
        "bilstm_model = load_model(str(MODEL_PATH))\n",
        "bilstm_model.model.bilstm.bilstm.flatten_parameters()\n",
        "pred_tag_lists, tagged_targets = bilstm_model.test(\n",
        "    test_words_for_model, test_tags_for_model, crf_word2id, crf_tag2id\n",
        ")\n",
        "\n",
        "# 去除<end>标记，确保与原始句子长度对齐\n",
        "clean_words = [words[:-1] if words and words[-1] == \"<end>\" else words for words in test_words_for_model]\n",
        "clean_preds = [tags for tags in pred_tag_lists]\n",
        "clean_targets = tagged_targets\n",
        "\n",
        "metrics = Metrics(clean_targets, clean_preds, remove_O=False)\n",
        "metrics.report_scores()\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "# 计算错误统计信息\n",
        "missed, spurious, type_confusions = compute_error_statistics(\n",
        "    clean_words, clean_targets, clean_preds\n",
        ")\n",
        "\n",
        "print(\"遗漏实体（按类型统计）：\")\n",
        "for ent_type, count in missed.most_common():\n",
        "    print(f\"  {ent_type}: {count}\")\n",
        "\n",
        "print(\"多余实体（按类型统计）：\")\n",
        "for ent_type, count in spurious.most_common():\n",
        "    print(f\"  {ent_type}: {count}\")\n",
        "\n",
        "print(\"类别混淆（gold -> pred）：\")\n",
        "for (gold_type, pred_type), count in type_confusions.most_common():\n",
        "    print(f\"  {gold_type} -> {pred_type}: {count}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "# 展示若干典型错误案例\n",
        "error_cases = collect_error_cases(clean_words, clean_targets, clean_preds, max_cases=5)\n",
        "for idx, case in enumerate(error_cases, start=1):\n",
        "    print(f\"案例{idx}：{case['sentence']}\")\n",
        "    print(\"  标准标签：\", \" \".join(case[\"gold_tags\"]))\n",
        "    print(\"  预测标签：\", \" \".join(case[\"pred_tags\"]))\n",
        "    print(\"  标准实体：\", case[\"gold_entities\"])\n",
        "    print(\"  预测实体：\", case[\"pred_entities\"])\n",
        "    print(\"-\")\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.x"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}

NameError: name 'null' is not defined