## CS310 Natural Language Processing
## Assignment 4. Long Short Term Memory (LSTM) Network for Named Entity Recognition (NER)

**Total points**: 50 + (10 bonus)

In this assignment, you will implement a Long Short Term Memory (LSTM) network for Named Entity Recognition (NER). 

Re-use the code in Lab 5.

### 0. Import Necessary Libraries

In [38]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import defaultdict
import os
import requests
import zipfile
import io

### 1. Build the Model

In [39]:
import torch
from torch.utils.data import Dataset
from collections import defaultdict

class CoNLLDataset(Dataset):
    def __init__(self, file_path, word2idx, label2idx):
        self.sentences = []
        self.labels = []
        self.word2idx = word2idx
        self.label2idx = label2idx
        
        current_sentence = []
        current_labels = []
        
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line == '':
                    if current_sentence:
                        self.sentences.append([word2idx.get(w.lower(), word2idx['<UNK>']) for w in current_sentence])
                        self.labels.append([label2idx[l] for l in current_labels])
                        current_sentence = []
                        current_labels = []
                else:
                    parts = line.split()
                    current_sentence.append(parts[0])
                    current_labels.append(parts[-1])
    
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        return torch.tensor(self.sentences[idx]), torch.tensor(self.labels[idx])

def build_vocab(train_path):
    word2idx = {'<PAD>': 0, '<UNK>': 1}
    label2idx = {}
    word_freq = defaultdict(int)
    
    with open(train_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                parts = line.split()
                word = parts[0].lower()
                label = parts[-1]
                word_freq[word] += 1
                if label not in label2idx:
                    label2idx[label] = len(label2idx)
    
    for word, freq in word_freq.items():
        if freq > 1 and word not in word2idx:
            word2idx[word] = len(word2idx)
    
    return word2idx, label2idx

def collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    sentences, labels = zip(*batch)
    max_len = len(sentences[0])
    
    padded_sentences = []
    padded_labels = []
    for sentence, label in zip(sentences, labels):
        padding_len = max_len - len(sentence)
        padded_sentence = torch.cat([sentence, torch.zeros(padding_len, dtype=torch.long)])
        padded_label = torch.cat([label, torch.zeros(padding_len, dtype=torch.long)])
        padded_sentences.append(padded_sentence)
        padded_labels.append(padded_label)
    
    return torch.stack(padded_sentences), torch.stack(padded_labels)

def load_glove_embeddings(word2idx, embed_dim=100):
    glove_path = 'glove.6B.100d.txt'
    print("正在加载GloVe词向量...")
    
    embeddings = np.random.uniform(-0.25, 0.25, (len(word2idx), embed_dim))
    
    try:
        with open(glove_path, 'r', encoding='utf-8') as f:
            for line in f:
                values = line.split()
                word = values[0]
                if word in word2idx:
                    vector = np.array(values[1:], dtype='float32')
                    embeddings[word2idx[word]] = vector
        print("GloVe词向量加载完成！")
    except Exception as e:
        print(f"加载词向量时出错: {e}")
        print("使用随机初始化的词向量继续...")
    
    return torch.FloatTensor(embeddings)

def collate_fn(batch):
    """
    自定义的collate_fn函数，用于处理不同长度的序列
    """
    # 将batch中的样本按句子长度排序（降序）
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    
    # 分离句子和标签
    sentences, labels = zip(*batch)
    
    # 获取这个batch中最长句子的长度
    max_len = len(sentences[0])
    
    # 对句子和标签进行padding
    padded_sentences = []
    padded_labels = []
    for sentence, label in zip(sentences, labels):
        # 计算需要padding的长度
        padding_len = max_len - len(sentence)
        
        # 对句子进行padding（使用0，即<PAD>的索引）
        padded_sentence = torch.cat([sentence, torch.zeros(padding_len, dtype=torch.long)])
        
        # 对标签进行padding（使用0，即<PAD>的索引）
        padded_label = torch.cat([label, torch.zeros(padding_len, dtype=torch.long)])
        
        padded_sentences.append(padded_sentence)
        padded_labels.append(padded_label)
    
    # 将列表转换为tensor
    padded_sentences = torch.stack(padded_sentences)
    padded_labels = torch.stack(padded_labels)
    
    return padded_sentences, padded_labels

# 然后修改DataLoader的创建
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)

def build_vocab(train_path):
    word2idx = {'<PAD>': 0, '<UNK>': 1}
    label2idx = {}
    word_freq = defaultdict(int)
    
    with open(train_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                parts = line.split()
                word = parts[0].lower()
                label = parts[-1]
                word_freq[word] += 1
                if label not in label2idx:
                    label2idx[label] = len(label2idx)
    
    for word, freq in word_freq.items():
        if freq > 1 and word not in word2idx:
            word2idx[word] = len(word2idx)
    
    return word2idx, label2idx

# 初始化数据
train_path = 'data/train.txt'
dev_path = 'data/dev.txt'
test_path = 'data/test.txt'

word2idx, label2idx = build_vocab(train_path)
train_dataset = CoNLLDataset(train_path, word2idx, label2idx)
dev_dataset = CoNLLDataset(dev_path, word2idx, label2idx)
test_dataset = CoNLLDataset(test_path, word2idx, label2idx)

batch_size = 32


# 加载预训练词嵌入
pretrained_embeddings = load_glove_embeddings(word2idx)

# 在加载完预训练词向量之后，添加以下代码：

class BiLSTM_NER(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_tags, dropout=0.5, pretrained_embeddings=None):
        super(BiLSTM_NER, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # 词嵌入层
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        if pretrained_embeddings is not None:
            self.word_embeddings.weight.data.copy_(pretrained_embeddings)
            # 冻结预训练的词向量
            self.word_embeddings.weight.requires_grad = False
        
        # 增加embedding dropout
        self.embed_dropout = nn.Dropout(0.2)
        
        # BiLSTM层
        self.lstm = nn.LSTM(embedding_dim, 
                           hidden_dim // 2,  # 因为是双向的，所以hidden_dim要除以2
                           num_layers=num_layers,
                           bidirectional=True,
                           dropout=dropout if num_layers > 1 else 0,
                           batch_first=True)
        
        # 增加层归一化
        self.layer_norm = nn.LayerNorm(hidden_dim)
        
        # Dropout层
        self.dropout = nn.Dropout(dropout)
        
        # 线性层
        self.hidden2tag = nn.Linear(hidden_dim, num_tags)
    
    def forward(self, x):
        embeds = self.word_embeddings(x)
        embeds = self.embed_dropout(embeds)
        
        lstm_out, _ = self.lstm(embeds)
        lstm_out = self.layer_norm(lstm_out)
        lstm_out = self.dropout(lstm_out)
        
        tag_space = self.hidden2tag(lstm_out)
        return tag_space


正在加载GloVe词向量...
GloVe词向量加载完成！


### 2. Train and Evaluate

In [26]:
from utils import get_tag_indices_from_scores
from metrics import MetricsHandler

labels_str = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']
labels_int = list(range(len(labels_str)))
train_metrics = MetricsHandler(labels_int)

def train_model(model, data_loader, optimizer, loss_func, train_metrics, num_epochs=5, **kwargs):
    model.train()
    losses = []
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in data_loader:
            inputs, targets = batch
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                targets = targets.cuda()
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            # 计算损失
            batch_size, seq_len, num_tags = outputs.size()
            outputs_reshaped = outputs.view(-1, num_tags)
            targets_reshaped = targets.view(-1)
            
            loss = loss_func(outputs_reshaped, targets_reshaped)
            loss.backward()
            optimizer.step()
            
            # 计算预测结果
            predictions = get_tag_indices_from_scores(outputs.detach().cpu().numpy())
            predictions = predictions.reshape(-1)
            true_values = targets.cpu().numpy().reshape(-1)
            
            # 过滤掉padding位置
            mask = (true_values != 0)  # 0是PAD的索引
            predictions = predictions[mask]
            true_values = true_values[mask]
            
            # 更新指标
            train_metrics.update(predictions, true_values)
            
            running_loss += loss.item() * inputs.size(0)
        
        epoch_loss = running_loss / len(data_loader.dataset)
        losses.append(epoch_loss)
    
    return model, train_metrics, losses

def evaluate_model(model, data_loader, criterion, metrics):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in data_loader:
            inputs, targets = batch
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                targets = targets.cuda()
            
            outputs = model(inputs)
            
            # 计算损失
            batch_size, seq_len, num_tags = outputs.size()
            outputs_reshaped = outputs.view(-1, num_tags)
            targets_reshaped = targets.view(-1)
            
            loss = criterion(outputs_reshaped, targets_reshaped)
            total_loss += loss.item() * inputs.size(0)
            
            # 修改预测结果的获取方式
            predictions = torch.argmax(outputs, dim=2)  # [batch_size, seq_len]
            predictions = predictions.cpu().numpy().reshape(-1)
            true_values = targets.cpu().numpy().reshape(-1)
            
            # 过滤掉padding位置
            mask = (true_values != 0)  # 0是PAD的索引
            predictions = predictions[mask]
            true_values = true_values[mask]
            
            # 在更新指标之前收集当前批次的指标
            metrics.collect()
            # 更新指标
            metrics.update(predictions, true_values)
    
    return total_loss / len(data_loader.dataset), metrics

# 初始化模型参数
vocab_size = len(word2idx)
embedding_dim = 100  # GloVe维度
hidden_dim = 256
num_layers = 2
num_tags = len(label2idx)
dropout = 0.3  # 降低dropout率

# 创建模型实例
model = BiLSTM_NER(vocab_size, embedding_dim, hidden_dim, num_layers, num_tags, 
                   dropout=dropout, pretrained_embeddings=pretrained_embeddings)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略PAD的损失
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.001)  # 降低学习率和权重衰减

# 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', 
                                                factor=0.5, patience=2, 
                                                verbose=True, min_lr=1e-6)

# 如果有GPU可用，将模型移到GPU
if torch.cuda.is_available():
    model = model.cuda()

# 训练循环
print("开始训练...")
num_epochs = 5
best_f1 = -1
patience = 5
no_improve = 0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    train_metrics = MetricsHandler(labels_int)
    dev_metrics = MetricsHandler(labels_int)
    
    # 训练阶段
    for batch_idx, batch in enumerate(train_loader):
        inputs, targets = batch
        if torch.cuda.is_available():
            inputs = inputs.cuda()
            targets = targets.cuda()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        
        batch_size, seq_len, num_tags = outputs.size()
        outputs_reshaped = outputs.view(-1, num_tags)
        targets_reshaped = targets.view(-1)
        
        # 检查数值是否有效
        if torch.isnan(outputs_reshaped).any():
            print(f"Warning: NaN in outputs at batch {batch_idx}")
            continue
            
        loss = criterion(outputs_reshaped, targets_reshaped)
        
        # 检查损失是否有效
        if torch.isnan(loss):
            print(f"Warning: NaN loss at batch {batch_idx}")
            continue
            
        loss.backward()
        
        # 更严格的梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        
        # 检查梯度是否有效
        for name, param in model.named_parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any():
                    print(f"Warning: NaN gradient in {name}")
                    continue
        
        optimizer.step()
        
        current_loss = loss.item()
        if not np.isnan(current_loss):
            total_loss += current_loss * inputs.size(0)
        
        with torch.no_grad():
            predictions = torch.argmax(outputs, dim=2)
            predictions = predictions.detach().cpu().numpy().reshape(-1)
            true_values = targets.cpu().numpy().reshape(-1)
            
            mask = (true_values != 0)
            predictions = predictions[mask]
            true_values = true_values[mask]
            
            train_metrics.collect()
            train_metrics.update(predictions, true_values)
    
    # 计算训练集性能
    train_loss = total_loss / len(train_loader.dataset)
    train_metrics_dict = train_metrics.get_metrics()
    train_f1 = train_metrics_dict['F1-score'][-1] if train_metrics_dict['F1-score'] else 0.0
    
    # 验证阶段
    model.eval()
    dev_total_loss = 0
    with torch.no_grad():
        for batch in dev_loader:
            inputs, targets = batch
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                targets = targets.cuda()
            
            outputs = model(inputs)
            
            batch_size, seq_len, num_tags = outputs.size()
            outputs_reshaped = outputs.view(-1, num_tags)
            targets_reshaped = targets.view(-1)
            
            loss = criterion(outputs_reshaped, targets_reshaped)
            if not torch.isnan(loss):
                dev_total_loss += loss.item() * inputs.size(0)
            
            predictions = torch.argmax(outputs, dim=2)
            predictions = predictions.cpu().numpy().reshape(-1)
            true_values = targets.cpu().numpy().reshape(-1)
            
            mask = (true_values != 0)
            predictions = predictions[mask]
            true_values = true_values[mask]
            
            dev_metrics.collect()
            dev_metrics.update(predictions, true_values)
    
    dev_loss = dev_total_loss / len(dev_loader.dataset)
    dev_metrics_dict = dev_metrics.get_metrics()
    dev_f1 = dev_metrics_dict['F1-score'][-1] if dev_metrics_dict['F1-score'] else 0.0
    
    # 学习率调整
    scheduler.step(dev_f1)
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}')
    print(f'Dev Loss: {dev_loss:.4f}, Dev F1: {dev_f1:.4f}')
    print(f'Learning rate: {optimizer.param_groups[0]["lr"]:.6f}')
    
    # 保存最佳模型
    if dev_f1 > best_f1:
        best_f1 = dev_f1
        print(f"保存模型到 best_model.pt (F1: {dev_f1:.4f})")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_f1': best_f1,
        }, 'best_model.pt')
        no_improve = 0
    else:
        no_improve += 1
        print(f"验证集F1未改善，已经{no_improve}个epoch")
    
    # 早停检查
    if no_improve >= patience:
        print(f"早停：验证集F1在{patience}个epoch内未改善")
        break

# 加载最佳模型进行测试
try:
    checkpoint = torch.load('best_model.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"成功加载最佳模型 (Epoch {checkpoint['epoch']+1}, Best F1: {checkpoint['best_f1']:.4f})")
except Exception as e:
    print(f"加载模型时出错: {e}")
    print("使用当前模型进行测试...")

# 最终测试
model.eval()
test_metrics = MetricsHandler(labels_int)
with torch.no_grad():
    test_loss, test_metrics = evaluate_model(model, test_loader, criterion, test_metrics)
    test_metrics_dict = test_metrics.get_metrics()
    test_f1 = test_metrics_dict['F1-score'][-1] if test_metrics_dict['F1-score'] else 0.0

print(f'\nFinal Test F1: {test_f1:.4f}')

开始训练...
Epoch 1/5
Train Loss: 1.3262, Train F1: 0.7881
Dev Loss: 0.9756, Dev F1: 0.5086
Learning rate: 0.001000
保存模型到 best_model.pt (F1: 0.5086)
Epoch 2/5
Train Loss: 0.9678, Train F1: 0.7540
Dev Loss: 0.8506, Dev F1: 0.7202
Learning rate: 0.001000
保存模型到 best_model.pt (F1: 0.7202)
Epoch 3/5
Train Loss: 0.8137, Train F1: 0.5907
Dev Loss: 0.7648, Dev F1: 0.5754
Learning rate: 0.001000
验证集F1未改善，已经1个epoch
Epoch 4/5
Train Loss: 0.6934, Train F1: 0.4730
Dev Loss: 0.6937, Dev F1: 0.6210
Learning rate: 0.001000
验证集F1未改善，已经2个epoch
Epoch 5/5
Train Loss: 0.6074, Train F1: 0.7235
Dev Loss: 0.6348, Dev F1: 0.6833
Learning rate: 0.000500
验证集F1未改善，已经3个epoch
成功加载最佳模型 (Epoch 2, Best F1: 0.7202)

Final Test F1: 0.9412


### 3. Other Experiments

In [46]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import numpy as np
from collections import defaultdict
from sklearn.metrics import precision_recall_fscore_support

# 定义MetricsHandler类
class MetricsHandler:
    def __init__(self, labels):
        self.labels = labels
        self.predictions = []
        self.true_values = []
        self.metrics = {
            'Precision': [],
            'Recall': [],
            'F1-score': []
        }
    
    def collect(self):
        if self.predictions and self.true_values:
            precision, recall, f1, _ = precision_recall_fscore_support(
                self.true_values, 
                self.predictions, 
                labels=self.labels, 
                average='weighted',
                zero_division=0
            )
            self.metrics['Precision'].append(precision)
            self.metrics['Recall'].append(recall)
            self.metrics['F1-score'].append(f1)
    
    def update(self, predictions, true_values):
        self.predictions.extend(predictions)
        self.true_values.extend(true_values)
    
    def get_metrics(self):
        return self.metrics

# 定义数据集类
class CoNLLDataset(Dataset):
    def __init__(self, file_path, word2idx, label2idx):
        self.sentences = []
        self.labels = []
        self.word2idx = word2idx
        self.label2idx = label2idx
        
        current_sentence = []
        current_labels = []
        
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line == '':
                    if current_sentence:
                        self.sentences.append([word2idx.get(w.lower(), word2idx['<UNK>']) for w in current_sentence])
                        self.labels.append([label2idx[l] for l in current_labels])
                        current_sentence = []
                        current_labels = []
                else:
                    parts = line.split()
                    current_sentence.append(parts[0])
                    current_labels.append(parts[-1])
    
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        return torch.tensor(self.sentences[idx]), torch.tensor(self.labels[idx])

def build_vocab(train_path):
    word2idx = {'<PAD>': 0, '<UNK>': 1}
    label2idx = {}
    word_freq = defaultdict(int)
    
    with open(train_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                parts = line.split()
                word = parts[0].lower()
                label = parts[-1]
                word_freq[word] += 1
                if label not in label2idx:
                    label2idx[label] = len(label2idx)
    
    for word, freq in word_freq.items():
        if freq > 1 and word not in word2idx:
            word2idx[word] = len(word2idx)
    
    return word2idx, label2idx

def collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    sentences, labels = zip(*batch)
    max_len = len(sentences[0])
    
    padded_sentences = []
    padded_labels = []
    for sentence, label in zip(sentences, labels):
        padding_len = max_len - len(sentence)
        padded_sentence = torch.cat([sentence, torch.zeros(padding_len, dtype=torch.long)])
        padded_label = torch.cat([label, torch.zeros(padding_len, dtype=torch.long)])
        padded_sentences.append(padded_sentence)
        padded_labels.append(padded_label)
    
    return torch.stack(padded_sentences), torch.stack(padded_labels)

def load_glove_embeddings(word2idx, embed_dim=100):
    glove_path = 'glove.6B.100d.txt'
    print("正在加载GloVe词向量...")
    
    embeddings = np.random.uniform(-0.25, 0.25, (len(word2idx), embed_dim))
    
    try:
        with open(glove_path, 'r', encoding='utf-8') as f:
            for line in f:
                values = line.split()
                word = values[0]
                if word in word2idx:
                    vector = np.array(values[1:], dtype='float32')
                    embeddings[word2idx[word]] = vector
        print("GloVe词向量加载完成！")
    except Exception as e:
        print(f"加载词向量时出错: {e}")
        print("使用随机初始化的词向量继续...")
    
    return torch.FloatTensor(embeddings)

# CRF模型定义
class CRF(nn.Module):
    def __init__(self, num_tags):
        super(CRF, self).__init__()
        self.num_tags = num_tags
        # 使用更小的初始值初始化转移矩阵
        self.transitions = nn.Parameter(torch.randn(num_tags, num_tags) * 0.1)
        
        # 初始化转移矩阵
        self.transitions.data[0, :] = -10000  # 从PAD标签转移的代价很高
        self.transitions.data[:, 0] = -10000  # 转移到PAD标签的代价很高
    
    def forward(self, emissions, tags, mask=None):
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)
        
        loss = -self.log_likelihood(emissions, tags, mask)
        return loss.mean()
    
    def decode(self, emissions, mask=None):
        """使用维特比算法进行解码"""
        if mask is None:
            mask = torch.ones(emissions.size(0), emissions.size(1), dtype=torch.uint8, device=emissions.device)
        
        batch_size, seq_len, num_tags = emissions.size()
        
        # 初始化
        viterbi = emissions.new_full((batch_size, num_tags), -10000)
        viterbi[:, 0] = 0  # 从PAD标签开始
        
        # 存储最佳路径
        backpointers = emissions.new_zeros((batch_size, seq_len, num_tags), dtype=torch.long)
        
        # 前向传播
        for t in range(seq_len):
            # 计算当前时间步的得分
            emit_score = emissions[:, t].unsqueeze(2)  # [batch_size, num_tags, 1]
            trans_score = self.transitions.unsqueeze(0)  # [1, num_tags, num_tags]
            
            # 计算所有可能的路径得分
            next_tag_var = viterbi.unsqueeze(2) + trans_score + emit_score  # [batch_size, num_tags, num_tags]
            
            # 找到最佳路径
            best_tag_scores, best_tag_ids = next_tag_var.max(dim=1)  # [batch_size, num_tags]
            
            # 应用mask
            best_tag_scores = best_tag_scores * mask[:, t].unsqueeze(1)
            
            # 更新viterbi和backpointers
            viterbi = best_tag_scores
            backpointers[:, t] = best_tag_ids
        
        # 回溯找到最佳路径
        best_path_scores, best_tag_ids = viterbi.max(dim=1)
        best_paths = [best_tag_ids]
        
        # 从后向前回溯
        for t in range(seq_len-1, 0, -1):
            best_tag_ids = torch.gather(backpointers[:, t], 1, best_tag_ids.unsqueeze(1)).squeeze(1)
            best_paths.insert(0, best_tag_ids)
        
        # 将路径转换为tensor
        best_paths = torch.stack(best_paths, dim=1)
        
        return best_paths
    
    def log_likelihood(self, emissions, tags, mask=None):
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)
        
        forward_score = self._forward_alg(emissions, mask)
        gold_score = self._score_sentence(emissions, tags, mask)
        return gold_score - forward_score
    
    def _forward_alg(self, emissions, mask):
        batch_size, seq_len, num_tags = emissions.size()
        
        alpha = emissions.new_full((batch_size, num_tags), -10000)
        alpha[:, 0] = 0
        
        for t in range(seq_len):
            emit_score = emissions[:, t].unsqueeze(2)
            trans_score = self.transitions.unsqueeze(0)
            next_tag_var = alpha.unsqueeze(2) + trans_score + emit_score
            alpha = torch.logsumexp(next_tag_var, dim=1) * mask[:, t].unsqueeze(1)
        
        return torch.logsumexp(alpha, dim=1)
    
    def _score_sentence(self, emissions, tags, mask):
        batch_size, seq_len, num_tags = emissions.size()
        score = emissions.new_zeros(batch_size)
        
        score += emissions[torch.arange(batch_size), 0, tags[:, 0]] * mask[:, 0]
        
        for t in range(1, seq_len):
            trans_score = self.transitions[tags[:, t-1], tags[:, t]]
            emit_score = emissions[torch.arange(batch_size), t, tags[:, t]]
            score += (trans_score + emit_score) * mask[:, t]
        
        return score

# BiLSTM-CRF模型定义
class BiLSTM_CRF_NER(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_tags, dropout=0.3, pretrained_embeddings=None):
        super(BiLSTM_CRF_NER, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_tags = num_tags
        
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        if pretrained_embeddings is not None:
            self.word_embeddings.weight.data.copy_(pretrained_embeddings)
            self.word_embeddings.weight.requires_grad = False
        
        self.lstm = nn.LSTM(embedding_dim,
                           hidden_dim // 2,
                           num_layers=num_layers,
                           bidirectional=True,
                           dropout=dropout if num_layers > 1 else 0,
                           batch_first=True)
        
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.hidden2tag = nn.Linear(hidden_dim, num_tags)
        self.crf = CRF(num_tags)
    
    def forward(self, x, tags=None, mask=None):
        emissions = self._get_emissions(x)
        
        if self.training and tags is not None:
            # 训练模式：计算损失
            return self.crf(emissions, tags, mask)
        else:
            # 评估模式：使用维特比算法解码
            return self.crf.decode(emissions, mask)
    
    def _get_emissions(self, x):
        embeds = self.word_embeddings(x)
        lstm_out, _ = self.lstm(embeds)
        lstm_out = self.layer_norm(lstm_out)
        lstm_out = self.dropout(lstm_out)
        emissions = self.hidden2tag(lstm_out)
        return emissions

# MEMM模型定义
class MEMM_NER(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_tags, dropout=0.3, pretrained_embeddings=None):
        super(MEMM_NER, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_tags = num_tags
        
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        if pretrained_embeddings is not None:
            self.word_embeddings.weight.data.copy_(pretrained_embeddings)
            self.word_embeddings.weight.requires_grad = False
        
        self.tag_embeddings = nn.Embedding(num_tags, embedding_dim)
        
        self.lstm = nn.LSTM(embedding_dim + embedding_dim,
                           hidden_dim // 2,
                           num_layers=num_layers,
                           bidirectional=True,
                           dropout=dropout if num_layers > 1 else 0,
                           batch_first=True)
        
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.hidden2tag = nn.Linear(hidden_dim, num_tags)
    
    def forward(self, x, prev_tags=None):
        batch_size, seq_len = x.size()
        
        word_embeds = self.word_embeddings(x)
        
        if prev_tags is None:
            prev_tags = torch.zeros(batch_size, seq_len, dtype=torch.long, device=x.device)
        else:
            prev_tags = torch.cat([torch.zeros(batch_size, 1, dtype=torch.long, device=x.device),
                                 prev_tags[:, :-1]], dim=1)
        
        tag_embeds = self.tag_embeddings(prev_tags)
        combined_embeds = torch.cat([word_embeds, tag_embeds], dim=-1)
        
        lstm_out, _ = self.lstm(combined_embeds)
        lstm_out = self.layer_norm(lstm_out)
        lstm_out = self.dropout(lstm_out)
        tag_space = self.hidden2tag(lstm_out)
        
        return tag_space

def beam_search_decode(model, sentence, beam_size=5):
    model.eval()
    with torch.no_grad():
        batch_size, seq_len = sentence.size()
        device = sentence.device
        
        all_predictions = []
        for b in range(batch_size):
            mask = (sentence[b] != 0)
            
            beams = [([], 0.0)]
            
            for t in range(seq_len):
                if not mask[t]:
                    continue
                    
                candidates = []
                for seq, score in beams:
                    prev_tags = torch.zeros(1, seq_len, dtype=torch.long, device=device)
                    for i, tag in enumerate(seq):
                        prev_tags[0, i] = tag
                    
                    current_input = sentence[b:b+1]
                    outputs = model(current_input, prev_tags)
                    probs = F.log_softmax(outputs[0, t], dim=-1)
                    
                    topk_probs, topk_tags = probs.topk(beam_size)
                    
                    for prob, tag in zip(topk_probs, topk_tags):
                        if tag != 0:
                            candidates.append((seq + [tag.item()], score + prob.item()))
            
                beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
        
            if beams:
                best_seq = beams[0][0]
                best_seq.extend([0] * (seq_len - len(best_seq)))
                all_predictions.append(best_seq)
            else:
                all_predictions.append([0] * seq_len)
        
        return torch.tensor(all_predictions, device=device)

def evaluate_with_beam_search(model, data_loader, beam_size=5):
    metrics = MetricsHandler(list(range(len(label2idx))))
    
    with torch.no_grad():
        for batch in data_loader:
            inputs, targets = batch
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                targets = targets.cuda()
            
            predictions = beam_search_decode(model, inputs, beam_size)
            
            mask = (targets != 0)
            predictions = predictions[mask].cpu().numpy()
            true_values = targets[mask].cpu().numpy()
            
            valid_mask = (predictions > 0) & (true_values > 0)
            predictions = predictions[valid_mask]
            true_values = true_values[valid_mask]
            
            if len(predictions) > 0:
                metrics.collect()
                metrics.update(predictions, true_values)
    
    metrics_dict = metrics.get_metrics()
    return metrics_dict['F1-score'][-1] if metrics_dict['F1-score'] else 0.0

def evaluate_crf(model, data_loader):
    metrics = MetricsHandler(list(range(len(label2idx))))
    
    with torch.no_grad():
        for batch in data_loader:
            inputs, targets = batch
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                targets = targets.cuda()
            
            mask = (targets != 0)
            predictions = model(inputs, mask=mask)
            
            predictions = predictions[mask].cpu().numpy()
            true_values = targets[mask].cpu().numpy()
            
            valid_mask = (predictions > 0) & (true_values > 0)
            predictions = predictions[valid_mask]
            true_values = true_values[valid_mask]
            
            if len(predictions) > 0:
                metrics.collect()
                metrics.update(predictions, true_values)
    
    metrics_dict = metrics.get_metrics()
    return metrics_dict['F1-score'][-1] if metrics_dict['F1-score'] else 0.0

def train_and_evaluate_models():
    # 初始化数据
    train_path = 'data/train.txt'
    dev_path = 'data/dev.txt'
    test_path = 'data/test.txt'
    
    word2idx, label2idx = build_vocab(train_path)
    train_dataset = CoNLLDataset(train_path, word2idx, label2idx)
    dev_dataset = CoNLLDataset(dev_path, word2idx, label2idx)
    test_dataset = CoNLLDataset(test_path, word2idx, label2idx)
    
    # 加载预训练词嵌入
    pretrained_embeddings = load_glove_embeddings(word2idx)
    
    # 创建DataLoader
    train_loader = DataLoader(
        train_dataset, 
        batch_size=256, 
        shuffle=True, 
        collate_fn=collate_fn,
        pin_memory=True,
        num_workers=0
    )
    
    dev_loader = DataLoader(
        dev_dataset, 
        batch_size=256, 
        collate_fn=collate_fn,
        pin_memory=True,
        num_workers=0
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=256, 
        collate_fn=collate_fn,
        pin_memory=True,
        num_workers=0
    )
    
    # 初始化模型参数
    vocab_size = len(word2idx)
    embedding_dim = 100
    hidden_dim = 256
    num_layers = 2
    num_tags = len(label2idx)
    dropout = 0.3
    
    # 初始化模型
    memm_model = MEMM_NER(vocab_size, embedding_dim, hidden_dim, num_layers, num_tags, 
                         dropout=dropout, pretrained_embeddings=pretrained_embeddings)
    crf_model = BiLSTM_CRF_NER(vocab_size, embedding_dim, hidden_dim, num_layers, num_tags,
                              dropout=dropout, pretrained_embeddings=pretrained_embeddings)
    
    if torch.cuda.is_available():
        memm_model = memm_model.cuda()
        crf_model = crf_model.cuda()
    
    # 优化器
    memm_optimizer = optim.AdamW(memm_model.parameters(), lr=0.001, weight_decay=0.001)
    crf_optimizer = optim.AdamW(crf_model.parameters(), lr=0.0001, weight_decay=0.001)  # 降低CRF的学习率
    
    # 学习率调度器
    memm_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        memm_optimizer, mode='max', factor=0.5, patience=2, verbose=True)
    crf_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        crf_optimizer, mode='max', factor=0.5, patience=2, verbose=True)
    
    # 创建梯度缩放器
    use_amp = torch.cuda.is_available()
    memm_scaler = GradScaler() if use_amp else None
    crf_scaler = GradScaler() if use_amp else None
    
    # 训练循环
    num_epochs = 20
    best_f1 = {'memm': -1, 'crf': -1}
    patience = 5
    no_improve_count = {'memm': 0, 'crf': 0}
    
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # 训练MEMM模型
        memm_model.train()
        memm_train_loss = 0
        for batch_idx, batch in enumerate(train_loader):
            inputs, targets = batch
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                targets = targets.cuda()
            
            memm_optimizer.zero_grad()
            
            if use_amp:
                with autocast():
                    outputs = memm_model(inputs, targets)
                    mask = (targets != 0).float()
                    loss = criterion(outputs.view(-1, num_tags), targets.view(-1))
                    loss = (loss * mask.view(-1)).sum() / mask.sum()
                
                memm_scaler.scale(loss).backward()
                memm_scaler.unscale_(memm_optimizer)
                torch.nn.utils.clip_grad_norm_(memm_model.parameters(), 1.0)  # 增加梯度裁剪阈值
                memm_scaler.step(memm_optimizer)
                memm_scaler.update()
            else:
                outputs = memm_model(inputs, targets)
                mask = (targets != 0).float()
                loss = criterion(outputs.view(-1, num_tags), targets.view(-1))
                loss = (loss * mask.view(-1)).sum() / mask.sum()
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(memm_model.parameters(), 1.0)
                memm_optimizer.step()
            
            memm_train_loss += loss.item()
            
            if (batch_idx + 1) % 10 == 0:
                print(f'MEMM Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}')
        
        # 评估MEMM模型
        memm_model.eval()
        dev_f1_memm = evaluate_with_beam_search(memm_model, dev_loader)
        print(f"MEMM - Train Loss: {memm_train_loss/len(train_loader):.4f}, Dev F1: {dev_f1_memm:.4f}")
        
        # 训练CRF模型
        crf_model.train()
        crf_train_loss = 0
        for batch_idx, batch in enumerate(train_loader):
            inputs, targets = batch
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                targets = targets.cuda()
            
            crf_optimizer.zero_grad()
            
            if use_amp:
                with autocast():
                    mask = (targets != 0)
                    loss = crf_model(inputs, targets, mask)
                
                crf_scaler.scale(loss).backward()
                crf_scaler.unscale_(crf_optimizer)
                torch.nn.utils.clip_grad_norm_(crf_model.parameters(), 1.0)
                crf_scaler.step(crf_optimizer)
                crf_scaler.update()
            else:
                mask = (targets != 0)
                loss = crf_model(inputs, targets, mask)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(crf_model.parameters(), 1.0)
                crf_optimizer.step()
            
            crf_train_loss += loss.item()
            
            if (batch_idx + 1) % 10 == 0:
                print(f'CRF Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}')
        
        # 评估CRF模型
        crf_model.eval()
        dev_f1_crf = evaluate_crf(crf_model, dev_loader)
        print(f"CRF - Train Loss: {crf_train_loss/len(train_loader):.4f}, Dev F1: {dev_f1_crf:.4f}")
        
        # 更新学习率
        memm_scheduler.step(dev_f1_memm)
        crf_scheduler.step(dev_f1_crf)
        
        print(f"MEMM LR: {memm_optimizer.param_groups[0]['lr']:.6f}")
        print(f"CRF LR: {crf_optimizer.param_groups[0]['lr']:.6f}")
        
        # 保存最佳模型
        if dev_f1_memm > best_f1['memm']:
            best_f1['memm'] = dev_f1_memm
            torch.save({
                'epoch': epoch,
                'model_state_dict': memm_model.state_dict(),
                'optimizer_state_dict': memm_optimizer.state_dict(),
                'scheduler_state_dict': memm_scheduler.state_dict(),
                'best_f1': dev_f1_memm,
            }, 'best_memm_model.pt')
            no_improve_count['memm'] = 0
        else:
            no_improve_count['memm'] += 1
            
        if dev_f1_crf > best_f1['crf']:
            best_f1['crf'] = dev_f1_crf
            torch.save({
                'epoch': epoch,
                'model_state_dict': crf_model.state_dict(),
                'optimizer_state_dict': crf_optimizer.state_dict(),
                'scheduler_state_dict': crf_scheduler.state_dict(),
                'best_f1': dev_f1_crf,
            }, 'best_crf_model.pt')
            no_improve_count['crf'] = 0
        else:
            no_improve_count['crf'] += 1
        
        # 早停检查
        if min(no_improve_count.values()) >= patience:
            print(f"早停：{patience}个epoch未改善")
            break
    
    # 最终测试
    print("\n最终测试结果：")
    
    # 测试MEMM模型
    checkpoint = torch.load('best_memm_model.pt')
    memm_model.load_state_dict(checkpoint['model_state_dict'])
    test_f1_memm = evaluate_with_beam_search(memm_model, test_loader)
    print(f"MEMM Test F1: {test_f1_memm:.4f}")
    
    # 测试CRF模型
    checkpoint = torch.load('best_crf_model.pt')
    crf_model.load_state_dict(checkpoint['model_state_dict'])
    test_f1_crf = evaluate_crf(crf_model, test_loader)
    print(f"CRF Test F1: {test_f1_crf:.4f}")

# 运行训练和评估
train_and_evaluate_models()

正在加载GloVe词向量...
GloVe词向量加载完成！

Epoch 1/20




MEMM Batch 10/59, Loss: 0.1508
MEMM Batch 20/59, Loss: 0.0380
MEMM Batch 30/59, Loss: 0.0131
MEMM Batch 40/59, Loss: 0.0064
MEMM Batch 50/59, Loss: 0.0076
MEMM - Train Loss: 0.1618, Dev F1: 0.0512
CRF Batch 10/59, Loss: 11094.7148
CRF Batch 20/59, Loss: 11678.8809
CRF Batch 30/59, Loss: 11521.1592
CRF Batch 40/59, Loss: 13120.8760
CRF Batch 50/59, Loss: 11791.7402
CRF - Train Loss: 11576.5412, Dev F1: 0.1704
MEMM LR: 0.001000
CRF LR: 0.000100

Epoch 2/20
MEMM Batch 10/59, Loss: 0.0079
MEMM Batch 20/59, Loss: 0.0047
MEMM Batch 30/59, Loss: 0.0039
MEMM Batch 40/59, Loss: 0.0022
MEMM Batch 50/59, Loss: 0.0058
MEMM - Train Loss: 0.0044, Dev F1: 0.0741
CRF Batch 10/59, Loss: 11476.9453
CRF Batch 20/59, Loss: 12647.1357
CRF Batch 30/59, Loss: 12099.1865
CRF Batch 40/59, Loss: 11826.0596
CRF Batch 50/59, Loss: 11513.5566
CRF - Train Loss: 11590.3948, Dev F1: 0.0730
MEMM LR: 0.001000
CRF LR: 0.000100

Epoch 3/20
MEMM Batch 10/59, Loss: 0.0048
MEMM Batch 20/59, Loss: 0.0013
MEMM Batch 30/59, Lo

KeyboardInterrupt: 