In [2]:
# 需要导入的包
# The python library need be imported
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
from transformers import AdamW
from datasets import load_dataset
from datasets import load_from_disk
from sklearn import metrics
import numpy as np

In [4]:
# 读取数据类
# The class of loading datasets
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = load_from_disk('./TikTok')
        self.dataset = self.dataset[split]
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']
        return text, label

In [5]:
# 加载字典  添加几个不能识别的标点符号
# load dictionary and add some punctuation
token = BertTokenizer.from_pretrained('bert-base-chinese')
token.add_tokens(new_tokens=['…','―','“','”'])

4

In [8]:
# 文本最大长度 
# Max length of text
length = 30

# 数据划分
# data division to batch
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=length,
                                   return_tensors='pt',
                                   return_length=True)

    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    return input_ids, attention_mask, token_type_ids,labels

loader = torch.utils.data.DataLoader(dataset = Dataset('train'),
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids,labels) in enumerate(loader):
    break


In [7]:
# 从huggingface加载BERT模型 
# From huggingface load BERT model
pretrained = BertModel.from_pretrained('bert-base-chinese')
pretrained.resize_token_embeddings(len(token))
for param in pretrained.parameters():
    param.requires_grad_(False)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
# 使用lstm为分类器 
# Creat a LSTM for classifcation
class LSTM(torch.nn.Module):
    def __init__(self):
        super(LSTM,self).__init__()
        self.rnn = nn.LSTM(768, 384, num_layers=2, batch_first= True,bidirectional=True, dropout=0.5)
        self.fc = torch.nn.Linear(384*2, 2)
        self.dropout = nn.Dropout(0.5)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids)
        output, (hidden, cell) = self.rnn(out.last_hidden_state)
        hidden = torch.cat([hidden[-2],hidden[-1]], dim=1)
        hidden = self.dropout(hidden)
        out = self.fc(hidden)
    
        return out


model1 = LSTM()

model1(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

torch.Size([16, 2])

In [10]:
# 训练 
# Training
optimizer = AdamW(model1.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
model1.train()
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    out = model1(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)
        print(i, loss.item(), accuracy)

    if i == 450:
        break



0 0.6786486506462097 0.6875
5 0.5608528256416321 0.6875
10 0.4853718876838684 0.8125
15 0.382571280002594 0.8125
20 0.7550916075706482 0.6875
25 0.389615535736084 0.8125
30 0.6661620140075684 0.6875
35 0.2125944346189499 0.9375
40 0.9701465368270874 0.6875
45 0.328982949256897 0.875
50 0.3119659125804901 0.875
55 0.12042928487062454 0.9375
60 0.41816550493240356 0.6875
65 0.16959929466247559 0.9375
70 0.45900094509124756 0.875
75 0.11932005733251572 0.9375
80 0.37557628750801086 0.8125
85 0.29105183482170105 0.875
90 0.3299608528614044 0.875
95 0.16580957174301147 0.9375
100 0.1903885155916214 0.875
105 0.30719897150993347 0.875
110 0.4996291697025299 0.75
115 0.20678968727588654 0.875
120 0.21921245753765106 0.875
125 0.16400021314620972 0.9375
130 0.15519458055496216 0.9375
135 0.20067480206489563 0.9375
140 0.38900670409202576 0.8125
145 0.24437777698040009 0.875
150 0.056191351264715195 1.0
155 0.14547888934612274 0.9375
160 0.1203714907169342 0.9375
165 0.17818517982959747 0.9375


In [12]:
# 测试 
# Testing 
def test():
    model1.eval()
    correct = 0


    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=1,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)
    labels_all=[]
    pred_all=[]

    for i, (input_ids, attention_mask, token_type_ids,labels) in enumerate(loader_test):
        
        if i == 450:
            break
        with torch.no_grad():
            out = model1(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        labels_all=np.append(labels_all,labels)
        pred_all=np.append(pred_all,out)    
    accuracy = metrics.accuracy_score(labels_all, pred_all)
    print("Accuracy: %.2f%%" % (accuracy * 100.0))
    print(metrics.classification_report(labels_all, pred_all,digits=4) ) 
    print(metrics.confusion_matrix(labels_all, pred_all))
test()

Accuracy: 92.00%
              precision    recall  f1-score   support

         0.0     0.8373    0.9392    0.8854       148
         1.0     0.9683    0.9106    0.9386       302

    accuracy                         0.9200       450
   macro avg     0.9028    0.9249    0.9120       450
weighted avg     0.9252    0.9200    0.9211       450

[[139   9]
 [ 27 275]]
