In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizerFast, BertForTokenClassification, AdamW,BertConfig
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import classification_report
from tqdm import tqdm
from torchcrf import CRF

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class NERDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len):
        self.sentences, self.labels = self.read_ner_data(file_path)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def read_ner_data(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            data = file.read().strip().split('\n\n')
        
        sentences = []
        labels = []
        for block in data:
            tokens_labels = block.split('\n')
            sentence = []
            label = []
            for token_label in tokens_labels:
                if token_label.strip():
                    token, tag = token_label.split()
                    sentence.append(token)
                    label.append(tag)
            sentences.append(sentence)
            labels.append(label)
        
        return sentences, labels

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        labels = self.labels[idx]

        encoding = self.tokenizer(sentence,
                                  is_split_into_words=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)

        labels = [tag2id[label] for label in labels]

        # label_ids = [0 if mask == 1 else -100 for mask in encoding['attention_mask']]
        label_ids = [-100] * self.max_len
        # label_ids[0] = tag2id['<START>']
        for i, label in enumerate(labels):
            if i < self.max_len - 1:
                label_ids[i + 1] = label
        # end_idx = len(labels) if len(labels) < self.max_len else self.max_len
        # label_ids[end_idx - 1] = tag2id['<END>']

        encoding['labels'] = label_ids
        return {key: torch.tensor(val) for key, val in encoding.items()}

# 标签到ID的映射
tag2id = {'O': 0, 'B-ORG': 1, 'I-ORG': 2, 'B-PER': 3, 'I-PER': 4, 'B-LOC': 5, 'I-LOC': 6}
id2tag = {v: k for k, v in tag2id.items()}

In [3]:
# 参数设置
train_file_path = 'data/train.txt'
test_file_path = 'data/test.txt'
# pretrained_model_name = 'bert-base-chinese'
pretrained_model_name = './models--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f'
max_len = 128
batch_size = 16
epochs = 3
learning_rate = 2e-5

use_crf = True  # 是否使用CRF层
use_bilstm = True  # 是否使用双向LSTM
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# 加载预训练模型
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name)
# 创建数据集和数据加载器
train_dataset = NERDataset(train_file_path, tokenizer, max_len)
test_dataset = NERDataset(test_file_path, tokenizer, max_len)
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size 
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

Train dataset size: 46364
Test dataset size: 4365


In [5]:
from typing import List
class NER(nn.Module):
    def __init__(self, model_name, num_labels, 
                 use_bilstm=False, use_crf=False, 
                 dropout=0.1):
        super().__init__()
        # 核心配置参数
        self.use_bilstm = use_bilstm
        self.use_crf = use_crf
        
        # BERT基础模型（使用ForTokenClassification版本）
        self.bert_tc = BertForTokenClassification.from_pretrained(
            model_name, 
            num_labels=num_labels,
            output_hidden_states=True  # 需要获取中间层输出
        )
        self.hidden_dim = self.bert_tc.config.hidden_size
        # BiLSTM模块（可选）
        if use_bilstm:
            self.bilstm = nn.LSTM(
                input_size=self.bert_tc.config.hidden_size,
                hidden_size=self.hidden_dim // 2,
                bidirectional=True,
                batch_first=True
            )
            # 替换原分类器
            # self.bert_tc.classifier = nn.Linear(self.hidden_dim, num_labels)
            
        # CRF模块（可选）
        if use_crf:
            self.crf = CRF(num_labels, batch_first=True)
            
        # 通用配置
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask,token_type_ids = None, labels=None):
        # BERT前向传播
        outputs = self.bert_tc(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids = token_type_ids,
            labels = labels
        )
        
        # 获取最终隐藏状态（不使用原logits）
        hidden_states = outputs.hidden_states[-1]  # (B, L, H)
        
        # BiLSTM处理（如果启用）
        if self.use_bilstm:
            lstm_out, _ = self.bilstm(hidden_states)
            lstm_out = self.dropout(lstm_out)
            emissions = self.bert_tc.classifier(lstm_out)
        else:
            # 使用默认分类器
            emissions = self.bert_tc.classifier(hidden_states)
        
        # CRF处理（如果启用）
        loss = None
        if labels is not None:
            if self.use_crf:
                crf_labels = labels.clone()
                crf_labels[labels == -100] = 0  # CRF不支持-100标签
                crf_mask = attention_mask.bool().clone()
                crf_mask[labels == -100] = False
                loss = -self.crf(emissions[:, 1:], tags = crf_labels[:, 1:], mask=crf_mask[:, 1:], reduction='mean')
            else:
                loss_fct = nn.CrossEntropyLoss()
                active_loss = attention_mask.view(-1) == 1
                active_logits = emissions.view(-1, self.bert_tc.config.num_labels)
                active_labels = labels.view(-1)
                loss = loss_fct(active_logits[active_loss], active_labels[active_loss])
        
        return {"loss": loss, "emissions": emissions}

    def decode(self, emissions, mask) -> List[int]: 
        if self.use_crf:
            
            tags_list = self.crf.decode(emissions, mask=mask.bool())
            predict = []
            for tags in tags_list:
                predict.extend(tags)
            return predict
        else:
            tags_list = torch.argmax(emissions, dim=-1)
            predict = []
            for tags in tags_list:
                predict.extend(tags[mask == 1].tolist())
            return predict

In [6]:

# 设置优化器和学习率调度器
model = NER(pretrained_model_name, num_labels=len(tag2id), use_bilstm=use_bilstm, use_crf=use_crf)
optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at ./models--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
def evaluate_model(loader, dataset_type="Validation", device="cuda"):
    model.eval()
    true_labels = []
    pred_labels = []

    with torch.no_grad():
        for batch in loader:
            inputs = {key: val.to(device) for key, val in batch.items()}
            outputs = model(**inputs)
            emissions = outputs['emissions']
            mask = inputs['attention_mask'].bool()  # 获取有效的输入位置的mask
            mask[inputs['labels'] == -100] = False  # CRF不支持-100标签
            predictions = model.decode(emissions[:,1:], mask[:,1:])  # 忽略第一个token的预测
            for i in range(len(emissions)):
                true_labels.extend(inputs['labels'][i][mask[i]].tolist())
            pred_labels.extend(predictions)
    report = classification_report(true_labels, pred_labels, target_names=[id2tag[i] for i in range(len(tag2id))], zero_division=0)
    print(f"{dataset_type} metrics:\n{report}")

In [8]:
# 训练模型
model.to(device)
test = True
for epoch in range(epochs):
    model.train()
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        inputs = {key: val.to(device) for key, val in batch.items()}
        outputs = model(**inputs)
        loss = outputs['loss']
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    
    # 在每个epoch结束时评估验证集
    evaluate_model(val_loader, "Validation", device)

print("Training complete.")

Training Epoch 1: 100%|██████████| 2608/2608 [09:17<00:00,  4.68it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       1.00      1.00      1.00    189930
       B-ORG       0.94      0.94      0.94      2069
       I-ORG       0.95      0.97      0.96      8139
       B-PER       0.99      0.99      0.99      1705
       I-PER       0.99      0.99      0.99      3348
       B-LOC       0.97      0.97      0.97      3470
       I-LOC       0.97      0.95      0.96      4781

    accuracy                           0.99    213442
   macro avg       0.97      0.97      0.97    213442
weighted avg       0.99      0.99      0.99    213442



Training Epoch 2: 100%|██████████| 2608/2608 [09:07<00:00,  4.77it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       1.00      1.00      1.00    189930
       B-ORG       0.96      0.94      0.95      2069
       I-ORG       0.96      0.96      0.96      8139
       B-PER       0.99      0.99      0.99      1705
       I-PER       1.00      0.99      0.99      3348
       B-LOC       0.97      0.98      0.98      3470
       I-LOC       0.96      0.97      0.96      4781

    accuracy                           0.99    213442
   macro avg       0.98      0.98      0.98    213442
weighted avg       0.99      0.99      0.99    213442



Training Epoch 3: 100%|██████████| 2608/2608 [09:09<00:00,  4.74it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       1.00      1.00      1.00    189930
       B-ORG       0.96      0.95      0.95      2069
       I-ORG       0.97      0.96      0.97      8139
       B-PER       0.99      0.99      0.99      1705
       I-PER       0.99      0.99      0.99      3348
       B-LOC       0.97      0.98      0.98      3470
       I-LOC       0.97      0.97      0.97      4781

    accuracy                           1.00    213442
   macro avg       0.98      0.98      0.98    213442
weighted avg       1.00      1.00      1.00    213442

Training complete.


In [9]:
# 测试模型
evaluate_model(test_loader, "Test")

Test metrics:
              precision    recall  f1-score   support

           O       1.00      1.00      1.00    150668
       B-ORG       0.92      0.95      0.93      1302
       I-ORG       0.93      0.97      0.95      5460
       B-PER       0.98      0.99      0.99      1401
       I-PER       0.98      0.99      0.98      2647
       B-LOC       0.98      0.96      0.97      2851
       I-LOC       0.97      0.95      0.96      4356

    accuracy                           0.99    168685
   macro avg       0.97      0.97      0.97    168685
weighted avg       0.99      0.99      0.99    168685

