In [79]:
import torch

from transformers import AutoTokenizer
clean_up_tokenization_spaces = True

#加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')

tokenizer

BertTokenizerFast(name_or_path='google-bert/bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [80]:
from datasets import load_dataset,list_metrics

#加载数据集
dataset = load_dataset('csv', data_files='input_your_tran_dataset.csv', split='train')
dataset2 = load_dataset('csv', data_files='input_your_test_dataset.csv', split='train')


dataset,dataset2


(Dataset({
     features: ['text', 'label'],
     num_rows: 45000
 }),
 Dataset({
     features: ['text', 'label'],
     num_rows: 5000
 }))

In [81]:
#定义数据集遍历工具
def collate_fn(data):
    text = [i['text'] for i in data]
    label = [i['label'] for i in data]

    #文字编码
    data = tokenizer(text,
                     padding=True,
                     truncation=True,
                     max_length=500,
                     return_tensors='pt',
                     return_token_type_ids=False)

    #设置label
    data['label'] = torch.LongTensor(label)

    return data


loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=16,
                                     shuffle=True,
                                     drop_last=True,
                                     collate_fn=collate_fn)

data = next(iter(loader))

for k, v in data.items():
    print(k, v.shape)

len(loader)

input_ids torch.Size([16, 32])
attention_mask torch.Size([16, 32])
label torch.Size([16])


2812

In [82]:
#定义模型
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

        #加载预训练模型
        from transformers import AutoModel
        self.pretrained = AutoModel.from_pretrained(
            'google-bert/bert-base-chinese')

        #self.fc = torch.nn.Linear(in_features=768, out_features=2)
        self.fc = torch.nn.Linear(in_features=768, out_features=5)

    def forward(self, input_ids, attention_mask, label=None):
        #使用预训练模型抽取数据特征
        with torch.no_grad():
            last_hidden_state = self.pretrained(
                input_ids=input_ids,
                attention_mask=attention_mask).last_hidden_state

        #只取第0个词的特征做分类
        last_hidden_state = last_hidden_state[:, 0]

        #对抽取的特征只取第一个字的结果做分类即可
        out = self.fc(last_hidden_state).softmax(dim=1)

        #计算loss
        loss = None
        if label is not None:
            loss = torch.nn.functional.cross_entropy(out, label)

        return loss, out


model = Model()

model(**data)



(tensor(1.5774, grad_fn=<NllLossBackward0>),
 tensor([[0.1326, 0.0947, 0.3107, 0.0700, 0.3920],
         [0.1602, 0.1156, 0.3069, 0.1552, 0.2621],
         [0.1686, 0.1709, 0.2465, 0.1592, 0.2548],
         [0.1389, 0.2530, 0.2221, 0.1494, 0.2366],
         [0.1646, 0.1782, 0.2519, 0.1446, 0.2607],
         [0.1081, 0.1388, 0.1878, 0.2193, 0.3461],
         [0.1370, 0.1442, 0.2241, 0.1694, 0.3253],
         [0.2040, 0.1116, 0.2902, 0.1238, 0.2704],
         [0.1699, 0.1952, 0.2487, 0.1516, 0.2346],
         [0.1314, 0.0869, 0.3251, 0.1338, 0.3228],
         [0.1588, 0.1595, 0.3180, 0.1217, 0.2421],
         [0.1491, 0.0854, 0.3384, 0.1225, 0.3046],
         [0.1103, 0.2380, 0.2126, 0.1618, 0.2773],
         [0.1630, 0.1464, 0.2718, 0.1444, 0.2745],
         [0.2125, 0.0777, 0.2353, 0.1508, 0.3238],
         [0.1560, 0.1113, 0.2628, 0.1964, 0.2736]], grad_fn=<SoftmaxBackward0>))

In [87]:
#执行训练
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
def train():
    y_true_all = torch.tensor([])
    y_pred_all = torch.tensor([])
    #optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

    for i, data in enumerate(loader):
        loss, out = model(**data)
        
        y_true = data.label
        #print(f'y_true: {y_true}')
        y_pred = out.argmax(dim=1)
        #print(f'y_pred: {y_pred}')
        y_true_all = torch.cat((y_true_all, y_true), dim=0)
        y_pred_all = torch.cat((y_pred_all, y_pred), dim=0)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i % 100 == 0:
            out = out.argmax(dim=1)
            acc = (out == data.label).sum().item() / len(data.label)
            print(i, len(loader), loss.item(), acc)

        #if i == 50:
        #    break
    report = classification_report(y_true_all, y_pred_all)  
    print(f'report: {report}')


train()

0 2812 1.0746947526931763 0.875
100 2812 1.015134572982788 0.875
200 2812 1.1222753524780273 0.75
300 2812 0.9760445952415466 0.9375
400 2812 0.9791123270988464 0.9375
500 2812 1.0803909301757812 0.8125
600 2812 1.0309051275253296 0.875
700 2812 0.9435992240905762 1.0
800 2812 1.1437382698059082 0.8125
900 2812 0.9202396869659424 1.0
1000 2812 1.2638517618179321 0.625
1100 2812 0.9314250349998474 1.0
1200 2812 1.0861499309539795 0.8125
1300 2812 1.0201774835586548 0.9375
1400 2812 1.1565898656845093 0.75
1500 2812 1.0710444450378418 0.8125
1600 2812 1.0073071718215942 0.875
1700 2812 0.9736364483833313 0.9375
1800 2812 1.221720814704895 0.6875
1900 2812 1.087439775466919 0.8125
2000 2812 1.0330674648284912 0.875
2100 2812 1.0394777059555054 0.875
2200 2812 0.9726205468177795 0.9375
2300 2812 0.9983895421028137 0.9375
2400 2812 1.0573872327804565 0.8125
2500 2812 1.0396744012832642 1.0
2600 2812 1.041595697402954 0.875
2700 2812 1.0770405530929565 0.8125
2800 2812 1.1819407939910889 0.6

In [90]:
#执行测试
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
def test():
    #loader_test = torch.utils.data.DataLoader(dataset['test'],
    loader_test = torch.utils.data.DataLoader(dataset2,
                                              batch_size=16,
                                              shuffle=True,
                                              drop_last=True,
                                              collate_fn=collate_fn)

    correct = 0
    total = 0
    y_true_all = torch.tensor([])
    y_pred_all = torch.tensor([])
    for i, data in enumerate(loader_test):
        with torch.no_grad():
            _, out = model(**data)

        out = out.argmax(dim=1)

        #模型评价函数
        y_true = data.label
        y_pred = out
        y_true_all = torch.cat((y_true_all, y_true), dim=0)
        y_pred_all = torch.cat((y_pred_all, y_pred), dim=0)
        
        correct += (out == data.label).sum().item()
        
        total += len(data.label)

        if i % 100 == 0:
            print(i, len(loader_test), correct / total)

        #if i == 5:
        #    break

    report = classification_report(y_true_all, y_pred_all)  
    print(f'report: {report}')
    return correct / total

    

test()

0 312 0.9375
1 312 0.84375
2 312 0.8125
3 312 0.84375
4 312 0.8625
5 312 0.875
6 312 0.8660714285714286
7 312 0.8671875
8 312 0.875
9 312 0.875
10 312 0.8693181818181818
11 312 0.859375
12 312 0.8653846153846154
13 312 0.8660714285714286
14 312 0.8666666666666667
15 312 0.85546875
16 312 0.8602941176470589
17 312 0.8576388888888888
18 312 0.8519736842105263
19 312 0.85625
20 312 0.8571428571428571
21 312 0.8579545454545454
22 312 0.8586956521739131
23 312 0.859375
24 312 0.8625
25 312 0.8605769230769231
26 312 0.8587962962962963
27 312 0.8616071428571429
28 312 0.8577586206896551
29 312 0.8583333333333333
30 312 0.8608870967741935
31 312 0.859375
32 312 0.8636363636363636
33 312 0.8658088235294118
34 312 0.8660714285714286
35 312 0.8680555555555556
36 312 0.8699324324324325
37 312 0.8601973684210527
38 312 0.8589743589743589
39 312 0.859375
40 312 0.8612804878048781
41 312 0.8601190476190477
42 312 0.8604651162790697
43 312 0.859375
44 312 0.8597222222222223
45 312 0.8627717391304348
4

0.8657852564102564