In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizerFast, BertForTokenClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import classification_report
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class NERDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len):
        self.sentences, self.labels = self.read_ner_data(file_path)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def read_ner_data(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            data = file.read().strip().split('\n\n')
        
        sentences = []
        labels = []
        for block in data:
            tokens_labels = block.split('\n')
            sentence = []
            label = []
            for token_label in tokens_labels:
                if token_label.strip():
                    token, tag = token_label.split()
                    sentence.append(token)
                    label.append(tag)
            sentences.append(sentence)
            labels.append(label)
        
        return sentences, labels

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        labels = self.labels[idx]

        encoding = self.tokenizer(sentence,
                                  is_split_into_words=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)

        labels = [tag2id[label] for label in labels]

        # label_ids = [tag2id['pad']]*self.max_len
        label_ids = [-100] * self.max_len
        for i, label in enumerate(labels):
            if i < self.max_len - 1:
                label_ids[i + 1] = label
        encoding['labels'] = label_ids
        return {key: torch.tensor(val) for key, val in encoding.items()}

# 标签到ID的映射
tag2id = {'O': 0, 'B-ORG': 1, 'I-ORG': 2, 'B-PER': 3, 'I-PER': 4, 'B-LOC': 5, 'I-LOC': 6}
id2tag = {v: k for k, v in tag2id.items()}

In [None]:
# 参数设置
train_file_path = 'data/train.txt'
test_file_path = 'data/test.txt'
# pretrained_model_name = 'bert-base-chinese'
pretrained_model_name = './models--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f'
max_len = 128
batch_size = 16
epochs = 3
learning_rate = 2e-5

预训练模型下载地址
https://hf-mirror.com/google-bert/bert-base-chinese

In [None]:
# 加载预训练模型和分词器
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name)
model = BertForTokenClassification.from_pretrained(pretrained_model_name, num_labels=len(tag2id))

Some weights of BertForTokenClassification were not initialized from the model checkpoint at ./models--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
# 创建数据集和数据加载器
train_dataset = NERDataset(train_file_path, tokenizer, max_len)
test_dataset = NERDataset(test_file_path, tokenizer, max_len)
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

Train dataset size: 46364
Test dataset size: 4365


In [14]:
# 设置优化器和学习率调度器
optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)



In [15]:
# 评估模型
def evaluate_model(loader, dataset_type="Validation", device="cuda"):
    model.eval()
    true_labels = []
    pred_labels = []

    with torch.no_grad():
        for batch in loader:
            inputs = {key: val.to(device) for key, val in batch.items()}
            outputs = model(**inputs)
            logits = outputs.logits
            
            predictions = torch.argmax(logits, dim=-1)
            for i in range(predictions.shape[0]):
                true_labels.extend(inputs['labels'][i][inputs['labels'][i] != -100].tolist())
                pred_labels.extend(predictions[i][inputs['labels'][i] != -100].tolist())

    report = classification_report(true_labels, pred_labels, target_names=[id2tag[i] for i in range(len(tag2id))], zero_division=0)
    print(f"{dataset_type} metrics:\n{report}")

In [16]:
# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for epoch in range(epochs):
    model.train()
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        inputs = {key: val.to(device) for key, val in batch.items()}
        
        outputs = model(**inputs)
        logits = outputs.logits
        loss = outputs.loss
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    
    # 在每个epoch结束时评估验证集
    evaluate_model(val_loader, "Validation", device)

print("Training complete.")

Training Epoch 1: 100%|██████████| 2608/2608 [04:16<00:00, 10.17it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       1.00      1.00      1.00    189316
       B-ORG       0.95      0.93      0.94      1974
       I-ORG       0.96      0.94      0.95      8103
       B-PER       0.99      0.98      0.98      1765
       I-PER       1.00      0.98      0.99      3453
       B-LOC       0.97      0.97      0.97      3562
       I-LOC       0.96      0.97      0.96      4718

    accuracy                           0.99    212891
   macro avg       0.97      0.97      0.97    212891
weighted avg       0.99      0.99      0.99    212891



Training Epoch 2: 100%|██████████| 2608/2608 [04:14<00:00, 10.26it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       1.00      1.00      1.00    189316
       B-ORG       0.96      0.94      0.95      1974
       I-ORG       0.97      0.95      0.96      8103
       B-PER       0.99      0.99      0.99      1765
       I-PER       1.00      0.99      0.99      3453
       B-LOC       0.98      0.97      0.98      3562
       I-LOC       0.97      0.97      0.97      4718

    accuracy                           0.99    212891
   macro avg       0.98      0.97      0.98    212891
weighted avg       0.99      0.99      0.99    212891



Training Epoch 3: 100%|██████████| 2608/2608 [04:14<00:00, 10.26it/s]


Validation metrics:
              precision    recall  f1-score   support

           O       1.00      1.00      1.00    189316
       B-ORG       0.95      0.95      0.95      1974
       I-ORG       0.97      0.96      0.97      8103
       B-PER       0.99      0.99      0.99      1765
       I-PER       1.00      0.99      0.99      3453
       B-LOC       0.97      0.98      0.98      3562
       I-LOC       0.96      0.98      0.97      4718

    accuracy                           1.00    212891
   macro avg       0.98      0.98      0.98    212891
weighted avg       1.00      1.00      1.00    212891

Training complete.


In [None]:
# 测试模型
evaluate_model(test_loader, "Test")

Test metrics:
              precision    recall  f1-score   support

           O       1.00      1.00      1.00    159325
       B-ORG       0.91      0.95      0.93      1302
       I-ORG       0.93      0.97      0.95      5460
       B-PER       0.98      0.99      0.98      1401
       I-PER       0.97      0.98      0.98      2647
       B-LOC       0.98      0.96      0.97      2851
       I-LOC       0.97      0.96      0.96      4356

    accuracy                           0.99    177342
   macro avg       0.96      0.97      0.97    177342
weighted avg       0.99      0.99      0.99    177342

