In [None]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3
!pip install torch

Collecting mxnet
  Downloading mxnet-1.8.0.post0-py2.py3-none-manylinux2014_x86_64.whl (46.9 MB)
[K     |████████████████████████████████| 46.9 MB 39 kB/s 
[?25hCollecting graphviz<0.9.0,>=0.8.1
  Downloading graphviz-0.8.4-py2.py3-none-any.whl (16 kB)
Installing collected packages: graphviz, mxnet
  Attempting uninstall: graphviz
    Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.8.4 mxnet-1.8.0.post0
Collecting gluonnlp
  Downloading gluonnlp-0.10.0.tar.gz (344 kB)
[K     |████████████████████████████████| 344 kB 4.1 MB/s 
Building wheels for collected packages: gluonnlp
  Building wheel for gluonnlp (setup.py) ... [?25l[?25hdone
  Created wheel for gluonnlp: filename=gluonnlp-0.10.0-cp37-cp37m-linux_x86_64.whl size=595721 sha256=fb940b699c28d0a7f86dcf2274dddde7fbc191559c52cdbf115afc25312645a9
  Stored in directory: /root/.cache/pip/wheels/be/b4/06/7f3fdfaf707e6b5e98b

In [None]:
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Collecting git+https://****@github.com/SKTBrain/KoBERT.git@master
  Cloning https://****@github.com/SKTBrain/KoBERT.git (to revision master) to /tmp/pip-req-build-qz0ex5sk
  Running command git clone -q 'https://****@github.com/SKTBrain/KoBERT.git' /tmp/pip-req-build-qz0ex5sk
Building wheels for collected packages: kobert
  Building wheel for kobert (setup.py) ... [?25l[?25hdone
  Created wheel for kobert: filename=kobert-0.1.2-py3-none-any.whl size=12771 sha256=15f578b06974a4ed9e8e59a046add7037d52b84315d59c087a845ceed0299c0a
  Stored in directory: /tmp/pip-ephem-wheel-cache-ii09ccub/wheels/d3/68/ca/334747dfb038313b49cf71f84832a33372f3470d9ddfd051c0
Successfully built kobert
Installing collected packages: kobert
Successfully installed kobert-0.1.2


In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

In [None]:
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

In [None]:
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [None]:
from google.colab import drive 
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
##GPU 사용 시
device = torch.device("cuda:0")

In [None]:
bertmodel, vocab = get_pytorch_kobert_model()

[██████████████████████████████████████████████████]
[██████████████████████████████████████████████████]


In [None]:
dataset_train = nlp.data.TSVDataset("/content/drive/My Drive/KoBERT/data/ratings_train.txt", field_indices=[1,2], num_discard_samples=1)
dataset_test = nlp.data.TSVDataset("/content/drive/My Drive/KoBERT/data/ratings_test.txt", field_indices=[1,2], num_discard_samples=1)

In [None]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model


In [None]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))


In [None]:
## Setting parameters
max_len = 50
batch_size = 8
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200 
learning_rate =  2e-5

In [None]:
data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

In [None]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=2) #This DataLoader will create 5 worker processes in total. Our suggested max number of worker in current system is 
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=2)

In [None]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes = 5,
                 dr_rate = None,
                 validation_split = 0.1,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [None]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

In [None]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [None]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [None]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [None]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [None]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [None]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 1.6405898332595825 train acc 0.0
epoch 1 batch id 201 loss 1.5612671375274658 train acc 0.20708955223880596
epoch 1 batch id 401 loss 1.6258141994476318 train acc 0.24283042394014961
epoch 1 batch id 601 loss 1.9992477893829346 train acc 0.2576955074875208
epoch 1 batch id 801 loss 1.2825086116790771 train acc 0.27137952559300876
epoch 1 batch id 1001 loss 1.4572073221206665 train acc 0.286963036963037
epoch 1 batch id 1201 loss 1.2909042835235596 train acc 0.2981890091590341
epoch 1 batch id 1401 loss 1.4052515029907227 train acc 0.3173625981441827
epoch 1 batch id 1601 loss 1.267577886581421 train acc 0.3392410993129294
epoch 1 batch id 1801 loss 0.8705451488494873 train acc 0.3600777345918934
epoch 1 batch id 2001 loss 1.3695752620697021 train acc 0.37831084457771114
epoch 1 batch id 2201 loss 1.3963947296142578 train acc 0.39561562925942756
epoch 1 batch id 2401 loss 0.7046957612037659 train acc 0.4111307788421491
epoch 1 batch id 2601 loss 0.956589341163635

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 1 test acc 0.6770510835913313


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 1.1523172855377197 train acc 0.625
epoch 2 batch id 201 loss 1.6484578847885132 train acc 0.6461442786069652
epoch 2 batch id 401 loss 0.7136577367782593 train acc 0.6574189526184538
epoch 2 batch id 601 loss 1.6484447717666626 train acc 0.6543261231281198
epoch 2 batch id 801 loss 0.5348793864250183 train acc 0.6546504369538078
epoch 2 batch id 1001 loss 0.9420403838157654 train acc 0.663086913086913
epoch 2 batch id 1201 loss 0.3755582273006439 train acc 0.6659034138218152
epoch 2 batch id 1401 loss 1.2437561750411987 train acc 0.6660421127765882
epoch 2 batch id 1601 loss 0.34097906947135925 train acc 0.6691130543410369
epoch 2 batch id 1801 loss 0.24843907356262207 train acc 0.6699750138811771
epoch 2 batch id 2001 loss 0.6651713848114014 train acc 0.6708520739630185
epoch 2 batch id 2201 loss 0.4480745792388916 train acc 0.6734438891412994
epoch 2 batch id 2401 loss 0.1786186397075653 train acc 0.6749791753436069
epoch 2 batch id 2601 loss 0.700694024562835

  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 2 test acc 0.7083010835913313


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 0.5001364946365356 train acc 0.875
epoch 3 batch id 201 loss 1.9257278442382812 train acc 0.7512437810945274
epoch 3 batch id 401 loss 0.8355361223220825 train acc 0.7456359102244389
epoch 3 batch id 601 loss 1.1108640432357788 train acc 0.7387687188019967
epoch 3 batch id 801 loss 0.31600677967071533 train acc 0.737203495630462
epoch 3 batch id 1001 loss 0.9690392017364502 train acc 0.741008991008991
epoch 3 batch id 1201 loss 0.1247125193476677 train acc 0.746044962531224
epoch 3 batch id 1401 loss 1.0719716548919678 train acc 0.7447359029264811
epoch 3 batch id 1601 loss 0.18712612986564636 train acc 0.7478919425359151
epoch 3 batch id 1801 loss 0.13584505021572113 train acc 0.7490283176013326
epoch 3 batch id 2001 loss 0.5977839827537537 train acc 0.7480009995002499
epoch 3 batch id 2201 loss 0.36026492714881897 train acc 0.749204906860518
epoch 3 batch id 2401 loss 0.19291743636131287 train acc 0.7504685547688463
epoch 3 batch id 2601 loss 0.729812741279602

  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 3 test acc 0.7117840557275542


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 0.5564899444580078 train acc 0.75
epoch 4 batch id 201 loss 1.668927788734436 train acc 0.8090796019900498
epoch 4 batch id 401 loss 0.6630154252052307 train acc 0.8042394014962594
epoch 4 batch id 601 loss 1.375739574432373 train acc 0.7982529118136439
epoch 4 batch id 801 loss 0.22085824608802795 train acc 0.8008739076154806
epoch 4 batch id 1001 loss 0.521197497844696 train acc 0.8039460539460539
epoch 4 batch id 1201 loss 0.08048097044229507 train acc 0.8040174854288094
epoch 4 batch id 1401 loss 0.40576231479644775 train acc 0.8052284082798001
epoch 4 batch id 1601 loss 0.03625427559018135 train acc 0.80683947532792
epoch 4 batch id 1801 loss 0.5939341187477112 train acc 0.8089255968906163
epoch 4 batch id 2001 loss 0.8439434170722961 train acc 0.8074087956021989
epoch 4 batch id 2201 loss 0.14604386687278748 train acc 0.8082121762835075
epoch 4 batch id 2401 loss 0.016203487291932106 train acc 0.8105476884631404
epoch 4 batch id 2601 loss 0.101847648620605

  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 4 test acc 0.7105263157894737


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 0.06812218576669693 train acc 1.0
epoch 5 batch id 201 loss 1.842697262763977 train acc 0.8526119402985075
epoch 5 batch id 401 loss 0.5589606761932373 train acc 0.850997506234414
epoch 5 batch id 601 loss 1.5368480682373047 train acc 0.8523294509151415
epoch 5 batch id 801 loss 0.08778081834316254 train acc 0.8503433208489388
epoch 5 batch id 1001 loss 1.3300259113311768 train acc 0.8538961038961039
epoch 5 batch id 1201 loss 0.03610724210739136 train acc 0.8527268942547876
epoch 5 batch id 1401 loss 0.20574960112571716 train acc 0.8523376159885796
epoch 5 batch id 1601 loss 0.012914900667965412 train acc 0.8542317301686446
epoch 5 batch id 1801 loss 0.24577420949935913 train acc 0.854247640199889
epoch 5 batch id 2001 loss 0.20767436921596527 train acc 0.8550099950024987
epoch 5 batch id 2201 loss 0.014286486431956291 train acc 0.8567696501590186
epoch 5 batch id 2401 loss 0.005589216016232967 train acc 0.858600583090379
epoch 5 batch id 2601 loss 0.0956419333

  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 5 test acc 0.7095588235294118


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 0.6219064593315125 train acc 0.875
epoch 6 batch id 201 loss 1.1460357904434204 train acc 0.8893034825870647
epoch 6 batch id 401 loss 0.7497628927230835 train acc 0.8871571072319202
epoch 6 batch id 601 loss 0.9874901175498962 train acc 0.8887271214642263
epoch 6 batch id 801 loss 0.027782650664448738 train acc 0.889669163545568
epoch 6 batch id 1001 loss 0.5561395883560181 train acc 0.8927322677322678
epoch 6 batch id 1201 loss 0.23982393741607666 train acc 0.8924854288093256
epoch 6 batch id 1401 loss 0.8774123787879944 train acc 0.8919521770164168
epoch 6 batch id 1601 loss 0.004254028666764498 train acc 0.8947532792004997
epoch 6 batch id 1801 loss 0.010343589819967747 train acc 0.8967934480843975
epoch 6 batch id 2001 loss 0.011462201364338398 train acc 0.8956771614192903
epoch 6 batch id 2201 loss 0.00793051440268755 train acc 0.8960131758291685
epoch 6 batch id 2401 loss 0.0034057912416756153 train acc 0.8974906289046231
epoch 6 batch id 2601 loss 0.0065

  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 6 test acc 0.7127515479876161


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 0.01580733433365822 train acc 1.0
epoch 7 batch id 201 loss 2.7588706016540527 train acc 0.9166666666666666
epoch 7 batch id 401 loss 0.8201982975006104 train acc 0.9180174563591023
epoch 7 batch id 601 loss 0.48886021971702576 train acc 0.9190931780366056
epoch 7 batch id 801 loss 0.01269913837313652 train acc 0.9165106117353309
epoch 7 batch id 1001 loss 0.6909675002098083 train acc 0.9194555444555444
epoch 7 batch id 1201 loss 0.001901162089779973 train acc 0.9197543713572023
epoch 7 batch id 1401 loss 0.14806215465068817 train acc 0.920146324054247
epoch 7 batch id 1601 loss 0.003919568844139576 train acc 0.9216895690193629
epoch 7 batch id 1801 loss 0.001222037710249424 train acc 0.923375902276513
epoch 7 batch id 2001 loss 0.024489905685186386 train acc 0.9234132933533233
epoch 7 batch id 2201 loss 0.003901519114151597 train acc 0.9234438891412994
epoch 7 batch id 2401 loss 0.0018391464836895466 train acc 0.9252394835485215
epoch 7 batch id 2601 loss 0.003

  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 7 test acc 0.7132352941176471


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 0.006559649482369423 train acc 1.0
epoch 8 batch id 201 loss 1.6405693292617798 train acc 0.945273631840796
epoch 8 batch id 401 loss 0.9262850880622864 train acc 0.9448254364089775
epoch 8 batch id 601 loss 1.6637201309204102 train acc 0.94238768718802
epoch 8 batch id 801 loss 0.004604465328156948 train acc 0.9399188514357054
epoch 8 batch id 1001 loss 0.0025402253959327936 train acc 0.9428071928071928
epoch 8 batch id 1201 loss 0.0017522505950182676 train acc 0.9429641965029142
epoch 8 batch id 1401 loss 0.7333701848983765 train acc 0.9448608137044968
epoch 8 batch id 1601 loss 0.00252150883898139 train acc 0.945737039350406
epoch 8 batch id 1801 loss 0.0021679752971976995 train acc 0.9461410327595781
epoch 8 batch id 2001 loss 0.0011507621966302395 train acc 0.9465267366316842
epoch 8 batch id 2201 loss 0.0020558033138513565 train acc 0.946501590186279
epoch 8 batch id 2401 loss 0.0011278034653514624 train acc 0.9472615576842982
epoch 8 batch id 2601 loss 0.

  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 8 test acc 0.7155572755417957


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 0.0013052605791017413 train acc 1.0
epoch 9 batch id 201 loss 0.989808976650238 train acc 0.9713930348258707
epoch 9 batch id 401 loss 0.5968489050865173 train acc 0.9669576059850374
epoch 9 batch id 601 loss 0.8203917741775513 train acc 0.9633943427620633
epoch 9 batch id 801 loss 0.0014220543671399355 train acc 0.9605181023720349
epoch 9 batch id 1001 loss 0.6590324640274048 train acc 0.9616633366633367
epoch 9 batch id 1201 loss 0.0014990530908107758 train acc 0.9614904246461282
epoch 9 batch id 1401 loss 0.0012664045207202435 train acc 0.9620806566738044
epoch 9 batch id 1601 loss 0.001021787989884615 train acc 0.9626795752654591
epoch 9 batch id 1801 loss 0.0005523587460629642 train acc 0.9621043864519712
epoch 9 batch id 2001 loss 0.0015372599009424448 train acc 0.9617066466766616
epoch 9 batch id 2201 loss 0.29711779952049255 train acc 0.9609268514311676
epoch 9 batch id 2401 loss 0.0010886609088629484 train acc 0.9615785089546023
epoch 9 batch id 2601 lo

  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 9 test acc 0.7161377708978328


  0%|          | 0/5164 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 0.0016458046156913042 train acc 1.0
epoch 10 batch id 201 loss 1.721928596496582 train acc 0.9720149253731343
epoch 10 batch id 401 loss 0.09381672739982605 train acc 0.9713216957605985
epoch 10 batch id 601 loss 1.1790534257888794 train acc 0.9702579034941764
epoch 10 batch id 801 loss 0.0007453353609889746 train acc 0.967852684144819
epoch 10 batch id 1001 loss 0.8331896066665649 train acc 0.9685314685314685
epoch 10 batch id 1201 loss 0.0019233082421123981 train acc 0.9685678601165695
epoch 10 batch id 1401 loss 0.0027504386380314827 train acc 0.9688615274803711
epoch 10 batch id 1601 loss 0.0009082970209419727 train acc 0.969628357276702
epoch 10 batch id 1801 loss 0.00031762165599502623 train acc 0.9693225985563576
epoch 10 batch id 2001 loss 0.0012755015632137656 train acc 0.9685157421289355
epoch 10 batch id 2201 loss 0.0010276641696691513 train acc 0.9679123125851885
epoch 10 batch id 2401 loss 0.0009892748203128576 train acc 0.9678259058725531
epoch 10

  0%|          | 0/1292 [00:00<?, ?it/s]

epoch 10 test acc 0.7154605263157895


# 모델 저장 


In [None]:
PATH_PT = '/content/drive/My Drive/KoBERT/model/model_emo5.pt'
torch.save(model.state_dict(), PATH_PT)

In [None]:
model_PT = BERTClassifier(bertmodel,  dr_rate=0.5) #bertmodel,  dr_rate=0.5
model_PT.load_state_dict(torch.load(PATH_PT))
model_PT.eval()

BERTClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(8002, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True