In [1]:
import os
import numpy as np
from collections import Counter
from typing import List, Tuple

In [2]:
def load_and_preprocess_data(data_path: str, max_seq_len: int = 30):
    """
    加载并预处理数据
    
    参数:
    data_path: 数据文件路径
    max_seq_len: 最大序列长度，默认30
    
    返回:
    inputs: 预处理后的输入序列
    outputs: 预处理后的标签序列
    word2idx: 词汇到索引的映射
    label2idx: 标签到索引的映射
    idx2label: 索引到标签的映射
    """
    

    with open(data_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    current_sentence = []
    current_labels = []
    all_sentences = []
    all_labels = []

    
   
    
    for line in lines:
        line = line.strip()
        if not line:  # 空行表示句子结束
            if current_sentence:
                all_sentences.append(current_sentence)
                all_labels.append(current_labels)
                current_sentence = []
                current_labels = []
        else:
            parts = line.split()
            if len(parts) >= 2:  
                word = parts[0]
                label = parts[1]  
                current_sentence.append(word)
                current_labels.append(label)
    
    # 添加最后一个句子
    if current_sentence:
        all_sentences.append(current_sentence)
        all_labels.append(current_labels)
    
    print(f"总共加载了 {len(all_sentences)} 个句子")
    
    # 3. 构建词汇表和标签表
    # 构建词汇表（统计所有词语）
    word_counter = Counter()
    for sentence in all_sentences:
        word_counter.update(sentence)
    
    # 添加特殊标记
    word2idx = {'<PAD>': 0, '<UNK>': 1}
    for word, _ in word_counter.most_common():
        word2idx[word] = len(word2idx)
    
    # 构建标签表（实验步骤中提到的7个标签）
    # 注意：这里按照文档中的7个标签定义
    labels = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']
    label2idx = {label: idx + 1 for idx, label in enumerate(labels)}  # 从1开始编号
    label2idx['<PAD>'] = 0  # 填充标记为0
    idx2label = {idx: label for label, idx in label2idx.items()}
    
    # 4. 转换为索引并填充/截断
    processed_inputs = []
    processed_outputs = []
    
    for sentence, labels in zip(all_sentences, all_labels):
        # 转换为索引
        word_indices = [word2idx.get(word, word2idx['<UNK>']) for word in sentence]
        label_indices = [label2idx.get(label, 0) for label in labels]  # 未知标签设为0
        
        # 截断或填充
        if len(word_indices) > max_seq_len:
            word_indices = word_indices[:max_seq_len]
            label_indices = label_indices[:max_seq_len]
        else:
            padding_length = max_seq_len - len(word_indices)
            word_indices = word_indices + [word2idx['<PAD>']] * padding_length
            label_indices = label_indices + [label2idx['<PAD>']] * padding_length
        
        processed_inputs.append(word_indices)
        processed_outputs.append(label_indices)
    
    # 转换为numpy数组
    processed_inputs = np.array(processed_inputs, dtype=np.int32)
    processed_outputs = np.array(processed_outputs, dtype=np.int32)
    
    return processed_inputs, processed_outputs, word2idx, label2idx, idx2label


In [3]:
def validate_preprocessing(inputs, outputs, idx2label, word2idx, num_samples=3):
    """
    验证预处理步骤是否正确
    
    参数:
    inputs: 预处理后的输入
    outputs: 预处理后的输出
    idx2label: 索引到标签的映射
    word2idx: 词汇到索引的映射
    num_samples: 验证的样本数量
    """
    
    print("=" * 60)
    print("验证预处理结果:")
    print("=" * 60)
    
    # 1. 验证形状
    print(f"1. 输入数据形状: {inputs.shape}")
    print(f"   输出数据形状: {outputs.shape}")
    print(f"   所有序列长度应为30: {inputs.shape[1] == 30}")
    
    # 2. 验证填充是否正确
    print(f"\n2. 验证填充:")
    for i in range(min(num_samples, len(inputs))):
        original_len = np.sum(inputs[i] != word2idx['<PAD>'])
        print(f"   样本{i+1}: 原始长度={original_len}, 填充后长度={len(inputs[i])}")
    
    # 3. 显示样本
    print(f"\n3. 显示前{num_samples}个样本:")
    idx2word = {idx: word for word, idx in word2idx.items()}
    
    for i in range(min(num_samples, len(inputs))):
        print(f"\n   样本{i+1}:")
        
        # 获取原始词和标签（去除填充）
        words = []
        labels = []
        for word_idx, label_idx in zip(inputs[i], outputs[i]):
            if word_idx != word2idx['<PAD>']:
                words.append(idx2word.get(word_idx, '<UNK>'))
                labels.append(idx2label.get(label_idx, '<PAD>'))
        
        print(f"   词: {' '.join(words)}")
        print(f"   标签: {' '.join(labels)}")
        
        # 验证对应关系
        if len(words) == len(labels):
            print(f"   词和标签数量匹配: ✓")
        else:
            print(f"   词和标签数量不匹配: ✗")
    
    # 4. 统计标签分布
    print(f"\n4. 标签分布统计:")
    unique_labels, counts = np.unique(outputs, return_counts=True)
    for label_idx, count in zip(unique_labels, counts):
        label_name = idx2label.get(label_idx, f"未知({label_idx})")
        print(f"   标签 {label_name}: {count} 次 ({count/len(outputs.flatten())*100:.2f}%)")
    
    # 5. 验证标签编号是否符合要求
    print(f"\n5. 验证标签编号:")
    expected_labels = {'O': 1, 'B-PER': 2, 'I-PER': 3, 'B-ORG': 4, 
                      'I-ORG': 5, 'B-LOC': 6, 'I-LOC': 7, '<PAD>': 0}
    
    all_correct = True
    for label_name, expected_idx in expected_labels.items():
        if label_name in idx2label.values():
            actual_idx = [k for k, v in idx2label.items() if v == label_name][0]
            if actual_idx == expected_idx:
                print(f"   {label_name}: 编号正确 ({actual_idx})")
            else:
                print(f"   {label_name}: 编号错误 (期望{expected_idx}, 实际{actual_idx})")
                all_correct = False
    
    return all_correct


In [4]:
import pickle
data_path = "chinese/train_data"  # 假设数据文件路径

# 执行预处理
inputs, outputs, word2idx, label2idx, idx2label = load_and_preprocess_data(data_path, max_seq_len=30)
print(f"\n预处理完成!")
print(f"词汇表大小: {len(word2idx)}")
print(f"标签数量: {len(label2idx)}")   

# 验证预处理
is_valid = validate_preprocessing(inputs, outputs, idx2label, word2idx)

if is_valid:
    print(f"\n✓ 预处理验证通过!")
    
    # 保存预处理结果（供后续步骤使用）
    np.save("processed_inputs.npy", inputs)
    np.save("processed_outputs.npy", outputs)
    
    # 保存映射表
    with open("word2idx.pkl", "wb") as f:
        pickle.dump(word2idx, f)
    with open("label2idx.pkl", "wb") as f:
        pickle.dump(label2idx, f)
    with open("idx2label.pkl", "wb") as f:
        pickle.dump(idx2label, f)
        
    print(f"预处理结果已保存到文件!")
else:
    print(f"\n✗ 预处理验证失败，请检查代码!")   

总共加载了 50658 个句子

预处理完成!
词汇表大小: 4769
标签数量: 8
验证预处理结果:
1. 输入数据形状: (50658, 30)
   输出数据形状: (50658, 30)
   所有序列长度应为30: True

2. 验证填充:
   样本1: 原始长度=30, 填充后长度=30
   样本2: 原始长度=30, 填充后长度=30
   样本3: 原始长度=30, 填充后长度=30

3. 显示前3个样本:

   样本1:
   词: 当 希 望 工 程 救 助 的 百 万 儿 童 成 长 起 来 ， 科 教 兴 国 蔚 然 成 风 时 ， 今 天 有
   标签: O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O
   词和标签数量匹配: ✓

   样本2:
   词: 藏 书 本 来 就 是 所 有 传 统 收 藏 门 类 中 的 第 一 大 户 ， 只 是 我 们 结 束 温 饱 的
   标签: O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O
   词和标签数量匹配: ✓

   样本3:
   词: 因 有 关 日 寇 在 京 掠 夺 文 物 详 情 ， 藏 界 较 为 重 视 ， 也 是 我 们 收 藏 北 京 史
   标签: O O O B-LOC O O B-LOC O O O O O O O O O O O O O O O O O O O O B-LOC I-LOC O
   词和标签数量匹配: ✓

4. 标签分布统计:
   标签 <PAD>: 132114 次 (8.69%)
   标签 O: 1207207 次 (79.44%)
   标签 B-PER: 13983 次 (0.92%)
   标签 I-PER: 26122 次 (1.72%)
   标签 B-ORG: 16199 次 (1.07%)
   标签 I-ORG: 62941 次 (4.14%)
   标签 B-LOC: 26163 次 (1.72%)
   标签 I-LOC: 35011 次 (2.30%)

5. 验证标签编号:
   O: 编号正确 (1)
   B-PER: 编号正确 (2)
   

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import re

# ==================== 1. 加载预处理数据 ====================

# 加载预处理好的数据
print("加载预处理数据...")
processed_inputs = np.load('processed_inputs.npy')
processed_outputs = np.load('processed_outputs.npy')

# 加载映射字典
with open('word2idx.pkl', 'rb') as f:
    word2idx = pickle.load(f)

with open('label2idx.pkl', 'rb') as f:
    label2idx = pickle.load(f)

with open('idx2label.pkl', 'rb') as f:
    idx2label = pickle.load(f)

print(f"训练数据形状: inputs={processed_inputs.shape}, outputs={processed_outputs.shape}")
print(f"词汇表大小: {len(word2idx)}")
print(f"标签数量: {len(label2idx)}")

# 获取序列长度
seq_len = processed_inputs.shape[1]
print(f"序列长度: {seq_len}")

# ==================== 2. 加载预训练词向量 ====================
def load_pretrained_embeddings(embedding_path, word2idx, embedding_dim=300):
    """
    加载预训练词向量
    """
    print(f"加载预训练词向量: {embedding_path}")
    
    # 初始化词向量矩阵
    vocab_size = len(word2idx) + 1  # +1 为padding的0
    embeddings = np.random.uniform(-0.25, 0.25, (vocab_size, embedding_dim))
    embeddings[0] = np.zeros(embedding_dim)  # padding位置置为0
    
    # 记录找到的词数量
    found_words = 0
    
    try:
        with open(embedding_path, 'r', encoding='utf-8') as f:
            # 读取第一行获取词数量和维度
            line = f.readline().strip()
            if len(line.split()) == 2:
                # 文件有头部信息
                vocab_count, dim = map(int, line.split())
                print(f"词向量文件信息: 词数量={vocab_count}, 维度={dim}")
            else:
                # 文件没有头部信息，重置文件指针
                f.seek(0)
            
            # 逐行读取词向量
            for line in tqdm(f, desc="加载词向量"):
                line = line.strip()
                if not line:
                    continue
                
                # 使用正则表达式分割：以数字（包括科学计数法）或负号为切分点
                # 匹配模式：非数字、负号、小数点、字母e/E的部分作为词，其余作为向量部分
                parts = re.split(r'(\s+[-+]?\d+\.?\d*[eE]?[-+]?\d*\s+)', line, maxsplit=1)
                
                if len(parts) < 2:
                    continue
                
                # 第一部分是词（可能包含空格）
                word = parts[0].strip()
                
                # 第二部分是向量部分（包含所有向量值）
                vector_part = parts[1].strip()
                
                # 提取所有数字（包括科学计数法）
                vector_numbers = re.findall(r'[-+]?\d+\.?\d*(?:[eE][-+]?\d+)?', vector_part)
                
                # 检查向量维度
                if len(vector_numbers) < embedding_dim:
                    # 尝试另一种解析方式：直接用空格分割整个行
                    all_parts = line.strip().split()
                    if len(all_parts) >= embedding_dim + 1:
                        # 词可能是由多个部分组成的
                        vector_numbers = all_parts[-embedding_dim:]
                        word = ' '.join(all_parts[:-embedding_dim])
                    else:
                        continue
                
                try:
                    # 转换为浮点数
                    vector = np.array([float(x) for x in vector_numbers[:embedding_dim]])
                    
                    if len(vector) != embedding_dim:
                        continue
                        
                    if word in word2idx:
                        idx = word2idx[word]
                        embeddings[idx] = vector
                        found_words += 1
                        
                except ValueError:
                    continue
                    
    except Exception as e:
        print(f"加载词向量时出错: {e}")
        import traceback
        traceback.print_exc()
    
    print(f"成功加载 {found_words}/{len(word2idx)} 个词的预训练向量")
    return torch.FloatTensor(embeddings)

# 加载预训练词向量
embedding_path = 'sgns.renmin.bigram-char'
pretrained_embeddings = load_pretrained_embeddings(embedding_path, word2idx)

# ==================== 3. 构建BiLSTM+CRF模型 ====================

class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tag_size, embedding_dim=300, hidden_dim=256, 
                 pretrained_embeddings=None, dropout=0.5):
        super(BiLSTM_CRF, self).__init__()
        
        # 嵌入层
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=False)
        else:
            self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # BiLSTM层
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if 2 > 1 else 0
        )
        
        # 全连接层，将LSTM输出映射到标签空间
        self.fc = nn.Linear(hidden_dim * 2, tag_size)
        
        # Dropout层
        self.dropout = nn.Dropout(dropout)
        
        # CRF层（使用pytorch-crf库）
        try:
            from torchcrf import CRF
            self.crf = CRF(tag_size, batch_first=True)
        except ImportError:
            print("警告: 未找到torchcrf库，将使用简化的CRF实现")
            self.crf = None
            # 如果没有CRF，使用普通分类
            self.classifier = nn.Linear(tag_size, tag_size)
    
    def forward(self, x, tags=None, mask=None):
        # 获取序列长度（非padding部分）
        if mask is None:
            mask = (x != 0).bool()
        
        # 嵌入层
        embeds = self.embedding(x)
        
        # LSTM层
        lstm_out, _ = self.lstm(embeds)
        lstm_out = self.dropout(lstm_out)
        
        # 全连接层
        emissions = self.fc(lstm_out)
        
        # 如果有CRF层
        if self.crf is not None:
            if tags is not None:
                # 训练模式：计算CRF损失
                loss = -self.crf(emissions, tags, mask=mask, reduction='mean')
                return loss
            else:
                # 预测模式：使用Viterbi解码
                predictions = self.crf.decode(emissions, mask=mask)
                return predictions
        else:
            # 如果没有CRF，直接使用softmax
            if tags is not None:
                # 计算交叉熵损失
                loss_fn = nn.CrossEntropyLoss(ignore_index=0)  # 忽略padding
                active_loss = mask.view(-1) == 1
                active_logits = emissions.view(-1, emissions.shape[2])[active_loss]
                active_labels = tags.view(-1)[active_loss]
                loss = loss_fn(active_logits, active_labels)
                return loss
            else:
                # 预测
                predictions = torch.argmax(emissions, dim=2)
                return predictions.tolist()

# ==================== 4. 准备训练数据 ====================

# 转换为Tensor
train_inputs = torch.LongTensor(processed_inputs)
train_outputs = torch.LongTensor(processed_outputs)

# 创建掩码（非padding部分为1）
train_mask = (train_inputs != 0).bool()

# 创建数据集和数据加载器
batch_size = 128
train_dataset = TensorDataset(train_inputs, train_outputs, train_mask)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

print(f"训练数据批次数量: {len(train_loader)}")

# ==================== 5. 加载测试数据 ====================

def load_and_preprocess_data(file_path, word2idx, label2idx, max_len=seq_len):
    """
    加载并预处理数据
    """
    print(f"加载数据: {file_path}")
    
    inputs = []
    outputs = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        current_input = []
        current_output = []
        
        for line in f:
            line = line.strip()
            if not line:  # 空行表示句子结束
                if current_input:
                    # 填充或截断到固定长度
                    if len(current_input) > max_len:
                        current_input = current_input[:max_len]
                        current_output = current_output[:max_len]
                    else:
                        padding_len = max_len - len(current_input)
                        current_input.extend([0] * padding_len)  # 0是padding
                        current_output.extend([0] * padding_len)
                    
                    inputs.append(current_input)
                    outputs.append(current_output)
                    
                    current_input = []
                    current_output = []
                continue
            
            parts = line.split()
            if len(parts) >= 3:
                word = parts[0]
                label = parts[2]
                
                # 转换词和标签为索引
                word_idx = word2idx.get(word, word2idx.get('<UNK>', 1))  # 1通常代表未知词
                label_idx = label2idx.get(label, 0)  # 0通常是O标签
                
                current_input.append(word_idx)
                current_output.append(label_idx)
    
    # 处理最后一个句子
    if current_input:
        if len(current_input) > max_len:
            current_input = current_input[:max_len]
            current_output = current_output[:max_len]
        else:
            padding_len = max_len - len(current_input)
            current_input.extend([0] * padding_len)
            current_output.extend([0] * padding_len)
        
        inputs.append(current_input)
        outputs.append(current_output)
    
    return np.array(inputs), np.array(outputs)

# 加载测试数据
test_file = './chinese/test_data'
test_inputs, test_outputs = load_and_preprocess_data(test_file, word2idx, label2idx)

# 转换为Tensor
test_inputs_tensor = torch.LongTensor(test_inputs)
test_outputs_tensor = torch.LongTensor(test_outputs)
test_mask = (test_inputs_tensor != 0).bool()

# 创建测试数据集
test_dataset = TensorDataset(test_inputs_tensor, test_outputs_tensor, test_mask)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"测试数据批次数量: {len(test_loader)}")

# ==================== 6. 初始化模型和优化器 ====================

# 模型参数
vocab_size = len(word2idx) + 1  # +1 for padding
tag_size = len(label2idx)
embedding_dim = 300
hidden_dim = 256

# 初始化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BiLSTM_CRF(vocab_size, tag_size, embedding_dim, hidden_dim, pretrained_embeddings)
model = model.to(device)

# 优化器
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# ==================== 7. 训练模型 ====================

def train_epoch(model, data_loader, optimizer, device):
    """训练一个epoch"""
    model.train()
    total_loss = 0
    
    for batch in tqdm(data_loader, desc="训练"):
        inputs, labels, mask = batch
        inputs, labels, mask = inputs.to(device), labels.to(device), mask.to(device)
        
        # 前向传播
        loss = model(inputs, tags=labels, mask=mask)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(data_loader)

def evaluate(model, data_loader, device):
    """评估模型"""
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="评估"):
            inputs, labels, mask = batch
            inputs, labels, mask = inputs.to(device), labels.to(device), mask.to(device)
            
            # 获取预测
            predictions = model(inputs, mask=mask)
            
            # 收集结果
            for i in range(len(predictions)):
                seq_len = mask[i].sum().item()
                all_predictions.extend(predictions[i][:seq_len])
                all_labels.extend(labels[i][:seq_len].cpu().tolist())
    
    return all_predictions, all_labels

def calculate_metrics(predictions, labels, idx2label):
    """计算精确率、召回率和F1分数"""
    # 过滤掉O标签（非实体）
    entity_labels = [idx for idx, label in idx2label.items() if label != 'O']
    
    # 将索引转换为标签名
    pred_labels = [idx2label.get(p, 'O') for p in predictions]
    true_labels = [idx2label.get(l, 'O') for l in labels]
    
    # 统计TP, FP, FN
    tp = 0
    fp = 0
    fn = 0
    
    # 简单的实体匹配（实际应用中可能需要更复杂的匹配逻辑）
    i = 0
    while i < len(pred_labels):
        if pred_labels[i].startswith('B-'):
            # 找到一个预测的实体
            entity_type = pred_labels[i][2:]
            end_idx = i + 1
            while end_idx < len(pred_labels) and pred_labels[end_idx] == f'I-{entity_type}':
                end_idx += 1
            
            # 检查是否匹配真实实体
            if i < len(true_labels) and true_labels[i].startswith('B-') and true_labels[i][2:] == entity_type:
                # 检查整个实体是否匹配
                match = True
                for j in range(i, end_idx):
                    if j >= len(true_labels) or true_labels[j] != pred_labels[j]:
                        match = False
                        break
                
                if match and (end_idx >= len(true_labels) or not true_labels[end_idx].startswith('I-')):
                    tp += 1
                else:
                    fp += 1
                    # 同时可能漏检了真实实体
                    fn += 1
            else:
                fp += 1
            
            i = end_idx
        elif true_labels[i].startswith('B-') and pred_labels[i] == 'O':
            # 漏检了一个实体
            fn += 1
            i += 1
        else:
            i += 1
    
    # 计算指标
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    return precision, recall, f1

# 训练模型
num_epochs = 20
best_f1 = 0

print("开始训练模型...")
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # 训练
    train_loss = train_epoch(model, train_loader, optimizer, device)
    print(f"训练损失: {train_loss:.4f}")
    
    # 评估
    predictions, labels = evaluate(model, test_loader, device)
    precision, recall, f1 = calculate_metrics(predictions, labels, idx2label)
    
    print(f"测试集 - 精确率: {precision:.4f}, 召回率: {recall:.4f}, F1: {f1:.4f}")
    
    # 保存最佳模型
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"保存最佳模型，F1: {f1:.4f}")
    
    # 更新学习率
    scheduler.step()

print(f"\n训练完成，最佳F1分数: {best_f1:.4f}")

# ==================== 8. DEMO演示 ====================

def predict_sentence(model, sentence, word2idx, idx2label, device, max_len=seq_len):
    """预测单个句子的实体标签"""
    # 分词（这里简单按字符分割，实际应用中可能需要更复杂的分词）
    words = list(sentence.strip())
    
    # 转换为索引
    word_indices = [word2idx.get(word, word2idx.get('<UNK>', 1)) for word in words]
    
    # 填充或截断
    if len(word_indices) > max_len:
        word_indices = word_indices[:max_len]
    else:
        padding_len = max_len - len(word_indices)
        word_indices = word_indices + [0] * padding_len
    
    # 转换为Tensor
    input_tensor = torch.LongTensor([word_indices]).to(device)
    mask = (input_tensor != 0).bool().to(device)
    
    # 预测
    model.eval()
    with torch.no_grad():
        predictions = model(input_tensor, mask=mask)
    
    # 提取预测结果
    if isinstance(predictions, list):
        pred_indices = predictions[0]
    else:
        pred_indices = predictions[0].cpu().tolist()
    
    # 转换为标签
    pred_labels = []
    for i, idx in enumerate(pred_indices[:len(words)]):
        pred_labels.append(idx2label.get(idx, 'O'))
    
    return list(zip(words, pred_labels))

# 加载最佳模型
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

print("\n" + "="*50)
print("DEMO演示")
print("="*50)

# 示例句子
demo_sentences = [
    "我爱北京天安门",
    "清华大学是中国著名的高等学府",
    "李白是唐代著名诗人",
    "马云是阿里巴巴集团的创始人"
]

for sentence in demo_sentences:
    print(f"\n输入句子: {sentence}")
    results = predict_sentence(model, sentence, word2idx, idx2label, device)
    
    print("实体识别结果:")
    for word, label in results:
        if label != 'O':
            print(f"  {word}: {label}")
        else:
            print(f"  {word}")

print("\n" + "="*50)
print("实体识别演示完成！")
print("="*50)

# ==================== 9. 保存完整模型 ====================

# 保存完整模型
torch.save({
    'model_state_dict': model.state_dict(),
    'word2idx': word2idx,
    'label2idx': label2idx,
    'idx2label': idx2label,
    'seq_len': seq_len
}, 'ner_model_complete.pth')

print("\n模型已保存为 'ner_model_complete.pth'")

加载预处理数据...
训练数据形状: inputs=(50658, 30), outputs=(50658, 30)
词汇表大小: 4769
标签数量: 8
序列长度: 30
加载预训练词向量: sgns.renmin.bigram-char
词向量文件信息: 词数量=356053, 维度=300


加载词向量: 75033it [00:09, 7700.64it/s]

: 