In [1]:
import time
import torch
import argparse
import numpy as np
import torch.nn as nn
from tqdm.auto import tqdm
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


In [2]:
class BertClassifier(nn.Module):
    def __init__(self,
                 pretrained: str,
                 num_classes=3,
                 pooling_output_layer=-1):
        super(BertClassifier, self).__init__()
        D_in, H, D_out = 768, 768, num_classes
        self.bert = BertModel.from_pretrained(pretrained)
        self.classifier = nn.Sequential(nn.Linear(D_in, H), nn.Tanh(), nn.Linear(H, D_out))
        self.dropout = nn.Dropout(0.1)
        self.pooling_output_layer = pooling_output_layer

    def forward(self, input_ids, attention_mask, output_hidden_states=True):
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask,
                            output_hidden_states=output_hidden_states)
        sentence_embeddings = outputs[1]
        sentence_embeddings = self.dropout(sentence_embeddings)
        logits = self.classifier(sentence_embeddings.to(device))
        return logits

In [3]:
max_sequence_length=130
batch_size=128
epochs=2
warmup_steps=2000
lr=3e-5
max_grad_norm=1.0
log_step=500

In [4]:
class MyDataset(Dataset):
    def __init__(self, text_list, label_list, tokenizer, max_sequence_len):
        self.input_ids = []
        self.token_type_ids = []
        self.attention_mask = []
        self.label_list = label_list
        self.len = len(label_list)
        for text in tqdm(text_list):
            text = text[:max_sequence_len - 2]
            title_ids = tokenizer.encode_plus(text, padding='max_length', max_length=max_sequence_len, truncation=True)
            self.input_ids.append(title_ids['input_ids'])
            self.attention_mask.append(title_ids["attention_mask"])

    def __getitem__(self, index):
        tmp_input_ids = self.input_ids[index]
        tmp_attention_mask = self.attention_mask[index]
        tmp_label = self.label_list[index]
        output = {"input_ids": torch.tensor(tmp_input_ids).to(device),
                  "attention_mask": torch.tensor(tmp_attention_mask).to(device)}
        return output, tmp_label

    def __len__(self):
        return self.len

In [5]:
def data_loader(x_list, y_list, tokenizer, max_sequence_len, batch_size, shuffle):
    dataset = MyDataset(x_list, y_list, tokenizer, max_sequence_len)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=shuffle)
    return dataloader

In [6]:
def load_model(num_labels):
    model = BertClassifier("bert-base-uncased", num_labels)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    return model, tokenizer

In [7]:
def compute_acc(logits, label):
    predicted_class_id = torch.tensor([w.argmax().item() for w in logits])
    return float((predicted_class_id == label).float().sum()) / label.shape[0]

In [9]:
import pandas as pd
label_dict = {1: 0, 2: 0, 3: 1, 4: 2, 5: 2}
def load_raw_data(data_path):
    # load raw data ，200000line for train and scale the test set
    train_x, train_y = [], []
    eval_x, eval_y = [], []
    df = pd.read_csv(data_path)
    df['label'] = df['rating'].replace(label_dict)
    use_df = pd.concat([df[df.label!=2], df[df.label==2].sample(100000)])
    use_df = use_df.sample(frac=1.0).reset_index()
    for idx in range(len(use_df)):
        if idx < 200000:
            train_x.append(use_df['review_text'][idx])
            train_y.append(int(use_df['label'][idx]))
        else:
            eval_x.append(use_df['review_text'][idx])
            eval_y.append(int(use_df['label'][idx]))
    print('训练数据:', len(train_x), len(train_y))
    print('验证数据:', len(eval_x), len(eval_y))
    return train_x, train_y, eval_x, eval_y

In [10]:
def train(model, dataloader, device):
    num_training_steps = epochs * len(dataloader)
    optimizer = Adam(model.parameters(), lr=lr)
    model.to(device)
    model.train()
    batch_steps = 0
    loss_fct = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for batch, label in dataloader:
            batch_steps += 1
            logits = model(**batch)
            acc = compute_acc(logits, label)
            loss = loss_fct(logits.view(-1, 3).to(device), label.view(-1).to(device))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            if batch_steps % log_step == 0:
                print("train epoch {}/{}, batch {}/{}, loss {}, acc {}".format(
                    epoch + 1, epochs,
                    batch_steps,
                    num_training_steps,
                    loss,
                    acc))
    torch.save(model, 'bert_mini_sample_model.pth')

In [11]:
from sklearn.metrics import classification_report
def evaluate(dataloader):
    model = torch.load("bert_mini_sample_model.pth")
    model.to(device)
    model.eval()
    loss_list = []
    acc_list = []
    loss_fct = nn.CrossEntropyLoss()
    labels_all = []
    predict_all = []
    with torch.no_grad():
        for batch, label in dataloader:
            labels_all.extend(label)
            logits = model(**batch)
            acc = compute_acc(logits, label)
            loss = loss_fct(logits.view(-1, 3).to(device), label.view(-1).to(device))
            loss_list.append(float(loss))
            acc_list.append(float(acc))
            predict_label = [w.argmax().item() for w in logits]
            predict_all.extend(predict_label)
    print("loss: {},".format(np.mean(loss_list)),
          "accuracy: {}.".format(np.mean(acc_list)))
    labels_all = [w.item() for w in labels_all]
    acc = accuracy_score(labels_all, predict_all)
    p = precision_score(labels_all, predict_all, average='weighted')
    r = recall_score(labels_all, predict_all, average='weighted')
    f1 = f1_score(labels_all, predict_all, average='weighted')
    print(classification_report(labels_all, predict_all))
    print('acc:', acc)
    print('precision:', p)
    print('recall:', r)
    print('f1:', f1)

In [12]:
def predict(device, text, tokenizer):
    model = torch.load("bert_mini_sample_model.pth")
    model.to(device)
    model.eval()
    time_start = time.time()
    with torch.no_grad():
        text = text[:max_sequence_length - 2]
        inputs = tokenizer.encode_plus(text,
                                       padding='max_length',
                                       max_length=max_sequence_length,
                                       return_tensors="pt")
        inputs = {"input_ids": inputs['input_ids'].to(device),
                  "attention_mask": inputs['attention_mask'].to(device)}
        logits = model(**inputs)
        print("predict time cost {}".format(time.time() - time_start))
        predicted_class_id = logits.argmax().item()
    print("text: {}".format(text))
    print("predict label: {}".format(predicted_class_id))

In [14]:
model, tokenizer = load_model(num_labels=3)
train_x, train_y, eval_x, eval_y = load_raw_data("./data/review.csv")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


训练数据: 200000 200000
验证数据: 85187 85187


In [37]:

train_dataloader = data_loader(train_x,
                               train_y,
                               tokenizer,
                               max_sequence_length,
                               batch_size,
                               True)
train(model, train_dataloader, device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


训练数据: 200000 200000
验证数据: 85187 85187


  0%|          | 0/200000 [00:00<?, ?it/s]

train epoch 1/2, batch 500/3126, loss 0.5722352266311646, acc 0.7265625
train epoch 1/2, batch 1000/3126, loss 0.5043110251426697, acc 0.78125
train epoch 1/2, batch 1500/3126, loss 0.5758607983589172, acc 0.734375
train epoch 2/2, batch 2000/3126, loss 0.3973851203918457, acc 0.84375
train epoch 2/2, batch 2500/3126, loss 0.4306446611881256, acc 0.796875
train epoch 2/2, batch 3000/3126, loss 0.3966715931892395, acc 0.828125


In [15]:
# print("计算验证集的loss和acc～～～")
eval_dataloader = data_loader(eval_x,
                              eval_y,
                              tokenizer,
                              max_sequence_length,
                              batch_size,
                              False)
evaluate(eval_dataloader)

  0%|          | 0/85187 [00:00<?, ?it/s]

loss: 0.4445248742898305, accuracy: 0.8193493750280131.
              precision    recall  f1-score   support

           0       0.83      0.87      0.85     28711
           1       0.74      0.71      0.72     26437
           2       0.88      0.87      0.87     30039

    accuracy                           0.82     85187
   macro avg       0.82      0.82      0.82     85187
weighted avg       0.82      0.82      0.82     85187

acc: 0.8193268926009837
precision: 0.8186208921022787
recall: 0.8193268926009837
f1: 0.8187668699877619


In [16]:
print("PREDICT ")
text = "Great Dj & dance music, food needs a bit of improvement as it lacks flavor"
predict(device, text, tokenizer)

单条预测～～～
predict time cost 0.009909868240356445
text: Great Dj & dance music, food needs a bit of improvement as it lacks flavor
predict label: 1
