In [1]:
# -*- coding: utf-8 -*-
# @Time    : 2023/8/8 上午
# @Author  : 东光太狼 
# @Content : 预训练语言模型入门系列笔记（二）
# 该系列笔记将介绍基于transformer的预训练语言模型，使用该模型解决现实NLP任务。
# （二）使用bert微调

In [None]:
!pip install transformers==4.29.0
!pip install small_text==1.3.1

In [1]:
import datasets
import torch

import pandas as pd
import numpy as np

import logging
logging.getLogger('small_text').setLevel(logging.INFO)
import warnings
warnings.filterwarnings("ignore")

from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer
from small_text import TransformersDataset
from small_text import TransformerModelArguments
from small_text import TransformerBasedClassificationFactory

# fix the random seed
seed = 2023
torch.manual_seed(seed)
np.random.seed(seed)

In [2]:
#第一部分：数据预处理
df = pd.read_csv('data/sample_多分类.csv')
df['label'] = 0
df.loc[df['categories']=='支付方式改革','label'] = 1
#df.head(3)
train = df.sample(frac=0.8)
test = df[~df.index.isin(train.index)]
num_classes = np.unique(train['label']).shape[0]

print('训练集数量为：',len(train),
      '\n测试集数量为：',len(test),
      '\n标签类数量为：',num_classes)

训练集数量为： 34704 
测试集数量为： 8676 
标签类数量为： 2


In [5]:
transformer_model_name = "Langboat/mengzi-bert-base"
tokenizer = AutoTokenizer.from_pretrained(transformer_model_name)
transformer_model = TransformerModelArguments(transformer_model_name)

Downloading (…)lve/main/config.json:   0%|          | 0.00/849 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/269k [00:00<?, ?B/s]

In [9]:
#第二部分：训练模型
## 将文本数据封装
train = TransformersDataset.from_arrays(list(train['text']),
                                        list(train['label']),
                                        tokenizer,
                                        max_length=90)
test = TransformersDataset.from_arrays(list(test['text']), 
                                       list(test['label']),
                                       tokenizer,
                                       max_length=90)

In [10]:
## 加载预训练模型
transformer_model = TransformerModelArguments(transformer_model_name)
clf_factory = TransformerBasedClassificationFactory(transformer_model, 
                                                    num_classes, 
                                                    kwargs=dict({'device': 'cuda', 
                                                                 'mini_batch_size': 256,
                                                                 'class_weight': 'balanced',
                                                                 'num_epochs': 20
                                                                }))

In [11]:
## 使用训练集进行训练
clf = clf_factory.new().fit(train)

Downloading (…)lve/main/config.json:   0%|          | 0.00/849 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/269k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/206M [00:00<?, ?B/s]

INFO:small_text.integrations.transformers.classifiers.classification:Epoch: 1 | 00:00:55
	Train Set Size: 31233
	Loss: 0.0013(train)	|	Acc: 92.7%(train)
	Loss: 0.0004(valid)	|	Acc: 99.2%(valid)
INFO:small_text.integrations.transformers.classifiers.classification:Epoch: 2 | 00:00:51
	Train Set Size: 31233
	Loss: 0.0002(train)	|	Acc: 99.4%(train)
	Loss: 0.0003(valid)	|	Acc: 99.0%(valid)
INFO:small_text.integrations.transformers.classifiers.classification:Epoch: 3 | 00:00:51
	Train Set Size: 31233
	Loss: 0.0001(train)	|	Acc: 99.7%(train)
	Loss: 0.0004(valid)	|	Acc: 99.4%(valid)
INFO:small_text.integrations.transformers.classifiers.classification:Epoch: 4 | 00:00:52
	Train Set Size: 31233
	Loss: 0.0001(train)	|	Acc: 99.8%(train)
	Loss: 0.0006(valid)	|	Acc: 99.5%(valid)
INFO:small_text.integrations.transformers.classifiers.classification:Epoch: 5 | 00:00:51
	Train Set Size: 31233
	Loss: 0.0000(train)	|	Acc: 99.9%(train)
	Loss: 0.0004(valid)	|	Acc: 99.3%(valid)
INFO:small_text.integrations.t

In [12]:
#第三部分：测试模型精度
clf.validate(test)

(0.00013596425689014045, 0.9925080682342093)