In [1]:
import torch
from transformers import AdamW
from data_loader import TalkDataset
from model_budling import PHI_NER
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [2]:
BATCH_SIZE = 4
data_path = "./dataset/sample_512_bert_data.pt"

list_of_dict = torch.load(data_path)
# train_list = list_of_dict[:80]

""" model setting (training)"""
trainSet = TalkDataset("train", list_of_dict)
trainLoader = DataLoader(trainSet, batch_size=BATCH_SIZE)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)
model = PHI_NER()
optimizer = AdamW(model.parameters(), lr=1e-5) # AdamW = BertAdam

BIO_weight = torch.FloatTensor([89.94618834, 50.65151515,  1.]).cuda()
type_weight = torch.FloatTensor([1, 7.42883207e+02, 3.78451358e+02, 4.71952716e+01,
 9.70063365e+07, 9.70063365e+07, 6.68553928e+03, 9.70063365e+07,
 9.70063365e+07, 9.70063365e+07, 9.70063365e+07, 9.70063365e+07,
 9.70063365e+07, 2.00575855e+03, 9.70063365e+07, 5.49531139e+02,
 9.70063365e+07, 3.10975750e+02, 9.70063365e+07]).cuda()

BIO_loss_fct = nn.CrossEntropyLoss(weight=BIO_weight)
type_loss_fct = nn.CrossEntropyLoss(weight=type_weight)

# high-level 顯示此模型裡的 modules
print("""
name            module
----------------------""")
for name, module in model.named_children():
    if name == "bert":
        for n, _ in module.named_children():
            print(f"{name}:{n}")
#             print(_)
    else:
        print("{:15} {}".format(name, module))

device: cuda:0

name            module
----------------------
bert:embeddings
bert:encoder
bert:pooler
type_classifier Linear(in_features=768, out_features=19, bias=True)
BIO_classifier  Linear(in_features=768, out_features=3, bias=True)
softmax         Softmax(dim=-1)


In [3]:
""" training """
from datetime import datetime,timezone,timedelta

model = model.to(device)
model.train()

EPOCHS = 10
dt1 = datetime.utcnow().replace(tzinfo=timezone.utc)
dt2 = dt1.astimezone(timezone(timedelta(hours=8))) # 轉換時區 -> 東八區
print(dt2)
for epoch in range(EPOCHS):
    running_loss = 0.0
    type_running_loss = 0.0
    BIO_running_loss = 0.0
    for data in trainLoader:
        tokens_tensors, segments_tensors, masks_tensors, \
        type_label, BIO_label = [t.to(device) for t in data]

    # 將參數梯度歸零
    optimizer.zero_grad()

    # forward pass
    outputs = model(input_ids=tokens_tensors, 
                  token_type_ids=segments_tensors, 
                  attention_mask=masks_tensors)
    
    type_pred = outputs[0]
    type_pred = torch.transpose(type_pred, 1, 2)
#     print(type_pred.size(), type_label.size())
    type_running_loss = type_loss_fct(type_pred, type_label)

    BIO_pred = outputs[1]
    BIO_pred = torch.transpose(BIO_pred, 1, 2)
#     print(BIO_pred.size(), BIO_label.size())
    BIO_loss = BIO_loss_fct(BIO_pred, BIO_label)

    loss = BIO_loss + type_running_loss

    # backward
    loss.backward()
    optimizer.step()

    # 紀錄當前 batch loss
    running_loss += loss.item()
    type_running_loss += type_running_loss.item()
    BIO_running_loss += BIO_loss.item()

    CHECKPOINT_NAME = './model/full_train_1_bert_wwm_E' + str(epoch) + '.pt' 
    torch.save(model.state_dict(), CHECKPOINT_NAME)

    dt1 = datetime.utcnow().replace(tzinfo=timezone.utc)
    dt2 = dt1.astimezone(timezone(timedelta(hours=8))) # 轉換時區 -> 東八區
    print('%s\t[epoch %d] loss: %.3f, type_loss: %.3f, BIO_loss: %.3f' %
          (dt2, epoch + 1, running_loss, type_running_loss, BIO_running_loss))

2020-09-17 22:08:26.817272+08:00
2020-09-17 22:08:27.189362+08:00	[epoch 1] loss: 4.048, type_loss: 5.871, BIO_loss: 1.112
2020-09-17 22:08:27.532100+08:00	[epoch 2] loss: 3.984, type_loss: 5.861, BIO_loss: 1.054
2020-09-17 22:08:27.872703+08:00	[epoch 3] loss: 3.929, type_loss: 5.848, BIO_loss: 1.005
2020-09-17 22:08:28.214932+08:00	[epoch 4] loss: 3.875, type_loss: 5.828, BIO_loss: 0.960
2020-09-17 22:08:28.609836+08:00	[epoch 5] loss: 3.841, type_loss: 5.816, BIO_loss: 0.933
2020-09-17 22:08:28.949210+08:00	[epoch 6] loss: 3.795, type_loss: 5.783, BIO_loss: 0.904
2020-09-17 22:08:29.290107+08:00	[epoch 7] loss: 3.749, type_loss: 5.749, BIO_loss: 0.875
2020-09-17 22:08:29.629484+08:00	[epoch 8] loss: 3.704, type_loss: 5.685, BIO_loss: 0.861
2020-09-17 22:08:29.970222+08:00	[epoch 9] loss: 3.659, type_loss: 5.634, BIO_loss: 0.842
2020-09-17 22:08:30.312435+08:00	[epoch 10] loss: 3.628, type_loss: 5.618, BIO_loss: 0.819


In [7]:
import torch
path = "./dataset/sample_512_bert_data.pt"
t = torch.load(path)

In [None]:
for i in range(len(t)):
    print(i, len(t[i]['input_ids']), len(t[i]['seg']), len(t[i]['att']), len(t[i]['BIO_label']), len(t[i]['type_label']))

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("hfl/chinese-bert-wwm")
print(tokenizer.convert_ids_to_tokens(t[63]['input_ids']))