In [1]:
import random
import torch
import numpy as np
import torch.nn as nn
from config import Config
from dataset import IntentionDataset
from torch.utils.data import DataLoader
from model import Bert
from tqdm import tqdm

In [2]:
index = [i for i in range(Config.dataset_size)]
random.shuffle(index)
train_index = index[:-100]
test_index = index[-100:]

In [3]:
train_dataset = IntentionDataset(Config, train_index)
test_dataset = IntentionDataset(Config, test_index)

In [4]:
train_dataloader = DataLoader(train_dataset, batch_size = Config.batch_size, shuffle = True)
test_dataloader = DataLoader(test_dataset, batch_size = Config.batch_size, shuffle = False)

In [5]:
next(iter(train_dataloader))

[('暂时不需要了',
  '我不知道怎么签',
  '我不想要这玩意',
  '肯定不买',
  '请直说',
  '对的。',
  '搞错了',
  '有事，不方便',
  '不可以',
  '他是谁啊',
  '嗯哼',
  '不喜欢',
  '我没有没时间',
  '这个让人很纠结',
  '我无法转告他',
  '不可以',
  '没有看到啊',
  '我现在没有时间',
  '放弃了',
  '把它发个我吧',
  '短信没收到啊',
  '不明白',
  '我需要考虑清楚',
  '不可以的，再见',
  '我没有时间',
  '我都不知道有没有此人',
  '别来烦我好不好',
  '抱歉，我没有空',
  '不可以',
  '十分纠结',
  '别发了',
  '等下我看看产品信息'),
 tensor([4, 3, 4, 4, 1, 1, 0, 2, 4, 0, 3, 4, 4, 5, 4, 4, 3, 4, 4, 3, 3, 5, 5, 4,
         4, 0, 4, 4, 4, 5, 4, 3])]

In [6]:
device = Config.device

In [6]:
model = Bert(Config).to(device)

Some weights of the model checkpoint at chinese-bert were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
criterion = nn.CrossEntropyLoss()

In [8]:
params_dict = [{'params': model.bert.parameters(), 'lr': 5e-5}, {'params': model.linear.parameters(), 'lr': 1e-3}]
optimizer = torch.optim.Adam(params_dict)

In [9]:
best_test_accuracy = 0
for epoch in range(Config.epoch):
    model.train()
    total_loss = []
    total_right = 0
    for batch in tqdm(train_dataloader):
        sentences, labels = batch[0], batch[1].to(device)
        outputs = model(sentences)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss.append(loss.detach().cpu().numpy())
        predictions = torch.argmax(outputs, dim=1).cpu().tolist()
        labels = labels.cpu().tolist()
        total_right += sum([True if predictions[i] == labels[i] else False for i in range(len(labels))])
    print(f"Epoch: {epoch+1}, Train loss: {np.mean(total_loss):.4f}, Train accuracy: {total_right/len(train_dataset):.4f}")
    
    model.eval()
    total_loss = []
    total_right = 0
    for batch in tqdm(test_dataloader):
        sentences, labels = batch[0], batch[1].to(device)
        outputs = model(sentences)
        loss = criterion(outputs, labels)
        total_loss.append(loss.detach().cpu().numpy())
        predictions = torch.argmax(outputs, dim=1).cpu().tolist()
        labels = labels.cpu().tolist()
        total_right += sum([True if predictions[i] == labels[i] else False for i in range(len(labels))])
    print(f"Epoch: {epoch+1}, Test loss: {np.mean(total_loss):.4f}, Test accuracy: {total_right/len(test_dataset):.4f}")
    if total_right/len(test_dataset) > best_test_accuracy:
        best_test_accuracy = total_right/len(test_dataset)
        print(f"saving best model...")
        torch.save(model, Config.save_path)
print(f"Training Done. Best test accuracy: {best_test_accuracy:.4f}")

100%|██████████| 25/25 [00:03<00:00,  7.53it/s]
 20%|██        | 5/25 [00:00<00:00, 41.55it/s]

Epoch: 1, Train loss: 1.1732, Train accuracy: 0.5959


100%|██████████| 25/25 [00:00<00:00, 46.91it/s]


Epoch: 1, Test loss: 0.3835, Test accuracy: 0.8780
saving best model...


100%|██████████| 25/25 [00:02<00:00,  9.41it/s]
 16%|█▌        | 4/25 [00:00<00:00, 37.56it/s]

Epoch: 2, Train loss: 0.3692, Train accuracy: 0.8831


100%|██████████| 25/25 [00:00<00:00, 43.71it/s]


Epoch: 2, Test loss: 0.2156, Test accuracy: 0.8996
saving best model...


100%|██████████| 25/25 [00:03<00:00,  8.06it/s]
 20%|██        | 5/25 [00:00<00:00, 43.15it/s]

Epoch: 3, Train loss: 0.2297, Train accuracy: 0.9263


100%|██████████| 25/25 [00:00<00:00, 46.90it/s]


Epoch: 3, Test loss: 0.2133, Test accuracy: 0.9263
saving best model...


100%|██████████| 25/25 [00:03<00:00,  7.84it/s]
 20%|██        | 5/25 [00:00<00:00, 49.17it/s]

Epoch: 4, Train loss: 0.2582, Train accuracy: 0.9085


100%|██████████| 25/25 [00:00<00:00, 49.42it/s]


Epoch: 4, Test loss: 0.1648, Test accuracy: 0.9377
saving best model...


100%|██████████| 25/25 [00:03<00:00,  7.34it/s]
 16%|█▌        | 4/25 [00:00<00:00, 35.08it/s]

Epoch: 5, Train loss: 0.1997, Train accuracy: 0.9314


100%|██████████| 25/25 [00:00<00:00, 34.70it/s]


Epoch: 5, Test loss: 0.1517, Test accuracy: 0.9454
saving best model...


100%|██████████| 25/25 [00:03<00:00,  7.69it/s]
 12%|█▏        | 3/25 [00:00<00:00, 23.07it/s]

Epoch: 6, Train loss: 0.1760, Train accuracy: 0.9377


100%|██████████| 25/25 [00:00<00:00, 28.76it/s]
  4%|▍         | 1/25 [00:00<00:03,  7.06it/s]

Epoch: 6, Test loss: 0.1638, Test accuracy: 0.9377


100%|██████████| 25/25 [00:02<00:00,  9.03it/s]
 20%|██        | 5/25 [00:00<00:00, 48.73it/s]

Epoch: 7, Train loss: 0.1861, Train accuracy: 0.9263


100%|██████████| 25/25 [00:00<00:00, 35.37it/s]


Epoch: 7, Test loss: 0.1254, Test accuracy: 0.9466
saving best model...


100%|██████████| 25/25 [00:03<00:00,  7.68it/s]
 12%|█▏        | 3/25 [00:00<00:00, 25.29it/s]

Epoch: 8, Train loss: 0.1482, Train accuracy: 0.9416


100%|██████████| 25/25 [00:00<00:00, 31.18it/s]
  4%|▍         | 1/25 [00:00<00:02,  8.35it/s]

Epoch: 8, Test loss: 0.1600, Test accuracy: 0.9263


100%|██████████| 25/25 [00:02<00:00,  8.90it/s]
 12%|█▏        | 3/25 [00:00<00:00, 27.00it/s]

Epoch: 9, Train loss: 0.1530, Train accuracy: 0.9428


100%|██████████| 25/25 [00:00<00:00, 34.61it/s]


Epoch: 9, Test loss: 0.1158, Test accuracy: 0.9517
saving best model...


100%|██████████| 25/25 [00:03<00:00,  7.56it/s]
 12%|█▏        | 3/25 [00:00<00:00, 27.18it/s]

Epoch: 10, Train loss: 0.1299, Train accuracy: 0.9479


100%|██████████| 25/25 [00:00<00:00, 32.61it/s]

Epoch: 10, Test loss: 0.1083, Test accuracy: 0.9504
Training Done. Best test accuracy: 0.9517





In [7]:
import torch
from config import Config
model = torch.load(Config.save_path)

In [11]:
for batch in tqdm(train_dataloader):
    break

  0%|          | 0/22 [00:00<?, ?it/s]


In [22]:
batch[0][1:2]

('想签约又不敢签约',)

In [25]:
text = '想签约又不敢签约'

In [37]:
output = model((text,))

In [38]:
output.shape

torch.Size([1, 7])

In [55]:
prediction = torch.argmax(output, dim=1).tolist()[0]

In [59]:
torch.nn.functional.softmax(output, dim=1).tolist()[0][prediction]

0.9993461966514587