In [1]:
import time
import torch
import argparse
import numpy as np
import torch.nn as nn
from tqdm.auto import tqdm
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


In [2]:
class BertClassifier(nn.Module):
    def __init__(self,
                 pretrained: str,
                 num_classes=3,
                 pooling_output_layer=-1):
        super(BertClassifier, self).__init__()
        D_in, H, D_out = 768, 768, num_classes
        self.bert = BertModel.from_pretrained(pretrained)
        self.classifier = nn.Sequential(nn.Linear(D_in, H), nn.Tanh(), nn.Linear(H, D_out))
        self.dropout = nn.Dropout(0.1)
        self.pooling_output_layer = pooling_output_layer

    def forward(self, input_ids, attention_mask, output_hidden_states=True):
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask,
                            output_hidden_states=output_hidden_states)
        sentence_embeddings = outputs[1]
        sentence_embeddings = self.dropout(sentence_embeddings)
        logits = self.classifier(sentence_embeddings.to(device))
        return logits

In [3]:
max_sequence_length=130
batch_size=128
epochs=2
warmup_steps=2000
lr=3e-5
max_grad_norm=1.0
log_step=500

In [4]:
class MyDataset(Dataset):
    def __init__(self, text_list, label_list, tokenizer, max_sequence_len):
        self.input_ids = []
        self.token_type_ids = []
        self.attention_mask = []
        self.label_list = label_list
        self.len = len(label_list)
        for text in tqdm(text_list):
            text = text[:max_sequence_len - 2]
            title_ids = tokenizer.encode_plus(text, padding='max_length', max_length=max_sequence_len, truncation=True)
            self.input_ids.append(title_ids['input_ids'])
            self.attention_mask.append(title_ids["attention_mask"])

    def __getitem__(self, index):
        tmp_input_ids = self.input_ids[index]
        tmp_attention_mask = self.attention_mask[index]
        tmp_label = self.label_list[index]
        output = {"input_ids": torch.tensor(tmp_input_ids).to(device),
                  "attention_mask": torch.tensor(tmp_attention_mask).to(device)}
        return output, tmp_label

    def __len__(self):
        return self.len

In [5]:
def data_loader(x_list, y_list, tokenizer, max_sequence_len, batch_size, shuffle):
    dataset = MyDataset(x_list, y_list, tokenizer, max_sequence_len)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=shuffle)
    return dataloader

In [6]:
def load_model(num_labels):
    model = BertClassifier("bert-base-uncased", num_labels)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    return model, tokenizer

In [7]:
def compute_acc(logits, label):
    predicted_class_id = torch.tensor([w.argmax().item() for w in logits])
    return float((predicted_class_id == label).float().sum()) / label.shape[0]

In [9]:
import pandas as pd
label_dict = {1: 0, 2: 0, 3: 1, 4: 2, 5: 2}
def load_raw_data(data_path):
    # load raw data
    train_x, train_y = [], []
    eval_x, eval_y = [], []
    df = pd.read_csv(data_path)
    for idx in range(len(df)):
        if idx < 1300000:
            train_x.append(df['review_text'][idx])
            train_y.append(label_dict[int(df['rating'][idx])])
        else:
            eval_x.append(df['review_text'][idx])
            eval_y.append(label_dict[int(df['rating'][idx])])
    print('训练数据:', len(train_x), len(train_y))
    print('验证数据:', len(eval_x), len(eval_y))
    return train_x, train_y, eval_x, eval_y

In [10]:
def train(model, dataloader, device):
    num_training_steps = epochs * len(dataloader)
    optimizer = Adam(model.parameters(), lr=lr)
    model.to(device)
    model.train()
    batch_steps = 0
    loss_fct = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for batch, label in dataloader:
            batch_steps += 1
            logits = model(**batch)
            acc = compute_acc(logits, label)
            loss = loss_fct(logits.view(-1, 3).to(device), label.view(-1).to(device))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            if batch_steps % log_step == 0:
                print("train epoch {}/{}, batch {}/{}, loss {}, acc {}".format(
                    epoch + 1, epochs,
                    batch_steps,
                    num_training_steps,
                    loss,
                    acc))
    torch.save(model, 'model_final.pth')

In [11]:
from sklearn.metrics import classification_report
def evaluate(dataloader):
    model = torch.load("model_final.pth")
    model.to(device)
    model.eval()
    loss_list = []
    acc_list = []
    loss_fct = nn.CrossEntropyLoss()
    labels_all = []
    predict_all = []
    with torch.no_grad():
        for batch, label in dataloader:
            labels_all.extend(label)
            logits = model(**batch)
            acc = compute_acc(logits, label)
            loss = loss_fct(logits.view(-1, 3).to(device), label.view(-1).to(device))
            loss_list.append(float(loss))
            acc_list.append(float(acc))
            predict_label = [w.argmax().item() for w in logits]
            predict_all.extend(predict_label)
    print("loss: {},".format(np.mean(loss_list)),
          "accuracy: {}.".format(np.mean(acc_list)))
    labels_all = [w.item() for w in labels_all]
    acc = accuracy_score(labels_all, predict_all)
    p = precision_score(labels_all, predict_all, average='weighted')
    r = recall_score(labels_all, predict_all, average='weighted')
    f1 = f1_score(labels_all, predict_all, average='weighted')
    print(classification_report(labels_all, predict_all))
    print('acc:', acc)
    print('precision:', p)
    print('recall:', r)
    print('f1:', f1)

In [12]:
def predict(device, text, tokenizer):
    model = torch.load("model_final.pth")
    model.to(device)
    model.eval()
    time_start = time.time()
    with torch.no_grad():
        text = text[:max_sequence_length - 2]
        inputs = tokenizer.encode_plus(text,
                                       padding='max_length',
                                       max_length=max_sequence_length,
                                       return_tensors="pt")
        inputs = {"input_ids": inputs['input_ids'].to(device),
                  "attention_mask": inputs['attention_mask'].to(device)}
        logits = model(**inputs)
        print("predict time cost {}".format(time.time() - time_start))
        predicted_class_id = logits.argmax().item()
    print("text: {}".format(text))
    print("predict label: {}".format(predicted_class_id))

In [13]:
model, tokenizer = load_model(num_labels=3)
train_x, train_y, eval_x, eval_y = load_raw_data("./data/review.csv")
train_dataloader = data_loader(train_x,
                               train_y,
                               tokenizer,
                               max_sequence_length,
                               batch_size,
                               True)
train(model, train_dataloader, device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


训练数据: 1300000 1300000
验证数据: 187747 187747


  0%|          | 0/1300000 [00:00<?, ?it/s]

train epoch 1/2, batch 500/20314, loss 0.29863736033439636, acc 0.890625
train epoch 1/2, batch 1000/20314, loss 0.3169819712638855, acc 0.8671875
train epoch 1/2, batch 1500/20314, loss 0.2742278277873993, acc 0.90625
train epoch 1/2, batch 2000/20314, loss 0.23564806580543518, acc 0.90625
train epoch 1/2, batch 2500/20314, loss 0.20182210206985474, acc 0.921875
train epoch 1/2, batch 3000/20314, loss 0.1342770904302597, acc 0.96875
train epoch 1/2, batch 3500/20314, loss 0.21181869506835938, acc 0.921875
train epoch 1/2, batch 4000/20314, loss 0.20062929391860962, acc 0.9375
train epoch 1/2, batch 4500/20314, loss 0.22458180785179138, acc 0.8984375
train epoch 1/2, batch 5000/20314, loss 0.15703251957893372, acc 0.9453125
train epoch 1/2, batch 5500/20314, loss 0.17514324188232422, acc 0.9453125
train epoch 1/2, batch 6000/20314, loss 0.18991750478744507, acc 0.953125
train epoch 1/2, batch 6500/20314, loss 0.1247263103723526, acc 0.953125
train epoch 1/2, batch 7000/20314, loss 0.17

In [14]:
# print(" loss and acc")
eval_dataloader = data_loader(eval_x,
                              eval_y,
                              tokenizer,
                              max_sequence_length,
                              batch_size,
                              False)
evaluate(eval_dataloader)

  0%|          | 0/187747 [00:00<?, ?it/s]

loss: 0.19506543190120107, accuracy: 0.9313024157732747.
              precision    recall  f1-score   support

           0       0.73      0.77      0.75     11049
           1       0.54      0.36      0.43     10973
           2       0.96      0.98      0.97    165725

    accuracy                           0.93    187747
   macro avg       0.75      0.70      0.72    187747
weighted avg       0.92      0.93      0.93    187747

acc: 0.9313011659307472
precision: 0.9238351922014859
recall: 0.9313011659307472
f1: 0.926392549375741


In [15]:
print("PREDICT ")
text = "Great Dj & dance music, food needs a bit of improvement as it lacks flavor"
predict(device, text, tokenizer)


单条预测～～～
predict time cost 0.01199030876159668
text: Great Dj & dance music, food needs a bit of improvement as it lacks flavor
predict label: 2
