In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import sys

print(os.getcwd())

/content


In [None]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3.0.2
!pip install torch

Collecting mxnet
  Downloading mxnet-1.8.0.post0-py2.py3-none-manylinux2014_x86_64.whl (46.9 MB)
[K     |████████████████████████████████| 46.9 MB 72 kB/s 
Collecting graphviz<0.9.0,>=0.8.1
  Downloading graphviz-0.8.4-py2.py3-none-any.whl (16 kB)
Installing collected packages: graphviz, mxnet
  Attempting uninstall: graphviz
    Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.8.4 mxnet-1.8.0.post0
Collecting gluonnlp
  Downloading gluonnlp-0.10.0.tar.gz (344 kB)
[K     |████████████████████████████████| 344 kB 12.9 MB/s 
Building wheels for collected packages: gluonnlp
  Building wheel for gluonnlp (setup.py) ... [?25l[?25hdone
  Created wheel for gluonnlp: filename=gluonnlp-0.10.0-cp37-cp37m-linux_x86_64.whl size=595721 sha256=dfe6311b216f19570c559ccb1f2649568a3fc8fe9f354727121a3356159820d1
  Stored in directory: /root/.cache/pip/wheels/be/b4/06/7f3fdfaf707e6b5e98b79c04

In [None]:
#깃허브에서 KoBERT 파일 로드
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Collecting git+https://****@github.com/SKTBrain/KoBERT.git@master
  Cloning https://****@github.com/SKTBrain/KoBERT.git (to revision master) to /tmp/pip-req-build-791zzreb
  Running command git clone -q 'https://****@github.com/SKTBrain/KoBERT.git' /tmp/pip-req-build-791zzreb
Building wheels for collected packages: kobert
  Building wheel for kobert (setup.py) ... [?25l[?25hdone
  Created wheel for kobert: filename=kobert-0.1.2-py3-none-any.whl size=12770 sha256=5c9944bbc984c1f06e7da857d6471f3932827e5ccb2fd584d725a555dc209609
  Stored in directory: /tmp/pip-ephem-wheel-cache-7aiyl862/wheels/d3/68/ca/334747dfb038313b49cf71f84832a33372f3470d9ddfd051c0
Successfully built kobert
Installing collected packages: kobert
Successfully installed kobert-0.1.2


# Data Preparation and Training

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

In [None]:
#kobert
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

#transformers
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [None]:
#GPU 사용
device = torch.device("cuda:0")

In [None]:
#BERT 모델, Vocabulary 불러오기
bertmodel, vocab = get_pytorch_kobert_model()

[██████████████████████████████████████████████████]
[██████████████████████████████████████████████████]


In [None]:
import pandas as pd
chatbot_data = pd.read_csv('/content/drive/MyDrive/test/chatbot_sentiment_dataset.csv')

In [None]:
chatbot_data.sample(n=10)

Unnamed: 0.1,Unnamed: 0,Sentence,Emotion
215838,183142,"갑자기라기보다, 처음부터 전제를 잘못 잡고 있었어. 걔는 학생회장을 하기 싫었던 게...",중립
154238,121542,하는 일 없이 노래교실에 다니는 친구와 내 처지가 비교 돼서 너무 화가 나.맞아. ...,분노
30735,36633,평화로운집회니 음악이어우러지니 어쩌니하니 닭은 진짜 축제같은건줄아나봄?,혐오
115229,82533,그가 나와의 결혼식을 며칠 앞두고 파혼을 요청해와서 상처받았어.응. 갑자기 파혼을 ...,슬픔
112504,79808,중년이 되니 부부 사이도 애틋한 거보단 각자 삶이 있는 것 같아.남편이 어느 순간 ...,당황
94601,61905,아들이 죽음에 대해서 질문하는데 내가 다 숙연해지더라.그렇지. 아들보다 내가 죽음에...,불안
206929,174233,포기한거 아냐.,중립
13592,19490,승진이빠르던가 예쁜여친이있던지 될사람은 되,슬픔
82886,50190,면접을 보러 갔는데 사무실에 간판이 없어서 혼란스러웠어.간판이 없는 것도 그렇고 면...,불안
190872,158176,의사선생한테 오늘도 연락 없어?,중립


In [None]:
chatbot_data.loc[(chatbot_data['Emotion'] == "불안"), 'Emotion'] = 0  #불안 => 0
chatbot_data.loc[(chatbot_data['Emotion'] == "당황"), 'Emotion'] = 1  #당황 => 1
chatbot_data.loc[(chatbot_data['Emotion'] == "분노"), 'Emotion'] = 2  #분노 => 2
chatbot_data.loc[(chatbot_data['Emotion'] == "슬픔"), 'Emotion'] = 3  #슬픔 => 3
chatbot_data.loc[(chatbot_data['Emotion'] == "중립"), 'Emotion'] = 4  #중립 => 4
chatbot_data.loc[(chatbot_data['Emotion'] == "행복"), 'Emotion'] = 5  #행복 => 5
chatbot_data.loc[(chatbot_data['Emotion'] == "혐오"), 'Emotion'] = 6  #혐오 => 6

data_list = []
for q, label in zip(chatbot_data['Sentence'], chatbot_data['Emotion'])  :
    data = []
    data.append(q)
    data.append(str(label))

    data_list.append(data)

In [None]:
print(data_list[0])
print(data_list[6000])
print(data_list[12000])
print(data_list[18000])
print(data_list[24000])
print(data_list[30000])
print(data_list[-1])

['언니 동생으로 부르는게 맞는 일인가요..??', '0']
['우리는 미개 합니다', '2']
['지금 내모습은 컴퓨터앞에서 하루종일 취업사이트만 검색중...', '3']
['쉬는 시간에 오답이나 맞추고 있으면 모교로 돌아가서 중간고사 기말고사나 봐야지', '4']
[' 빨리빨리 파악해서 바로 고쳐주니깐 이러니 고객들이 믿고 살수밖에..', '5']
['미안한데 니들은 인간이 어떻게 이런말 쓰지마라', '6']
['그 여자랑 내가 무슨 상관인데? 아까는 탐정님이 부탁하기에 너 구하는 김에 주워왔지만, 민폐니까 얼른 나가.', '4']


In [None]:
#train & test 데이터로 나누기
from sklearn.model_selection import train_test_split
                                                         
dataset_train, dataset_test = train_test_split(data_list, test_size=0.13, random_state=0)

In [None]:
print(len(dataset_train))
print(len(dataset_test))

190109
28408


In [None]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [None]:
#Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 15
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [None]:
#토큰화
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

using cached model


In [None]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=4)

In [None]:
# for data in train_dataloader:
#   print(data)
#   break

In [None]:
#kobert 학습모델 만들기
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=7,   ##클래스 수 조정##
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [None]:
#BERT 모델 불러오기
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

#optimizer와 schedule 설정
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

#정확도 측정을 위한 함수 정의
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc
    
# train_dataloader

In [None]:
best_acc=0.0
best_loss=99999999
ckpt_path="/content/drive/MyDrive/test/" #your own path
ckpt_name=ckpt_path+"saved_model.pt"

In [55]:
#kobert 모델 학습시키기
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        # scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_loss=loss_fn(out,label)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {} test loss {}".format(e+1, test_acc / (batch_id+1),test_loss.data.cpu().numpy()))

    if test_acc>best_acc and test_loss.data.cpu().numpy()<best_loss:
      torch.save({'epoch':e+1,
                  'model_state_dict':model.state_dict(),
                  'optimizer_state_dict':optimizer.state_dict(),
                  'loss':test_loss.data.cpu().numpy()},
                 ckpt_name)
      best_loss=test_loss.data.cpu().numpy()
      bset_acc=test_acc
      
      print('current best model saved')
    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 1 batch id 1 loss 0.16824722290039062 train acc 0.96875
epoch 1 batch id 201 loss 0.07003601640462875 train acc 0.9851523631840796
epoch 1 batch id 401 loss 0.0038964352570474148 train acc 0.9832839775561097
epoch 1 batch id 601 loss 0.10422869771718979 train acc 0.983153078202995
epoch 1 batch id 801 loss 0.01879844442009926 train acc 0.9830095193508115
epoch 1 batch id 1001 loss 0.026083270087838173 train acc 0.9826891858141859
epoch 1 batch id 1201 loss 0.048541221767663956 train acc 0.982826810990841
epoch 1 batch id 1401 loss 0.0507928729057312 train acc 0.9830032119914347
epoch 1 batch id 1601 loss 0.00401336932554841 train acc 0.9828915521549032
epoch 1 batch id 1801 loss 0.0845431387424469 train acc 0.9827005830094392
epoch 1 batch id 2001 loss 0.015039879828691483 train acc 0.9827351949025487
epoch 1 batch id 2201 loss 0.03284340724349022 train acc 0.9826073375738301
epoch 1 batch id 2401 loss 0.05721035227179527 train acc 0.9824357038733861
epoch 1 batch id 2601 loss 0.

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=444.0), HTML(value='')))


epoch 1 test acc 0.7747043918918919 test loss 1.5845586061477661
current best model saved


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 2 batch id 1 loss 0.13414990901947021 train acc 0.984375
epoch 2 batch id 201 loss 0.03162827715277672 train acc 0.9853855721393034
epoch 2 batch id 401 loss 0.002318520564585924 train acc 0.9831670822942643
epoch 2 batch id 601 loss 0.10684362798929214 train acc 0.9833610648918469
epoch 2 batch id 801 loss 0.06144106388092041 train acc 0.9830095193508115
epoch 2 batch id 1001 loss 0.010969359427690506 train acc 0.9830482017982018
epoch 2 batch id 1201 loss 0.0868084654211998 train acc 0.9831130308076603
epoch 2 batch id 1401 loss 0.11100742965936661 train acc 0.9833600999286224
epoch 2 batch id 1601 loss 0.005947912577539682 train acc 0.9832331355402874
epoch 2 batch id 1801 loss 0.022519104182720184 train acc 0.9828914491948917
epoch 2 batch id 2001 loss 0.006647337228059769 train acc 0.9831959020489756
epoch 2 batch id 2201 loss 0.05384117737412453 train acc 0.9832391526578828
epoch 2 batch id 2401 loss 0.042004507035017014 train acc 0.9830474281549354
epoch 2 batch id 2601 lo

HBox(children=(FloatProgress(value=0.0, max=444.0), HTML(value='')))


epoch 2 test acc 0.7760064752252253 test loss 1.6062480211257935


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 3 batch id 1 loss 0.14825627207756042 train acc 0.96875
epoch 3 batch id 201 loss 0.05484083667397499 train acc 0.9854633084577115
epoch 3 batch id 401 loss 0.033072542399168015 train acc 0.9840632793017456
epoch 3 batch id 601 loss 0.07665172219276428 train acc 0.9836210482529119
epoch 3 batch id 801 loss 0.004734979942440987 train acc 0.9841799313358303
epoch 3 batch id 1001 loss 0.0023480127565562725 train acc 0.9843437812187812
epoch 3 batch id 1201 loss 0.046262890100479126 train acc 0.9844270399666945
epoch 3 batch id 1401 loss 0.03865916281938553 train acc 0.9843526945039258
epoch 3 batch id 1601 loss 0.018172040581703186 train acc 0.9843554809494066
epoch 3 batch id 1801 loss 0.038168665021657944 train acc 0.9840366463076069
epoch 3 batch id 2001 loss 0.017210708931088448 train acc 0.9842188280859571
epoch 3 batch id 2201 loss 0.018611472100019455 train acc 0.9842117219445706
epoch 3 batch id 2401 loss 0.08401986956596375 train acc 0.9841732611411912
epoch 3 batch id 2601

HBox(children=(FloatProgress(value=0.0, max=444.0), HTML(value='')))


epoch 3 test acc 0.7745284346846847 test loss 1.6882390975952148


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 4 batch id 1 loss 0.12162604928016663 train acc 0.984375
epoch 4 batch id 201 loss 0.056494373828172684 train acc 0.9867848258706468
epoch 4 batch id 401 loss 0.0014850989682599902 train acc 0.9855439526184538
epoch 4 batch id 601 loss 0.05209595710039139 train acc 0.9852329450915142
epoch 4 batch id 801 loss 0.02535257861018181 train acc 0.9853113295880149
epoch 4 batch id 1001 loss 0.0019128206185996532 train acc 0.9851710789210789
epoch 4 batch id 1201 loss 0.029009295627474785 train acc 0.9851816194837635
epoch 4 batch id 1401 loss 0.04513497278094292 train acc 0.9851556923625981
epoch 4 batch id 1601 loss 0.0016212810296565294 train acc 0.9850874453466584
epoch 4 batch id 1801 loss 0.005215710960328579 train acc 0.9848434897279289
epoch 4 batch id 2001 loss 0.004297970328480005 train acc 0.9850699650174912
epoch 4 batch id 2201 loss 0.029412290081381798 train acc 0.9849358246251704
epoch 4 batch id 2401 loss 0.09075654298067093 train acc 0.9847068929612661
epoch 4 batch id 2

HBox(children=(FloatProgress(value=0.0, max=444.0), HTML(value='')))


epoch 4 test acc 0.774634009009009 test loss 1.7536518573760986


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 5 batch id 1 loss 0.1498275101184845 train acc 0.953125
epoch 5 batch id 201 loss 0.017048435285687447 train acc 0.9868625621890548
epoch 5 batch id 401 loss 0.0044841887429356575 train acc 0.9855439526184538
epoch 5 batch id 601 loss 0.05258815735578537 train acc 0.9854149334442596
epoch 5 batch id 801 loss 0.02362985722720623 train acc 0.985330836454432
epoch 5 batch id 1001 loss 0.0031217888463288546 train acc 0.9850774225774226
epoch 5 batch id 1201 loss 0.024736057966947556 train acc 0.9853507493755204
epoch 5 batch id 1401 loss 0.046498578041791916 train acc 0.9853229835831548
epoch 5 batch id 1601 loss 0.0018912128871306777 train acc 0.9852045596502186
epoch 5 batch id 1801 loss 0.03898327425122261 train acc 0.9849736257634647
epoch 5 batch id 2001 loss 0.019201332703232765 train acc 0.9850777736131934
epoch 5 batch id 2201 loss 0.04513823613524437 train acc 0.9850423103134939
epoch 5 batch id 2401 loss 0.11347952485084534 train acc 0.9849216472303207
epoch 5 batch id 2601

HBox(children=(FloatProgress(value=0.0, max=444.0), HTML(value='')))


epoch 5 test acc 0.7765695382882883 test loss 1.7369470596313477


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 6 batch id 1 loss 0.11702898889780045 train acc 0.984375
epoch 6 batch id 201 loss 0.08267038315534592 train acc 0.9867848258706468
epoch 6 batch id 401 loss 0.001193334348499775 train acc 0.9854270573566085
epoch 6 batch id 601 loss 0.05787070468068123 train acc 0.9856749168053245
epoch 6 batch id 801 loss 0.07862555980682373 train acc 0.9853893570536829
epoch 6 batch id 1001 loss 0.0018106624484062195 train acc 0.98565497002997
epoch 6 batch id 1201 loss 0.013285724446177483 train acc 0.9859101790174855
epoch 6 batch id 1401 loss 0.04179564490914345 train acc 0.9860144539614561
epoch 6 batch id 1601 loss 0.004630193114280701 train acc 0.9860146002498439
epoch 6 batch id 1801 loss 0.004789506085216999 train acc 0.9860147140477512
epoch 6 batch id 2001 loss 0.025103826075792313 train acc 0.9860148050974513
epoch 6 batch id 2201 loss 0.010947881266474724 train acc 0.9860645729213994
epoch 6 batch id 2401 loss 0.003624905366450548 train acc 0.9858847875885048
epoch 6 batch id 2601 

HBox(children=(FloatProgress(value=0.0, max=444.0), HTML(value='')))


epoch 6 test acc 0.7755841779279279 test loss 1.7973835468292236


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 7 batch id 1 loss 0.1160278171300888 train acc 0.984375
epoch 7 batch id 201 loss 0.021170740947127342 train acc 0.9884172885572139
epoch 7 batch id 401 loss 0.0012499522417783737 train acc 0.9869856608478803
epoch 7 batch id 601 loss 0.08000056445598602 train acc 0.986870840266223
epoch 7 batch id 801 loss 0.005495088640600443 train acc 0.9869889200998752
epoch 7 batch id 1001 loss 0.013487214222550392 train acc 0.9866851898101898
epoch 7 batch id 1201 loss 0.006150592118501663 train acc 0.9867167985012489
epoch 7 batch id 1401 loss 0.04026871174573898 train acc 0.9868397573162027
epoch 7 batch id 1601 loss 0.0020684751216322184 train acc 0.9866001717676453
epoch 7 batch id 1801 loss 0.016393257305026054 train acc 0.9864658523042754
epoch 7 batch id 2001 loss 0.009176122024655342 train acc 0.9866160669665167
epoch 7 batch id 2201 loss 0.005472928285598755 train acc 0.9867176851431168
epoch 7 batch id 2401 loss 0.017107408493757248 train acc 0.9864444502290712
epoch 7 batch id 26

HBox(children=(FloatProgress(value=0.0, max=444.0), HTML(value='')))


epoch 7 test acc 0.774387668918919 test loss 1.8343675136566162


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 8 batch id 1 loss 0.08483948558568954 train acc 0.984375
epoch 8 batch id 201 loss 0.029725877568125725 train acc 0.9873289800995025
epoch 8 batch id 401 loss 0.0013092367444187403 train acc 0.9860115336658354
epoch 8 batch id 601 loss 0.08946385234594345 train acc 0.986610856905158
epoch 8 batch id 801 loss 0.005521014332771301 train acc 0.9866377965043696
epoch 8 batch id 1001 loss 0.008900640532374382 train acc 0.9867007992007992
epoch 8 batch id 1201 loss 0.021447181701660156 train acc 0.9867688384679434
epoch 8 batch id 1401 loss 0.03404104337096214 train acc 0.9869066738044254
epoch 8 batch id 1601 loss 0.06174039840698242 train acc 0.9869515146783261
epoch 8 batch id 1801 loss 0.009239567443728447 train acc 0.9866827456968351
epoch 8 batch id 2001 loss 0.03830472007393837 train acc 0.9866394927536232
epoch 8 batch id 2201 loss 0.032598838210105896 train acc 0.986604100408905
epoch 8 batch id 2401 loss 0.1191234439611435 train acc 0.9865355581007913
epoch 8 batch id 2601 lo

HBox(children=(FloatProgress(value=0.0, max=444.0), HTML(value='')))


epoch 8 test acc 0.775267454954955 test loss 1.8524469137191772


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 9 batch id 1 loss 0.09952840954065323 train acc 0.984375
epoch 9 batch id 201 loss 0.030905332416296005 train acc 0.9878731343283582
epoch 9 batch id 401 loss 0.0031795362010598183 train acc 0.9860504987531172
epoch 9 batch id 601 loss 0.05599892511963844 train acc 0.9859608985024958
epoch 9 batch id 801 loss 0.00523745222017169 train acc 0.9860720973782772
epoch 9 batch id 1001 loss 0.030446259304881096 train acc 0.9861232517482518
epoch 9 batch id 1201 loss 0.02420222945511341 train acc 0.986417568692756
epoch 9 batch id 1401 loss 0.09632450342178345 train acc 0.9865497858672377
epoch 9 batch id 1601 loss 0.002402503741905093 train acc 0.9863464241099313
epoch 9 batch id 1801 loss 0.1333174854516983 train acc 0.9862055802332038
epoch 9 batch id 2001 loss 0.003359818132594228 train acc 0.9863037231384307
epoch 9 batch id 2201 loss 0.011002270504832268 train acc 0.9863201385733757
epoch 9 batch id 2401 loss 0.07389777898788452 train acc 0.9862492190753852
epoch 9 batch id 2601 lo

HBox(children=(FloatProgress(value=0.0, max=444.0), HTML(value='')))


epoch 9 test acc 0.7756193693693694 test loss 1.9291362762451172


HBox(children=(FloatProgress(value=0.0, max=2971.0), HTML(value='')))

epoch 10 batch id 1 loss 0.11138397455215454 train acc 0.984375
epoch 10 batch id 201 loss 0.010439303703606129 train acc 0.9881840796019901


KeyboardInterrupt: ignored

# Inference

In [None]:
#새로운 문장 테스트
#토큰화
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

def predict(predict_sentence):

    data = [predict_sentence, '0']
    dataset_another = [data]

    another_test = BERTDataset(dataset_another, 0, 1, tok, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=5)
    
    model.eval()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length= valid_length
        label = label.long().to(device)

        out = model(token_ids, valid_length, segment_ids)


        test_eval=[]
        for i in out:
            logits=i
            logits = logits.detach().cpu().numpy()

            if np.argmax(logits) == 0:
                test_eval.append("불안이")
            elif np.argmax(logits) == 1:
                test_eval.append("당황이")
            elif np.argmax(logits) == 2:
                test_eval.append("분노가")
            elif np.argmax(logits) == 3:
                test_eval.append("슬픔이")
            elif np.argmax(logits) == 4:
                test_eval.append("중립이")
            elif np.argmax(logits) == 5:
                test_eval.append("행복이")
            elif np.argmax(logits) == 6:
                test_eval.append("혐오가")

        print(">> 입력하신 내용에서 " + test_eval[0] + " 느껴집니다.")

using cached model


In [53]:
checkpoint=torch.load(ckpt_name)

In [54]:
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [None]:
#질문 무한반복하기! 0 입력시 종료
end = 1
while end == 1 :
    sentence = input("하고싶은 말을 입력해주세요 : ")
    if sentence == 0 :
        break
    predict(sentence)
    print("\n")

하고싶은 말을 입력해주세요 : 짜증나네


  cpuset_checked))


>> 입력하신 내용에서 혐오가 느껴집니다.


하고싶은 말을 입력해주세요 : 열받게 하네
>> 입력하신 내용에서 분노가 느껴집니다.


하고싶은 말을 입력해주세요 : 죽고싶어
>> 입력하신 내용에서 슬픔이 느껴집니다.


하고싶은 말을 입력해주세요 : 어떻게 하면 좋지
>> 입력하신 내용에서 중립이 느껴집니다.


하고싶은 말을 입력해주세요 : 걱정이 태산이야
>> 입력하신 내용에서 불안이 느껴집니다.


하고싶은 말을 입력해주세요 : 어떻게 해야할지 모르겠어
>> 입력하신 내용에서 불안이 느껴집니다.


하고싶은 말을 입력해주세요 : 남친이랑 싸웠어
>> 입력하신 내용에서 슬픔이 느껴집니다.


하고싶은 말을 입력해주세요 : 여친이랑 싸웠어
>> 입력하신 내용에서 슬픔이 느껴집니다.


하고싶은 말을 입력해주세요 : 행복해
>> 입력하신 내용에서 행복이 느껴집니다.


하고싶은 말을 입력해주세요 : 우울하지 않아
>> 입력하신 내용에서 중립이 느껴집니다.


하고싶은 말을 입력해주세요 : 즐겁지 않아
>> 입력하신 내용에서 슬픔이 느껴집니다.


하고싶은 말을 입력해주세요 : 행복하지 않아
>> 입력하신 내용에서 슬픔이 느껴집니다.


하고싶은 말을 입력해주세요 : 안우울해
>> 입력하신 내용에서 슬픔이 느껴집니다.


하고싶은 말을 입력해주세요 : 안즐거워
>> 입력하신 내용에서 중립이 느껴집니다.


하고싶은 말을 입력해주세요 : 안슬퍼
>> 입력하신 내용에서 슬픔이 느껴집니다.


하고싶은 말을 입력해주세요 : 남자친구랑 헤어졌어
>> 입력하신 내용에서 슬픔이 느껴집니다.


하고싶은 말을 입력해주세요 : 역겨워
>> 입력하신 내용에서 혐오가 느껴집니다.


하고싶은 말을 입력해주세요 : 그냥 그래
>> 입력하신 내용에서 중립이 느껴집니다.


하고싶은 말을 입력해주세요 : 뭐?
>> 입력하신 내용에서 당황이 느껴집니다.


하고싶은 말을 입력해주세요 : 나보고 어쩌라고.....
>> 입력하신 내용에서 중립이 느껴집니다.


하고싶은 말을 입력해주세요 

KeyboardInterrupt: ignored