In [1]:
import warnings
warnings.filterwarnings(action='ignore') 
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
import random
from tqdm import tqdm, tqdm_notebook

In [2]:
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

In [3]:
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [4]:
import pickle
with open('./data/train.pickle', 'rb') as f:
    train = pickle.load(f)

with open('./data/val.pickle', 'rb') as f:
    val = pickle.load(f)

with open('./data/test.pickle', 'rb') as f:
    test = pickle.load(f)

In [5]:
##GPU 사용 시
device = torch.device("cuda:0")

In [6]:
bertmodel, vocab = get_pytorch_kobert_model()

using cached model
using cached model


In [7]:
dataset_train = [[text[0][0] + ' ' + text[0][1] + ' ' + text[0][2], str(text[1]) ] for text in train]
dataset_validation = [[text[0][0] + ' ' + text[0][1] + ' ' + text[0][2], str(text[1]) ]  for text in val]
dataset_test = [[text[0][0] + ' ' + text[0][1] + ' ' + text[0][2], str(text[1]) ]  for text in test]

In [8]:
dataset_train[:6]

[['사람이 되고 싶어서 그대가 말한 온갖 작품을 가슴 속에 새기고', '1'],
 ['내 모습에 한숨 쉬네 오랜만에 느껴지는 이 떨림이 날 단순하게 만들어', '1'],
 ['참 오래됐지 우리 서로 헤어진 지 나도 네가 없는 삶에 많이 익숙해졌어', '0'],
 ['욕심을 잃고 초심에 대한 촛농을 녹였지 얼마나 뜨거울지 몰라', '1'],
 ['아예 선을 그어 주던가 네가 나를 잡던가 잡힐 손을 주던가 오늘도 이렇게 너를 보낸다', '1'],
 ['이제 지난 일로 간직할게 매일 찾아오는 공허함을 버틴다는 게 어렵지만', '0']]

In [9]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model


In [10]:
## Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 5
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [11]:
transform = nlp.data.BERTSentenceTransform(tokenizer=tok, max_seq_length=max_len, pad=True, pair=False)  # 140

In [12]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len, pad, pair):
        transform = nlp.data.BERTSentenceTransform(bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
    
        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [13]:
data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_validataion = BERTDataset(dataset_validation, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

In [14]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
validation_dataloader = torch.utils.data.DataLoader(data_validataion, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [15]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [16]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

In [17]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [18]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [19]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [20]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [21]:
def calc_accuracy(X, Y):
    max_vals, max_indices = torch.max(X, 1)
    acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return acc

def calc_precision(X, Y):
    max_vals, max_indices = torch.max(X, 1)
    right = (max_indices == Y).sum().data.cpu().numpy()
    pre = right/max_indices.size()[0]
    return pre

def calc_recall(X, Y):
    max_vals, max_indices = torch.max(X, 1)
    right = (max_indices == Y).sum().data.cpu().numpy()
    re = right/max_indices.size()[0]
    return re

def calc_f1(X, Y):
    max_vals, max_indices = torch.max(X, 1)
    right = (max_indices == Y).sum().data.cpu().numpy()
    pre = right/max_indices.size()[0]
    
    right = (max_indices == Y).sum().data.cpu().numpy()
    re = right/max_indices.size()[0]
    
    f1 = 2*pre*re/(pre+re)
    
    return f1

In [21]:
for e in range(num_epochs):
    train_acc = 0.0
    val_acc = 0.0
    
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(train_dataloader):
        
        model.train()
        
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        
        train_acc += calc_accuracy(out, label)
        
        if batch_id % log_interval == 0:
            print('epoch {} batch id {}'.format(e+1, batch_id+1))
            print("Train : loss {:.3f} Acc {:.3f}".format(loss.data.cpu().numpy(), train_acc / (batch_id+1)) )
    
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(validation_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        val_acc += calc_accuracy(out, label)
        
    print("Val : Acc {:.3f} " .format(val_acc / (batch_id+1) ))

epoch 1 batch id 1
Train : loss 0.732 Acc 0.516
Val : Acc 0.715 
epoch 2 batch id 1
Train : loss 0.365 Acc 0.859
Val : Acc 0.749 
epoch 3 batch id 1
Train : loss 0.216 Acc 0.922
Val : Acc 0.739 
epoch 4 batch id 1
Train : loss 0.040 Acc 0.984
Val : Acc 0.743 
epoch 5 batch id 1
Train : loss 0.017 Acc 0.984
Val : Acc 0.754 


In [22]:
train_pred = []
val_pred = []
test_pred = []

model.eval()
# for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(train_dataloader):
#     token_ids = token_ids.long().to(device)
#     segment_ids = segment_ids.long().to(device)
#     valid_length= valid_length
#     label = label.long().to(device)
#     out = model(token_ids, valid_length, segment_ids)
#     max_vals, max_indices = torch.max(out, 1)
#     train_pred.extend(max_indices.cpu().tolist())
# 
# for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(validation_dataloader):
#     token_ids = token_ids.long().to(device)
#     segment_ids = segment_ids.long().to(device)
#     valid_length= valid_length
#     label = label.long().to(device)
#     out = model(token_ids, valid_length, segment_ids)
#     max_vals, max_indices = torch.max(out, 1)
#     val_pred.extend(max_indices.cpu().tolist())
    
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
    test_pred.extend(max_indices.cpu().tolist())

In [23]:
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

In [24]:
train_y = [sentence[1] for sentence in train]
val_y = [sentence[1] for sentence in val]
test_y = [sentence[1] for sentence in test]

In [25]:
# accuracy
test_acc = 0
for i in range(len(test_y)):
    if test_y[i] == test_pred[i]:
        test_acc += 1
        
print(test_acc/len(test_y))

0.750252780586451


In [26]:
# fl score

print(f1_score(train_y, train_pred, pos_label=0), f1_score(train_y, train_pred, pos_label=1))
print(f1_score(val_y, val_pred, pos_label=0), f1_score(val_y, val_pred, pos_label=1))
print(f1_score(test_y, test_pred, pos_label=0), f1_score(test_y, test_pred, pos_label=1))

0.9927143657762002 0.9931578947368421
0.7549751243781095 0.7525125628140703
0.7576054955839058 0.7424400417101147


In [27]:
# Precision

print(precision_score(train_y, train_pred, pos_label=0), precision_score(train_y, train_pred, pos_label=1))
print(precision_score(val_y, val_pred, pos_label=0), precision_score(val_y, val_pred, pos_label=1))
print(precision_score(test_y, test_pred, pos_label=0), precision_score(test_y, test_pred, pos_label=1))

0.9921583271097835 0.9936808846761453
0.7375455650060754 0.7709137709137709
0.7086903304773562 0.8012003000750187


In [28]:
# Recall

print(recall_score(train_y, train_pred, pos_label=0), recall_score(train_y, train_pred, pos_label=1))
print(recall_score(val_y, val_pred, pos_label=0), recall_score(val_y, val_pred, pos_label=1))
print(recall_score(test_y, test_pred, pos_label=0), recall_score(test_y, test_pred, pos_label=1))

0.9932710280373832 0.9926354550236718
0.7732484076433122 0.7349693251533742
0.8137737174982431 0.6917098445595855
