# bert进行分类任务
### 目录介绍
    home/dataset:存放数据目录
    home/work:存放vocab.txt、Bert预训练模型结构、config文件
    home/model:存放训练存储的模型参数
    
### 其他
    主要工具：采用transformers来加载模型与配置文件
    * 加载tokenier:
        BertTokenizer.from_pretrained(file_path)  
            #file_path='bert-base-uncased'     (线上)
            #file_path='bert-base-uncased-xx.txt'   (现下自定义)
            #file_path='xxx'    (xxx为目录，目录必保存vocab.txt)
       * tokenizer.encode(item[1],add_special_tokens=False)
               等价于 tokenizer.tokenize(sentence)分词 + tokenizer.convert_tokens__to_ids()
    * 加载config:
        BertConfig.from_pretrained(file_path)  
            #file_path=''     (同上)
    * 加载model:
        model=BertForSequenceClassification(config=config)
        model=BertForSequenceClassification.from_pretrained(file_path)  
            #file_path=''     (同上)
    框架：pytorch

In [40]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from transformers import BertTokenizer,BertForSequenceClassification,BertConfig,AdamW,get_linear_schedule_with_warmup

import csv
import json
import numpy as np
import pandas as pd
import os

In [41]:
#超参数
EPOCHS=10  #训练轮数
BATCH_SIZE=4 #批次
MAX_LEN=300 #文本长度
LR=1e-5  #学习率
WARMUP_STEPS=150  #预热步数
T_TOTAL=1834 #总步数

In [84]:
#自定义dataset类
class My_dataset(Dataset):
    def __init__(self,data_list):
        self.dataset=data_list
    
    def __getitem__(self,item_index):
        #构建一条x,y
        text=self.dataset[item_index][1]
        label=self.dataset[item_index][2]
        return text,label
    
    def __len__(self):
        return len(self.dataset)
    
#加载数据,生成data_list,有问题????
def load_dataset(file_path,max_len):
    data_list=[]
    r=csv.reader(open(file_path,'r',encoding='utf-8'))
    for item in r:
        if r.line_num==1:
            continue
        data_list.append(item)
    #加载tokenizer
    tokenizer=BertTokenizer.from_pretrained("home/work")
    #padding与切割操作
    for item in data_list:
        item[1]=item[1].strip().replace(" ","")  #去空格
        num=max_len-len(item[1])
        if num<0:   #切割
            item[1]=item[1][:max_len]
        else:       #填充
            for _ in range(num):
                item[1]=item[1]+"[PAD]"
        item[1]=tokenizer.encode(item[1],add_special_tokens=False)
        num_temp=max_len-len(item[1])   #再次检查
        if num_temp>0:
            for _ in range(num_temp):
                item[1].append(0)
        item[1]=[101]+item[1][:max_len]+[102]
        item[1]=str(item[1])
    return data_list

In [85]:
#数据封装dataloader
train_data=load_dataset("home/dataset/Train.csv",max_len=MAX_LEN)
test_data=load_dataset("home/dataset/Test.csv",max_len=MAX_LEN)
train_dataset=My_dataset(train_data)
print(len(train_data))
train_loader=DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)

7337


In [86]:
#model:可改装成gpu训练
config=BertConfig.from_pretrained("home/work")   #加载config
config.num_labels=3    #类别数
# model=BertForSequenceClassification(config)
model=BertForSequenceClassification.from_pretrained("home/work",config=config)   #加载模型
device='cuda' if torch.cuda.is_available() else 'cpu'
optimizer=AdamW(model.parameters(),lr=LR,correct_bias=False)
scheduler=get_linear_schedule_with_warmup(optimizer,num_warmup_steps=WARMUP_STEPS,num_training_steps=T_TOTAL)
device

Some weights of the model checkpoint at home/work were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at home/wo

'cuda'

In [87]:
#准确率计算
def batch_accuracy(pre,label):
    pre=pre.argmax(dim=1)
    correct=torch.eq(pre,label).sum().float().item()
    accuracy=correct/float(len(label))
    return accuracy

In [90]:
#模型训练
def train(model,train_loader,optimizer,scheduler):
    model.train()
    best_acc=0.0
    if os.path.exists("model/bert_cla.ckpt"):
        model.load_state_dict(torch.load('model/bert_cla.ckpt'))
    print("开始训练。。。")
    for epoch in range(EPOCHS):
        count=0
        for text,label in train_loader:
            text_list=list(map(json.loads,text))   #加载
            label_list=list(map(json.loads,label))
#             print(text_list)
            text_tensor=torch.tensor(text_list)  #转tensor
            label_tensor=torch.tensor(label_list)
            outputs=model(text_tensor,labels=label_tensor)   #模型训练
            loss, logits = outputs[:2]
            optimizer.zero_grad()   #梯度清空
            loss.backward()         #求解梯度
            scheduler.step()        
            optimizer.step()

            acc=batch_accuracy(logits,label_tensor)
            print('epoch:{} | acc:{} | loss:{}'.format(epoch, acc, loss))
#             if count%100==0:
#                 print('epoch:{} | acc:{} | loss:{}'.format(epoch, acc, loss))
            count+=1
            if acc>=best_acc:
                best_acc=acc
                torch.save(model.state_dict(),"bert_cla.ckpt")
    print('训练完成...')

In [91]:
train(model,train_loader,optimizer,scheduler)

开始训练。。。
epoch:0 | acc:0.75 | loss:0.9358950853347778
epoch:0 | acc:0.5 | loss:0.8102744221687317
epoch:0 | acc:0.75 | loss:0.6711229085922241
epoch:0 | acc:0.75 | loss:0.8209071159362793
epoch:0 | acc:0.5 | loss:1.00160551071167
epoch:0 | acc:0.75 | loss:0.796048641204834
epoch:0 | acc:0.5 | loss:0.8536520600318909
epoch:0 | acc:0.75 | loss:0.8585497140884399


KeyboardInterrupt: 