In [47]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizerFast, BertForTokenClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import classification_report
from tqdm import tqdm
from torchcrf import CRF

In [48]:
class NERDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len):
        self.sentences, self.labels = self.read_ner_data(file_path)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def read_ner_data(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            data = file.read().strip().split('\n\n')
        
        sentences = []
        labels = []
        for block in data:
            tokens_labels = block.split('\n')
            sentence = []
            label = []
            for token_label in tokens_labels:
                if token_label.strip():
                    token, tag = token_label.split()
                    sentence.append(token)
                    label.append(tag)
            sentences.append(sentence)
            labels.append(label)
        
        return sentences, labels

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        labels = self.labels[idx]

        encoding = self.tokenizer(sentence,
                                  is_split_into_words=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)

        labels = [tag2id[label] for label in labels]

        # label_ids = [0 if mask == 1 else -100 for mask in encoding['attention_mask']]
        label_ids = [-100] * self.max_len
        label_ids[0] = tag2id['<START>']
        label_size = len(labels)
        for i, label in enumerate(labels):
            if i < self.max_len - 1:
                label_ids[i + 1] = label
        if label_size < self.max_len - 1:
            label_ids[label_size + 1] = tag2id['<END>']
        else:
            label_ids[-1] = tag2id['<END>']

        encoding['labels'] = label_ids
        return {key: torch.tensor(val) for key, val in encoding.items()}

# 标签到ID的映射
tag2id = {'O': 0, 'B-ORG': 1, 'I-ORG': 2, 'B-PER': 3, 'I-PER': 4, 'B-LOC': 5, 'I-LOC': 6, '<START>': 7, '<END>': 8}
id2tag = {v: k for k, v in tag2id.items()}

In [49]:
# 参数设置
train_file_path = 'data/train.txt'
test_file_path = 'data/test.txt'

pretrained_model_name = './models--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f'
max_len = 128
batch_size = 16
epochs = 3
learning_rate = 2e-5

use_crf = True  # 是否使用CRF层
use_bilstm = True  # 是否使用双向LSTM
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [50]:
# 加载预训练模型
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name)
vocab_size = tokenizer.vocab_size
# 创建数据集和数据加载器
train_dataset = NERDataset(train_file_path, tokenizer, max_len)
test_dataset = NERDataset(test_file_path, tokenizer, max_len)
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size 
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

Train dataset size: 46364
Test dataset size: 4365


In [51]:
from torch import nn
class BiLSTM_CRF_NER(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, 
                 num_labels=7, num_layers=2, dropout=0.1, 
                 use_bilstm=True, use_crf=True):
        super(BiLSTM_CRF_NER, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_labels = num_labels
        self.use_bilstm = use_bilstm
        self.use_crf = use_crf
        
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        # BiLSTM层
        if self.use_bilstm:
            self.bilstm = nn.LSTM(
                input_size=embedding_dim,
                hidden_size=hidden_dim,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout if num_layers > 1 else 0,
                bidirectional=True
            )
            self.dropout = nn.Dropout(dropout)
            lstm_output_dim = hidden_dim * 2  # 双向LSTM
        else:
            lstm_output_dim = embedding_dim
        
        # 分类层
        self.classifier = nn.Linear(lstm_output_dim, num_labels)
        
        # CRF层
        if self.use_crf:
            self.crf = CRF(num_labels, batch_first=True)
        
        # 初始化权重
        self.init_weights()
    
    def init_weights(self):
        """初始化模型权重"""
        nn.init.xavier_uniform_(self.embedding.weight)
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.constant_(self.classifier.bias, 0)
        
        if self.use_bilstm:
            for name, param in self.bilstm.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    nn.init.constant_(param, 0)
    
    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
        batch_size, seq_len = input_ids.size()
        
        # 词嵌入
        embeddings = self.embedding(input_ids)  # (batch_size, seq_len, embedding_dim)
        
        # BiLSTM处理
        if self.use_bilstm:
            # 使用attention_mask创建有效长度
            lengths = attention_mask.sum(dim=1).cpu()
            
            # Pack padded sequence for efficiency
            packed_embeddings = nn.utils.rnn.pack_padded_sequence(
                embeddings, lengths, batch_first=True, enforce_sorted=False
            )
            
            packed_output, _ = self.bilstm(packed_embeddings)
            
            # Unpack
            lstm_output, _ = nn.utils.rnn.pad_packed_sequence(
                packed_output, batch_first=True, total_length=seq_len
            )
            
            lstm_output = self.dropout(lstm_output)
            logits = self.classifier(lstm_output)
        else:
            logits = self.classifier(embeddings)
        
        outputs = {}
        
        if labels is not None:
            if self.use_crf:
                # 创建mask，排除-100的位置
                mask = (labels != -100) & (attention_mask == 1)
                
                # 将-100替换为0以避免CRF计算错误
                crf_labels = labels.clone()
                crf_labels[labels == -100] = 0
                
                # 计算CRF损失
                loss = -self.crf(logits, crf_labels, mask=mask, reduction='mean')
                outputs['loss'] = loss
                
            else:
                # 标准交叉熵损失
                loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
                
                # 计算损失
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs['loss'] = loss
        outputs['logits'] = logits
        return outputs
    def decode(self, logits, attention_mask):
        if self.use_crf:
            mask = attention_mask.bool()
            predictions = self.crf.decode(logits, mask=mask)
            predict = []
            for p in predictions:
                # 将CRF的输出转换为标签ID
                predict.extend(p)
            return predict
        else:
            # 直接取最大值的索引
            predictions = torch.argmax(logits, dim=-1)
            pridict = []
            for i in range(predictions.size(0)):
                pred = predictions[i][attention_mask[i] == 1].tolist()
                pridict.extend(pred)
            return pridict

In [52]:
model = BiLSTM_CRF_NER(
    vocab_size=vocab_size,
    embedding_dim=128,
    hidden_dim=256,
    num_labels=len(tag2id),
    use_bilstm=use_bilstm,
    use_crf=use_crf
)
model.to(device)
# 优化器和学习率调度器
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_training_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)




In [53]:
def evaluate_model(loader, dataset_type="Validation", device="cuda"):
    model.eval()
    true_labels = []
    pred_labels = []

    with torch.no_grad():
        for batch in loader:
            inputs = {key: val.to(device) for key, val in batch.items()}
            outputs = model(**inputs)
            emissions = outputs['logits']
            mask = inputs['attention_mask'].bool()  # 获取有效的输入位置的mask
            mask[inputs['labels'] == -100] = False  # CRF不支持-100标签
            predictions = model.decode(emissions, mask)
            for i in range(len(emissions)):
                true_labels.extend(inputs['labels'][i][mask[i]].tolist())
            pred_labels.extend(predictions)
    report = classification_report(true_labels, pred_labels, target_names=[id2tag[i] for i in range(len(tag2id))], zero_division=0)
    print(f"{dataset_type} metrics:\n{report}")

## BiLSTM + CRF

In [54]:
# 训练模型
model.to(device)
test = True
for epoch in range(epochs):
    model.train()
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        inputs = {key: val.to(device) for key, val in batch.items()}
        outputs = model(**inputs)
        loss = outputs['loss']
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    # 在每个epoch结束时评估验证集
    evaluate_model(val_loader, "Validation", device)

print("Training complete.")

Training Epoch 1: 100%|██████████| 2608/2608 [06:32<00:00,  6.64it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       0.91      0.99      0.95    187557
       B-ORG       0.00      0.00      0.00      2022
       I-ORG       0.32      0.25      0.28      8587
       B-PER       0.00      0.00      0.00      1794
       I-PER       0.57      0.00      0.00      3509
       B-LOC       0.00      0.00      0.00      3508
       I-LOC       0.50      0.00      0.00      5002
     <START>       0.93      1.00      0.96      4637
       <END>       1.00      1.00      1.00      4637

    accuracy                           0.89    221253
   macro avg       0.47      0.36      0.35    221253
weighted avg       0.84      0.89      0.86    221253



Training Epoch 2: 100%|██████████| 2608/2608 [06:31<00:00,  6.66it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       0.93      0.99      0.96    187557
       B-ORG       0.00      0.00      0.00      2022
       I-ORG       0.32      0.40      0.35      8587
       B-PER       0.00      0.00      0.00      1794
       I-PER       0.65      0.00      0.01      3509
       B-LOC       0.45      0.06      0.10      3508
       I-LOC       0.38      0.06      0.10      5002
     <START>       1.00      1.00      1.00      4637
       <END>       0.99      1.00      1.00      4637

    accuracy                           0.90    221253
   macro avg       0.52      0.39      0.39    221253
weighted avg       0.87      0.90      0.87    221253



Training Epoch 3: 100%|██████████| 2608/2608 [06:31<00:00,  6.66it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       0.93      0.99      0.96    187557
       B-ORG       0.62      0.02      0.04      2022
       I-ORG       0.32      0.39      0.35      8587
       B-PER       0.00      0.00      0.00      1794
       I-PER       0.61      0.01      0.01      3509
       B-LOC       0.43      0.12      0.19      3508
       I-LOC       0.40      0.09      0.15      5002
     <START>       1.00      1.00      1.00      4637
       <END>       1.00      1.00      1.00      4637

    accuracy                           0.90    221253
   macro avg       0.59      0.40      0.41    221253
weighted avg       0.88      0.90      0.88    221253

Training complete.


In [55]:
# 测试模型
evaluate_model(test_loader, "Test")

Test metrics:
              precision    recall  f1-score   support

           O       0.94      0.99      0.96    150599
       B-ORG       0.39      0.01      0.01      1302
       I-ORG       0.26      0.36      0.30      5459
       B-PER       0.00      0.00      0.00      1401
       I-PER       0.35      0.00      0.01      2645
       B-LOC       0.55      0.15      0.23      2850
       I-LOC       0.56      0.11      0.19      4356
     <START>       1.00      0.99      0.99      4365
       <END>       0.99      0.99      0.99      4365

    accuracy                           0.91    177342
   macro avg       0.56      0.40      0.41    177342
weighted avg       0.88      0.91      0.89    177342



## BiLSTM + W/O CRF

In [56]:
use_crf = False  # 如果需要切换到不使用CRF的模型，可以设置为False
model = BiLSTM_CRF_NER(
    vocab_size=vocab_size,
    embedding_dim=128,
    hidden_dim=256,
    num_labels=len(tag2id),
    use_bilstm=use_bilstm,
    use_crf=use_crf
)
model.to(device)

BiLSTM_CRF_NER(
  (embedding): Embedding(21128, 128, padding_idx=0)
  (bilstm): LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=512, out_features=9, bias=True)
)

In [57]:
# 优化器和学习率调度器
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_training_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)



In [58]:
# 训练模型
for epoch in range(epochs):
    model.train()
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        inputs = {key: val.to(device) for key, val in batch.items()}
        outputs = model(**inputs)
        loss = outputs['loss']
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    # 在每个epoch结束时评估验证集
    evaluate_model(val_loader, "Validation", device)

print("Training complete.")

Training Epoch 1: 100%|██████████| 2608/2608 [01:29<00:00, 29.30it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       0.92      0.99      0.95    187557
       B-ORG       0.00      0.00      0.00      2022
       I-ORG       0.31      0.32      0.32      8587
       B-PER       1.00      0.00      0.00      1794
       I-PER       0.00      0.00      0.00      3509
       B-LOC       0.00      0.00      0.00      3508
       I-LOC       0.00      0.00      0.00      5002
     <START>       0.88      0.99      0.93      4637
       <END>       0.98      1.00      0.99      4637

    accuracy                           0.89    221253
   macro avg       0.46      0.37      0.35    221253
weighted avg       0.84      0.89      0.86    221253



Training Epoch 2: 100%|██████████| 2608/2608 [01:29<00:00, 29.26it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       0.94      0.99      0.96    187557
       B-ORG       0.39      0.00      0.01      2022
       I-ORG       0.32      0.44      0.37      8587
       B-PER       0.50      0.00      0.00      1794
       I-PER       0.22      0.00      0.00      3509
       B-LOC       0.45      0.10      0.17      3508
       I-LOC       0.39      0.04      0.08      5002
     <START>       0.99      1.00      1.00      4637
       <END>       0.99      1.00      0.99      4637

    accuracy                           0.90    221253
   macro avg       0.58      0.40      0.40    221253
weighted avg       0.87      0.90      0.88    221253



Training Epoch 3: 100%|██████████| 2608/2608 [01:29<00:00, 29.21it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       0.94      0.99      0.96    187557
       B-ORG       0.54      0.03      0.05      2022
       I-ORG       0.33      0.41      0.37      8587
       B-PER       0.50      0.00      0.00      1794
       I-PER       0.25      0.00      0.00      3509
       B-LOC       0.42      0.16      0.23      3508
       I-LOC       0.38      0.08      0.13      5002
     <START>       1.00      1.00      1.00      4637
       <END>       0.99      1.00      0.99      4637

    accuracy                           0.90    221253
   macro avg       0.59      0.41      0.42    221253
weighted avg       0.88      0.90      0.88    221253

Training complete.


In [59]:
# 测试模型
evaluate_model(test_loader, "Test")

Test metrics:
              precision    recall  f1-score   support

           O       0.94      0.99      0.97    150599
       B-ORG       0.37      0.01      0.02      1302
       I-ORG       0.27      0.38      0.31      5459
       B-PER       1.00      0.00      0.00      1401
       I-PER       0.32      0.00      0.00      2645
       B-LOC       0.52      0.20      0.29      2850
       I-LOC       0.53      0.09      0.15      4356
     <START>       1.00      0.98      0.99      4365
       <END>       0.98      0.99      0.98      4365

    accuracy                           0.91    177342
   macro avg       0.66      0.40      0.41    177342
weighted avg       0.89      0.91      0.89    177342

