# 1 数据预处理



In [1]:
import os
import pandas as pd, numpy as np
import torch
from transformers import BertForSequenceClassification, BertTokenizer

data_path = 'data/toutiao_data.txt'
with open(data_path, 'r', encoding='utf-8') as file:
    data = file.readlines()

# 预处理
data = np.array([line[:-1].strip().split('_!_') for line in data])
data = pd.DataFrame(data, columns=['news_id', 'label', 'label_name', 'text', 'key_words'])
data.drop(['news_id', 'label_name'], axis=1, inplace=True)
data['label'] = data['label'].astype(int)
data['text'] = data['text'].astype(str)
data['key_words'] = data['key_words'].astype(str)
print(data.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 382688 entries, 0 to 382687
Data columns (total 3 columns):
 #   Column     Non-Null Count   Dtype 
---  ------     --------------   ----- 
 0   label      382688 non-null  int32 
 1   text       382688 non-null  object
 2   key_words  382688 non-null  object
dtypes: int32(1), object(2)
memory usage: 7.3+ MB
None


In [2]:
scale = 0.1
train_ratio = 0.8
total_size = int(data.shape[0] * scale)
train_size = int(total_size * train_ratio)

data = data.sample(frac=scale, random_state=42)

train_data = data[:train_size]
test_data = data[train_size:]
train_text = train_data.apply(lambda x: x['text'] + "," +  x['key_words'], axis=1).tolist()
train_label = train_data['label'].values
test_text = test_data.apply(lambda x: x['text'] + "," +  x['key_words'], axis=1).tolist()
test_label = test_data['label'].values

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
train_text = tokenizer(train_text, padding=True, truncation=True, max_length=128, return_tensors='pt')
test_text = tokenizer(test_text, padding=True, truncation=True, max_length=128, return_tensors='pt')

In [4]:
# print(train_text)
# print(train_label)
# print(test_text)
# print(test_label)

label2name = {
    100: '民生',
    101: '文化',
    102: '娱乐',
    103: '体育',
    104: '财经',
    106: '房产',
    107: '汽车',
    108: '教育',
    109: '科技',
    110: '军事',
    112: '旅游',
    113: '国际',
    114: '证券',
    115: '农业',
    116: '电竞'
}
label_size = len(label2name) + 2

# 2 模型训练

In [5]:
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

epochs, batch_size = 5, 512
lr = 5e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
train_dataset = TensorDataset(train_text['input_ids'], train_text['attention_mask'], torch.tensor(train_label))
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = TensorDataset(test_text['input_ids'], test_text['attention_mask'], torch.tensor(test_label))
test_data_loader = DataLoader(test_dataset, batch_size=batch_size)

model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=label_size, local_files_only=True)
# 除了最后一层
for param in model.bert.parameters():
    param.requires_grad = False
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs * len(train_data_loader))
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):
    train_loss, train_acc = 0, 0
    bar = tqdm(train_data_loader)
    bar.set_description(f'Epoch {epoch+1}/{epochs}')
    for i, batch in enumerate(bar):
        input_ids, attention_mask, label = batch
        label = label.long() - 100
        input_ids, attention_mask, label = input_ids.to(device), attention_mask.to(device), label.to(device)

        logits = model(input_ids, attention_mask=attention_mask).logits

        loss = criterion(logits, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        bar.set_postfix(loss=loss.item())
        scheduler.step()



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  attn_output = torch.nn.functional.scaled_dot_product_attention(
Epoch 1/5:  22%|██▏       | 13/60 [00:21<01:16,  1.63s/it, loss=2.03]


KeyboardInterrupt: 

# 3 模型评估

In [47]:
model.eval()
correct, total = 0, 0
for batch in tqdm(test_data_loader):
    input_ids, attention_mask, label = batch
    label = label.long() - 100

    input_ids, attention_mask, label = input_ids.to(device), attention_mask.to(device), label.to(device)
    with torch.no_grad():
        logits = model(input_ids, attention_mask=attention_mask).logits
        pred = torch.argmax(logits, dim=1)
        correct += torch.sum(pred == label).item()
        total += label.size(0)

print(f'Accuracy: {correct / total:.4f}')

100%|██████████| 15/15 [00:05<00:00,  2.82it/s]

Accuracy: 0.8119





In [53]:
# 随机选一些展示效果
random_samples = []
for i in range(10):
    random_samples.append(test_data.sample(1))

for sample in random_samples:
    text = sample['text'].values[0] + "," + sample['key_words'].values[0]
    true_label = sample['label'].values[0]

    token = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
    pred = model(token['input_ids'].to(device), attention_mask=token['attention_mask'].to(device)).logits
    pred = torch.argmax(pred, dim=1).item()  + 100

    print(f'{text} \n True Label: {label2name[true_label]} \n Pred Label: {label2name[pred]}\n')

口碑最好的10万级合资车，油耗4毛，关键是一年买几十万台,桑塔纳,PASS,威驰,轩逸,爱丽舍 
 True Label: 汽车 
 Pred Label: 汽车

阚清子拖着“沉重”的身躯默默行走在机场，面无表情尽显疲惫,纪凌尘,阚清子,综艺节目,机场 
 True Label: 娱乐 
 Pred Label: 娱乐

又一女团成员确定结婚！本月20日举行婚礼,婚礼,女团,Nine Muses,结为连理,DJ Da.Q 
 True Label: 娱乐 
 Pred Label: 娱乐

娱乐圈真正的“冻龄女神”，71年龄却保持着18岁的脸,娃娃脸,一代女皇,还珠格格,冻龄,潘迎紫,宋慧乔,武媚娘传奇 
 True Label: 娱乐 
 Pred Label: 娱乐

日本新兵因受不了虐待，满载炸弹撞向指挥部，当场炸死30名长官,攻击队,日本,新兵,武士道精神 
 True Label: 军事 
 Pred Label: 军事

《三国机密之潜龙在渊》热播 马天宇身份泄漏韩东君遭受酷刑,司马懿,刘平,曹操,郭嘉,任红昌 
 True Label: 娱乐 
 Pred Label: 娱乐

为什么坦克世界没人玩？,英雄联盟,排位赛,梦三国,百夫长,坦克世界,游戏 
 True Label: 电竞 
 Pred Label: 电竞

5000公里必换机油？恭喜你中了4S店的套路，一招辨别该不该换机油,宝马,机油,换机油,4s店 
 True Label: 汽车 
 Pred Label: 汽车

给孩子起名“女起诗经男起楚辞”的文化含意,邶风·新台,鄘风·君子偕老,邶风·静女,湘夫人,雨巷,周南·桃夭,郑风·出其东门,邶风·燕燕,灵均,诗经,离骚,邶风·,楚辞,周信芳,郑风·叔于田 
 True Label: 文化 
 Pred Label: 文化

为什么那么多人去上海的城隍庙？,玉华堂日记,龙宿郊民图,怀庆府推官刘君墓表,南吴旧话录,明朝,溪山秋霁图,刘钝,董其昌,潘允端,归有光 
 True Label: 文化 
 Pred Label: 文化

