## 1. Library Import

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [2]:
##GPU 사용 시
device = torch.device("cuda:0")
device

device(type='cuda', index=0)

## Get Model / Tokenizer

In [3]:
# bert model
# vocab - Vocab(size=8002, unk="[UNK]", reserved="['[CLS]', '[SEP]', '[MASK]', '[PAD]']")
bertmodel, vocab = get_pytorch_kobert_model()

tokenizer = get_tokenizer() # str => /home/neuralworks/kobert/kobert_news_wiki_ko_cased-1087f8699e.spiece
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

# tok('안녕하세요. 저는 김선민 입니다')

using cached model
using cached model
using cached model


## Get Dataset

In [4]:
dataset_train = nlp.data.TSVDataset("data/ratings_train.txt", field_indices=[1,2], num_discard_samples=1)
dataset_test = nlp.data.TSVDataset("data/ratings_test.txt", field_indices=[1,2], num_discard_samples=1)
print(type(dataset_train))

<class 'gluonnlp.data.dataset.TSVDataset'>


In [5]:
max_len = 64 # 텍스트 데이터 최대 길이
batch_size = 32

In [6]:
class BERTDataset(Dataset):
    # dataset, 0, 1, tokenizer, max_len, True, False
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len, pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

## Modeling, Hyper Parameter

In [16]:
## Setting parameters
warmup_ratio = 0.1
num_epochs = 5
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [17]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2, # 긍정 or 부정
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)
    
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

In [18]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# optimizer, loss
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

# warmup scheduler
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

# calc accuracy
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

## Training

In [19]:
for e in range(num_epochs):
    train_acc, test_acc = 0.0, 0.0
    # train
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    
    # evaluate
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    torch.save(model, 'kobert-nsmc-'+str(e)+'.pt')
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):


  0%|          | 0/4688 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 1 batch id 1 loss 0.6857167482376099 train acc 0.65625
epoch 1 batch id 201 loss 0.3573457896709442 train acc 0.7815609452736318
epoch 1 batch id 401 loss 0.21133974194526672 train acc 0.8485037406483791
epoch 1 batch id 601 loss 0.16963686048984528 train acc 0.8707882695507487
epoch 1 batch id 801 loss 0.34017372131347656 train acc 0.8815933208489388
epoch 1 batch id 1001 loss 0.37174472212791443 train acc 0.887987012987013
epoch 1 batch id 1201 loss 0.46989017724990845 train acc 0.8932920482930891
epoch 1 batch id 1401 loss 0.23267492651939392 train acc 0.8973724125624554
epoch 1 

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):


  0%|          | 0/1563 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 1 test acc 0.8825175943698017


  0%|          | 0/4688 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 2 batch id 1 loss 0.583040177822113 train acc 0.8125
epoch 2 batch id 201 loss 0.2601645588874817 train acc 0.8849502487562189
epoch 2 batch id 401 loss 0.28650760650634766 train acc 0.8901184538653366
epoch 2 batch id 601 loss 0.135768324136734 train acc 0.8953826955074875
epoch 2 batch id 801 loss 0.3039056062698364 train acc 0.9001638576779026
epoch 2 batch id 1001 loss 0.30369770526885986 train acc 0.9031905594405595
epoch 2 batch id 1201 loss 0.5312280654907227 train acc 0.9064581598667777
epoch 2 batch id 1401 loss 0.234229177236557 train acc 0.9085028551034975
epoch 2 batch i

  0%|          | 0/1563 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 2 test acc 0.8856365962891874


  0%|          | 0/4688 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 3 batch id 1 loss 0.6300756335258484 train acc 0.84375
epoch 3 batch id 201 loss 0.1997765600681305 train acc 0.9264614427860697
epoch 3 batch id 401 loss 0.13893330097198486 train acc 0.9281483790523691
epoch 3 batch id 601 loss 0.10090045630931854 train acc 0.9331322795341098
epoch 3 batch id 801 loss 0.2151477187871933 train acc 0.9355493133583022
epoch 3 batch id 1001 loss 0.24839629232883453 train acc 0.9371565934065934
epoch 3 batch id 1201 loss 0.45139408111572266 train acc 0.9390091590341382
epoch 3 batch id 1401 loss 0.06209949031472206 train acc 0.9403327980014276
epoch 3 

  0%|          | 0/1563 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 3 test acc 0.8914747280870121


  0%|          | 0/4688 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 4 batch id 1 loss 0.5689281225204468 train acc 0.84375
epoch 4 batch id 201 loss 0.14891383051872253 train acc 0.9500932835820896
epoch 4 batch id 401 loss 0.06914684176445007 train acc 0.9559694513715711
epoch 4 batch id 601 loss 0.018289947882294655 train acc 0.9597545757071547
epoch 4 batch id 801 loss 0.09229589253664017 train acc 0.9620006242197253
epoch 4 batch id 1001 loss 0.08820468187332153 train acc 0.963317932067932
epoch 4 batch id 1201 loss 0.23846440017223358 train acc 0.9648209825145712
epoch 4 batch id 1401 loss 0.18070903420448303 train acc 0.9654264810849393
epoch 

  0%|          | 0/1563 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 4 test acc 0.8943338131797824


  0%|          | 0/4688 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 5 batch id 1 loss 0.4382215142250061 train acc 0.875
epoch 5 batch id 201 loss 0.02459258772432804 train acc 0.970771144278607
epoch 5 batch id 401 loss 0.07532317191362381 train acc 0.9747506234413965
epoch 5 batch id 601 loss 0.004637721460312605 train acc 0.9773294509151415
epoch 5 batch id 801 loss 0.1279895156621933 train acc 0.9777621722846442
epoch 5 batch id 1001 loss 0.1061367616057396 train acc 0.9782405094905094
epoch 5 batch id 1201 loss 0.25959092378616333 train acc 0.979131973355537
epoch 5 batch id 1401 loss 0.007385326083749533 train acc 0.979768915060671
epoch 5 bat

  0%|          | 0/1563 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
The current process just got forked. Disabling parallelism to avoid deadlocks...
epoch 5 test acc 0.8936140435060781
