In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook


In [3]:
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

In [4]:
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
#from transformers.optimization import WarmupLinearSchedule

In [5]:
##GPU 사용 시
device = torch.device("cuda:0")

In [6]:
bertmodel, vocab = get_pytorch_kobert_model()

using cached model
using cached model


In [7]:
dataset_train = nlp.data.TSVDataset("./kobert_train.tsv", field_indices=[1, 0], num_discard_samples=1)
dataset_test = nlp.data.TSVDataset("./kobert_test.tsv", field_indices=[1,0], num_discard_samples=1)

In [8]:
dataset_train[6]

['게임 개발 중 사용자의 수행수준이 중요한 변인이라면 사용자의 이전 경험을 고려해야 한다는 결론을 도출할 수 있다', '0']

In [9]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model


In [10]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]
        
    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))


In [11]:
## Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [12]:
data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

In [13]:
len(data_train[10])

4

In [14]:
data_train[10]

(array([   2, 3301, 6903, 4591,  517, 7963, 6631, 4645, 7095, 1568,  522,
         517, 5845, 6438, 6139, 7095,  968, 5760, 4814, 5468, 1568,  517,
          40,  517, 6896,  705,  446,  380,  389, 5086, 7095,  517,  430,
         405,  393,  390,  517,  380,  427,  449,  388,  410,  380,  427,
         457,  522,  680,  270,  270,  517,   40,  517, 6116,  768, 4297,
        7921, 5468, 4591,  517, 7963, 6633, 4616, 7782,    3], dtype=int32),
 array(64, dtype=int32),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       dtype=int32),
 3)

In [15]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [16]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=9,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [17]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

In [18]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [19]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [20]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [21]:
#scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_step, t_total=t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [22]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [23]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 1 batch id 1 loss 2.128507375717163 train acc 0.171875
epoch 1 batch id 201 loss 1.4840689897537231 train acc 0.2966417910447761
epoch 1 batch id 401 loss 0.9187043309211731 train acc 0.4617752493765586
epoch 1 batch id 601 loss 0.9373445510864258 train acc 0.5405054076539102

epoch 1 train acc 0.5799508722152165


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 1 test acc 0.7282281945441493


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 2 batch id 1 loss 0.6289991140365601 train acc 0.796875
epoch 2 batch id 201 loss 0.5713232755661011 train acc 0.7159514925373134
epoch 2 batch id 401 loss 0.6675209999084473 train acc 0.7192565461346634
epoch 2 batch id 601 loss 0.7053610682487488 train acc 0.7299552828618968

epoch 2 train acc 0.738134019896315


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 2 test acc 0.7358444005743001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 3 batch id 1 loss 0.5159103274345398 train acc 0.828125
epoch 3 batch id 201 loss 0.4267199635505676 train acc 0.7755752487562189
epoch 3 batch id 401 loss 0.5876497030258179 train acc 0.7788341645885287
epoch 3 batch id 601 loss 0.5810011625289917 train acc 0.7867876455906821

epoch 3 train acc 0.7924982860945976


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 3 test acc 0.7385925161521896


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 4 batch id 1 loss 0.3897703289985657 train acc 0.828125
epoch 4 batch id 201 loss 0.37464627623558044 train acc 0.8254819651741293
epoch 4 batch id 401 loss 0.4973014295101166 train acc 0.8244622817955112
epoch 4 batch id 601 loss 0.46402624249458313 train acc 0.8341046173044925

epoch 4 train acc 0.8395900263215837


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 4 test acc 0.7311333452979182


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 5 batch id 1 loss 0.20448654890060425 train acc 0.9375
epoch 5 batch id 201 loss 0.1971641182899475 train acc 0.8725124378109452
epoch 5 batch id 401 loss 0.3737642168998718 train acc 0.8675187032418953
epoch 5 batch id 601 loss 0.2685333490371704 train acc 0.8744280366056573

epoch 5 train acc 0.8784468889489382


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 5 test acc 0.7275888370423547


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 6 batch id 1 loss 0.22391463816165924 train acc 0.921875
epoch 6 batch id 201 loss 0.16189022362232208 train acc 0.9033737562189055
epoch 6 batch id 401 loss 0.23484256863594055 train acc 0.9018859102244389
epoch 6 batch id 601 loss 0.21354590356349945 train acc 0.9078099001663894

epoch 6 train acc 0.9111355686663064


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 6 test acc 0.7309650933237617


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 7 batch id 1 loss 0.1299387365579605 train acc 0.96875
epoch 7 batch id 201 loss 0.15183743834495544 train acc 0.9333022388059702
epoch 7 batch id 401 loss 0.13080507516860962 train acc 0.9300966334164589
epoch 7 batch id 601 loss 0.1544337272644043 train acc 0.9335222545757071

epoch 7 train acc 0.935469585059749


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 7 test acc 0.7304939877961235


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 8 batch id 1 loss 0.0873015895485878 train acc 0.984375
epoch 8 batch id 201 loss 0.0500311441719532 train acc 0.9494713930348259
epoch 8 batch id 401 loss 0.08589717000722885 train acc 0.9484881546134664
epoch 8 batch id 601 loss 0.18503877520561218 train acc 0.9509931364392679

epoch 8 train acc 0.952749692247643


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 8 test acc 0.7319858219669778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 9 batch id 1 loss 0.10757728666067123 train acc 0.96875
epoch 9 batch id 201 loss 0.07786057144403458 train acc 0.9640080845771144
epoch 9 batch id 401 loss 0.10698048770427704 train acc 0.9611518079800498
epoch 9 batch id 601 loss 0.10282407701015472 train acc 0.9634723377703827

epoch 9 train acc 0.9656361966812786


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 9 test acc 0.7302584350323045


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=793.0), HTML(value='')))

epoch 10 batch id 1 loss 0.022513672709465027 train acc 0.984375
epoch 10 batch id 201 loss 0.07011513411998749 train acc 0.9713930348258707
epoch 10 batch id 401 loss 0.031669437885284424 train acc 0.9711268703241895
epoch 10 batch id 601 loss 0.06740203499794006 train acc 0.9732737104825291

epoch 10 train acc 0.9746019861286255


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=199.0), HTML(value='')))


epoch 10 test acc 0.7319073043790381


In [None]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc