# 1 数据预处理

In [56]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from transformers import BertTokenizerFast, BertForTokenClassification

In [57]:
def read_data(file_path):
    """
    读取数据
    """
    sentences = []
    labels = []
    with open(file_path, 'r', encoding='utf-8') as f:
        sentence = []
        label = []
        for line in f:
            line = line.strip()
            if line == "":
                if sentence:
                    sentences.append(sentence)
                    labels.append(label)
                    sentence = []
                    label = []
                continue
            parts = line.split()
            if len(parts) == 2:
                word, tag = parts
                sentence.append(word)
                label.append(tag)
        if sentence:
            sentences.append(sentence)
            labels.append(label)
    return sentences, labels

def handle_data(file_path):
    """
    处理数据，将数据转化为模型可以接受的格式
    """
    sentences, labels = read_data(file_path)

    tokenized_inputs = {"input_ids": [], "attention_mask": []}
    tokenized_labels = []

    tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

    for sentence, label in zip(sentences, labels):
        encoded = tokenizer(sentence, is_split_into_words=True, truncation=True, padding="max_length", max_length=32 , return_tensors='pt')


        word_ids = encoded.word_ids()
        label_ids = []
        for word_id in word_ids:
            if word_id is None:
                label_ids.append(-100)
            else:
                label_ids.append(label2id[label[word_id]])

        tokenized_inputs["input_ids"].append(encoded["input_ids"])
        tokenized_inputs["attention_mask"].append(encoded["attention_mask"])
        tokenized_labels.append(label_ids)

    return torch.cat(tokenized_inputs["input_ids"]), torch.cat(tokenized_inputs["attention_mask"]), torch.tensor(tokenized_labels)

file_path = './data/youku/train.txt'

In [58]:
id2label = {0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-MISC', 4: 'I-MISC', 5: 'B-TELEVISION', 6: 'I-TELEVISION'}
label2id = {label: i for i, label in enumerate(id2label.values())}

input_ids, attention_mask, labels = handle_data(file_path)

train_dataset = TensorDataset(input_ids, attention_mask, labels)

# 2 模型训练

In [59]:
epochs, batch_size = 8, 64
lr = 5e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [61]:
from tqdm import tqdm

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

model = BertForTokenClassification.from_pretrained( 'bert-base-chinese', num_labels=7)
model.to(device)
loss = torch.nn.CrossEntropyLoss()
trainer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, T_max=len(train_loader) * epochs)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    bar = tqdm(train_loader)
    bar.set_description(f"epoch: {epoch + 1}")
    for batch in bar:
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        trainer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        l = loss(outputs.logits.view(-1, 7), labels.view(-1))
        l.backward()
        trainer.step()
        scheduler.step()

        total_loss += l.item()
        bar.set_postfix(loss=total_loss / (batch[0].shape[0] * (epoch + 1)))



Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
epoch: 1: 100%|██████████| 126/126 [00:21<00:00,  5.96it/s, loss=51.8] 
epoch: 2: 100%|██████████| 126/126 [00:20<00:00,  6.01it/s, loss=13.4] 
epoch: 3: 100%|██████████| 126/126 [00:21<00:00,  5.97it/s, loss=5.99]  
epoch: 4: 100%|██████████| 126/126 [00:20<00:00,  6.07it/s, loss=2.81]  
epoch: 5: 100%|██████████| 126/126 [00:20<00:00,  6.08it/s, loss=1.4]   
epoch: 6: 100%|██████████| 126/126 [00:20<00:00,  6.07it/s, loss=0.795] 
epoch: 7: 100%|██████████| 126/126 [00:20<00:00,  6.05it/s, loss=0.557]  
epoch: 8: 100%|██████████| 126/126 [00:21<00:00,  5.97it/s, loss=0.413]  


# 3 模型评估

In [66]:
# 测试
file_path = './data/youku/test.txt'
input_ids, attention_mask, true_labels = handle_data(file_path)
test_dataset = TensorDataset(input_ids, attention_mask, true_labels)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
model.eval()
preds = []
total_acc = 0
for batch in tqdm(test_loader):
    input_ids, attention_mask, true_labels = batch
    input_ids, attention_mask, true_labels = input_ids.to(device), attention_mask.to(device), true_labels.to(device)
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    preds.extend(torch.argmax(outputs.logits, dim=-1).cpu().numpy().tolist())
    # 实体标注的评估
    # i = 3
    # print(outputs.logits.argmax(dim=-1)[i])
    # print(true_labels[i])
    # print(attention_mask[i])
    right = (outputs.logits.argmax(dim=-1) == true_labels).sum(-1)
    all = (attention_mask.sum(-1) - 2)
    total_acc += (right / all).sum().item() / len(batch[0])

print(total_acc / len(test_loader))

100%|██████████| 16/16 [00:01<00:00, 13.69it/s]

0.9367983786434662





In [123]:
# 随机展示一些效果
import random
# random.seed(42)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
for i in range(10):
    idx = random.randint(0, len(input_ids) - 1 )
    input_id = input_ids[idx]
    mask = attention_mask[idx]

    # print(input_id )
    # print( mask)

    pred= model(input_ids = input_id.unsqueeze(0), attention_mask = mask.unsqueeze(0))
    pred = pred.logits.argmax(dim=-1).cpu().numpy().tolist()

    valid_len = mask.sum()
    for i in range(1, valid_len - 1):
        # 实体命名开头
        if id2label[pred[0][i]][0] == "B" : print()

        if id2label[pred[0][i]] == 'O':  print(tokenizer.decode(input_id[i]))
        else: print(tokenizer.decode(input_id[i]) , id2label[pred[0][i]])


网 B-TELEVISION
球 I-TELEVISION
王 I-TELEVISION
子 I-TELEVISION
1
6

江 B-MISC
阴 I-MISC
外
语
培
训

江 B-MISC
阴 I-MISC
学
外
语
哪
里
好
女
神
喊
你
学
外
语

爱 B-TELEVISION
情 I-TELEVISION
保 I-TELEVISION
卫 I-TELEVISION
战 I-TELEVISION

爱 B-TELEVISION
要 I-TELEVISION
有 I-TELEVISION
你 I-TELEVISION
搞
笑
视
频

西 B-TELEVISION
游 I-TELEVISION
搞
笑
视
频
搞
笑
电
影
,
恶
搞

西 B-TELEVISION
游 I-TELEVISION
记 I-TELEVISION

孙 B-MISC
悟 I-MISC
空 I-MISC
1
采
访

郑 B-PER
伊 I-PER
健 I-PER
,

北 B-MISC
京 I-MISC
电
影
节
,
闭
幕
式
,
2
0
1
4
0
4
2
3
,
标
清
三
哥

苗 B-PER
僑 I-PER
偉 I-PER
（

使 B-TELEVISION
徒 I-TELEVISION
行 I-TELEVISION
者 I-TELEVISION
剪
輯
片
）
【

奇 B-PER
怪 I-PER
君 I-PER
-
[UNK]
家
】
,
[UNK]
i
n
e
c
r
a
f
t
,
我
的
世
界
,
神
奇
宝
贝
口
袋

柯 B-TELEVISION
南 I-TELEVISION
剧
场

九 B-TELEVISION
水 I-TELEVISION
平 I-TELEVISION
线 I-TELEVISION
上 I-TELEVISION
的 I-TELEVISION
阴 I-TELEVISION
谋 I-TELEVISION
【
粤
语
】
[UNK]

南 B-PER
阳 I-PER
任 I-PER
国 I-PER
熙 I-PER
珍
藏
,
,
港
台
经
典
恐
怖
鬼
片
《

阴 B-TELEVISION
阳 I-TELEVISION
界 I-TELEVISION
》
{
国
语
}
这

西 B-TELEVISION
游 I