In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_notebook

import pandas as pd
import gluonnlp as nlp
import numpy as np


from transformers.optimization import get_cosine_schedule_with_warmup
from transformers import BertModel
from transformers import AdamW

In [2]:
from kobert_hf.kobert_tokenizer import KoBERTTokenizer
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')

In [3]:
tokenizer.vocab_file

'C:\\Users\\cpprh/.cache\\huggingface\\transformers\\6920ce54223b52af14e36b32047ced34c47ec88ac51f45ce0141aaa1054e3263.7eed87d19282a93a2d45e130f20b4d8e831cbf8e957f1476628fd4ab99ae977f'

In [4]:
from kobert.pytorch_kobert import get_pytorch_kobert_model
bertmodel, vocab = get_pytorch_kobert_model('C:\\Users\\cpprh/.cache\\huggingface\\transformers', '.cache') 

using cached model. C:\Users\cpprh\Documents\GitHub\EatShare-AI\.cache\kobert_v1.zip
using cached model. C:\Users\cpprh\Documents\GitHub\EatShare-AI\.cache\kobert_news_wiki_ko_cased-1087f8699e.spiece


In [5]:
bertmodel

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(8002, 768, padding_idx=1)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )

In [10]:
data = pd.read_csv('./final.csv', encoding='utf8')

In [11]:
data.drop(['Unnamed: 0'], axis=1, inplace=True)

In [12]:
data

Unnamed: 0,DESC_KOR,label
0,굴국밥,0
1,물회,0
2,생선물회,0
3,닭칼국수,0
4,삼선자장면,0
...,...,...
78193,닭고기표고버섯렌틸콩진밥 밀키트,26
78194,설성목장 한우 사골육수 떡볶이,26
78195,우둔 스테이크 도시락,26
78196,"표고버섯, 배지재배, 갓, 말린것",26


In [13]:
data_list = []
for quest, label in zip(data['DESC_KOR'], data['label']):
    data=[]
    data.append(quest)
    data.append(str(label))
    
    data_list.append(data)

In [14]:
data_list, len(data_list)

([['굴국밥', '0'],
  ['물회', '0'],
  ['생선물회', '0'],
  ['닭칼국수', '0'],
  ['삼선자장면', '0'],
  ['우동(중식)', '0'],
  ['해물크림소스스파게티', '0'],
  ['돼지국밥', '0'],
  ['잡탕밥', '0'],
  ['장어덮밥', '0'],
  ['참치덮밥', '0'],
  ['해물덮밥', '0'],
  ['회덮밥', '0'],
  ['치즈피자', '0'],
  ['페퍼로니피자', '0'],
  ['돼지등갈비찜', '0'],
  ['붕어찜', '0'],
  ['안동찜닭', '0'],
  ['치킨데리야끼', '0'],
  ['미트볼 토마토 스파게티', '0'],
  ['쟁반국수', '0'],
  ['콩국수', '0'],
  ['회냉면', '0'],
  ['돼지머리국밥', '0'],
  ['모듬회덮밥', '0'],
  ['비빔밥', '0'],
  ['제육덮밥', '0'],
  ['짬뽕밥', '0'],
  ['오리고기죽', '0'],
  ['부대찌개', '0'],
  ['클래식치즈피자', '0'],
  ['햄앤체다피자', '0'],
  ['고르곤졸라피자', '0'],
  ['페파로니', '0'],
  ['페파로니매니아피자', '0'],
  ['청양페파로니피자', '0'],
  ['하와이안피자', '0'],
  ['베이컨포테이토피자', '0'],
  ['불고기피자', '0'],
  ['부라타치즈피자', '0'],
  ['나폴리슈림프피자', '0'],
  ['디아볼라피자', '0'],
  ['볼로네즈', '0'],
  ['고향 만두', '0'],
  ['신 비비고 한섬 만두', '0'],
  ['서원 교자만두', '0'],
  ['밴쯔 덤플링', '0'],
  ['매생이 삼계탕', '0'],
  ['더욱 맛있어진 뉴 백설군만두', '0'],
  ['청고추만두', '0'],
  ['알찬소시지', '0'],
  ['전통 우리만두', '0'],
  ['맘마밀1', '0'],
  ['맘마밀2', '0'],

In [15]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer,vocab, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(bert_tokenizer, max_seq_length=max_len,vocab=vocab, pad=pad, pair=pair)
        
        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))
         

    def __len__(self):
        return (len(self.labels))

In [16]:
max_len = 64
batch_size = 4
warmup_ratio = 0.1
num_epochs = 5  
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

device = torch.device("cuda:0")

In [17]:
from sklearn.model_selection import train_test_split
dataset_train, dataset_test = train_test_split(data_list, test_size=0.2, shuffle=True, random_state=10)

In [18]:
tok=tokenizer.tokenize
data_train = BERTDataset(dataset_train, 0, 1, tok, vocab, max_len, True, False)
data_test = BERTDataset(dataset_test,0, 1, tok, vocab,  max_len, True, False)

In [19]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size)

In [20]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=27,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device),return_dict=False)
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [21]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)
 
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc
    
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x1f35595aac0>

In [None]:
train_history=[]
test_history=[]
loss_history=[]
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
         
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch: {}, batch id: {}, loss: {}, train acc: {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
            train_history.append(train_acc / (batch_id+1))
            loss_history.append(loss.data.cpu().numpy())
    print("epoch: {}, train acc: {}".format(e+1, train_acc / (batch_id+1)))
    
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch: {}, test acc: {}".format(e+1, test_acc / (batch_id+1)))
    test_history.append(test_acc / (batch_id+1))

  0%|                                                                              | 3/15640 [00:00<1:58:17,  2.20it/s]

epoch: 1, batch id: 1, loss: 3.1744022369384766, train acc: 0.0


  1%|█                                                                             | 203/15640 [00:15<21:03, 12.21it/s]

epoch: 1, batch id: 201, loss: 3.235738515853882, train acc: 0.1318407960199005


  3%|█▉                                                                            | 401/15640 [00:32<22:13, 11.43it/s]

epoch: 1, batch id: 401, loss: 2.210911750793457, train acc: 0.20199501246882792


  4%|███                                                                           | 602/15640 [00:49<23:12, 10.80it/s]

epoch: 1, batch id: 601, loss: 3.1308937072753906, train acc: 0.2391846921797005


  5%|███▉                                                                          | 802/15640 [01:05<17:58, 13.76it/s]

epoch: 1, batch id: 801, loss: 2.6245386600494385, train acc: 0.25124843945068664


  6%|████▉                                                                        | 1002/15640 [01:22<19:22, 12.59it/s]

epoch: 1, batch id: 1001, loss: 2.6792309284210205, train acc: 0.26373626373626374


  8%|█████▉                                                                       | 1202/15640 [01:38<17:11, 14.00it/s]

epoch: 1, batch id: 1201, loss: 3.2044057846069336, train acc: 0.2722731057452123


  9%|██████▉                                                                      | 1403/15640 [01:55<18:34, 12.77it/s]

epoch: 1, batch id: 1401, loss: 1.58082115650177, train acc: 0.2780157030692363


 10%|███████▉                                                                     | 1603/15640 [02:11<23:41,  9.88it/s]

epoch: 1, batch id: 1601, loss: 2.6531379222869873, train acc: 0.2812304809494066


 12%|████████▉                                                                    | 1803/15640 [02:28<18:25, 12.52it/s]

epoch: 1, batch id: 1801, loss: 2.8854317665100098, train acc: 0.28387007218212107


 13%|█████████▊                                                                   | 2002/15640 [02:44<17:03, 13.33it/s]

epoch: 1, batch id: 2001, loss: 2.0180273056030273, train acc: 0.28473263368315843


 14%|██████████▊                                                                  | 2202/15640 [03:00<19:12, 11.66it/s]

epoch: 1, batch id: 2201, loss: 3.010880470275879, train acc: 0.2869150386188096


 15%|███████████▊                                                                 | 2402/15640 [03:18<19:37, 11.24it/s]

epoch: 1, batch id: 2401, loss: 2.4640228748321533, train acc: 0.29081632653061223


 17%|████████████▊                                                                | 2603/15640 [03:34<20:36, 10.55it/s]

epoch: 1, batch id: 2601, loss: 1.92498779296875, train acc: 0.291522491349481


 18%|█████████████▊                                                               | 2802/15640 [03:50<18:02, 11.86it/s]

epoch: 1, batch id: 2801, loss: 2.7982964515686035, train acc: 0.2918600499821492


 19%|██████████████▊                                                              | 3003/15640 [04:08<16:26, 12.81it/s]

epoch: 1, batch id: 3001, loss: 1.9277865886688232, train acc: 0.29423525491502833


 20%|███████████████▊                                                             | 3203/15640 [04:23<14:22, 14.42it/s]

epoch: 1, batch id: 3201, loss: 2.8211846351623535, train acc: 0.2935020306154327


 22%|████████████████▋                                                            | 3402/15640 [04:40<15:42, 12.98it/s]

epoch: 1, batch id: 3401, loss: 3.8794713020324707, train acc: 0.29432519847103794


 23%|█████████████████▋                                                           | 3602/15640 [04:57<20:28,  9.80it/s]

epoch: 1, batch id: 3601, loss: 1.3454034328460693, train acc: 0.2956817550680367


 24%|██████████████████▋                                                          | 3802/15640 [05:13<14:11, 13.91it/s]

epoch: 1, batch id: 3801, loss: 2.526069402694702, train acc: 0.2972244146277295


 26%|███████████████████▋                                                         | 4002/15640 [05:30<20:12,  9.60it/s]

epoch: 1, batch id: 4001, loss: 2.1125786304473877, train acc: 0.29967508122969255


 27%|████████████████████▋                                                        | 4203/15640 [05:47<14:24, 13.23it/s]

epoch: 1, batch id: 4201, loss: 2.3238441944122314, train acc: 0.3029040704594144


 28%|█████████████████████▋                                                       | 4403/15640 [06:03<13:48, 13.57it/s]

epoch: 1, batch id: 4401, loss: 3.4850072860717773, train acc: 0.3026016814360373


 29%|██████████████████████▋                                                      | 4603/15640 [06:18<14:57, 12.30it/s]

epoch: 1, batch id: 4601, loss: 2.748196601867676, train acc: 0.3034123016735492


 31%|███████████████████████▋                                                     | 4802/15640 [06:36<15:39, 11.54it/s]

epoch: 1, batch id: 4801, loss: 2.213165760040283, train acc: 0.3053530514476151


 32%|████████████████████████▌                                                    | 5001/15640 [06:52<17:23, 10.19it/s]

epoch: 1, batch id: 5001, loss: 3.1321821212768555, train acc: 0.30663867226554686


 33%|█████████████████████████▌                                                   | 5203/15640 [07:09<14:33, 11.94it/s]

epoch: 1, batch id: 5201, loss: 1.5834251642227173, train acc: 0.3077773505095174


 35%|██████████████████████████▌                                                  | 5403/15640 [07:26<13:13, 12.90it/s]

epoch: 1, batch id: 5401, loss: 2.6389999389648438, train acc: 0.310081466395112


 36%|███████████████████████████▌                                                 | 5603/15640 [07:42<11:46, 14.20it/s]

epoch: 1, batch id: 5601, loss: 1.3452110290527344, train acc: 0.31280128548473485


 37%|████████████████████████████▌                                                | 5803/15640 [07:58<12:03, 13.59it/s]

epoch: 1, batch id: 5801, loss: 1.4709264039993286, train acc: 0.3135666264437166


 38%|█████████████████████████████▌                                               | 6003/15640 [08:14<14:08, 11.35it/s]

epoch: 1, batch id: 6001, loss: 1.5384900569915771, train acc: 0.3148225295784036


 40%|██████████████████████████████▌                                              | 6202/15640 [08:31<12:06, 12.98it/s]

epoch: 1, batch id: 6201, loss: 2.737062692642212, train acc: 0.31627963231736816


 41%|███████████████████████████████▌                                             | 6403/15640 [08:47<14:24, 10.69it/s]

epoch: 1, batch id: 6401, loss: 3.7658157348632812, train acc: 0.3170598344008749


 42%|████████████████████████████████▌                                            | 6603/15640 [09:04<11:46, 12.79it/s]

epoch: 1, batch id: 6601, loss: 3.132086753845215, train acc: 0.31858809271322525


 43%|█████████████████████████████████▍                                           | 6803/15640 [09:21<10:47, 13.64it/s]

epoch: 1, batch id: 6801, loss: 1.4654819965362549, train acc: 0.31951183649463316


 45%|██████████████████████████████████▍                                          | 7003/15640 [09:36<09:52, 14.59it/s]

epoch: 1, batch id: 7001, loss: 2.491446018218994, train acc: 0.31991858305956294


 46%|███████████████████████████████████▍                                         | 7202/15640 [09:53<11:47, 11.93it/s]

epoch: 1, batch id: 7201, loss: 2.8804852962493896, train acc: 0.3204416053325927


 47%|████████████████████████████████████▍                                        | 7402/15640 [10:09<14:18,  9.59it/s]

epoch: 1, batch id: 7401, loss: 1.652545690536499, train acc: 0.3212065937035536


 49%|█████████████████████████████████████▍                                       | 7602/15640 [10:25<09:53, 13.55it/s]

epoch: 1, batch id: 7601, loss: 1.6997361183166504, train acc: 0.32281936587291143


 50%|██████████████████████████████████████▍                                      | 7801/15640 [10:42<11:36, 11.25it/s]

epoch: 1, batch id: 7801, loss: 1.75899076461792, train acc: 0.3243494423791822


 51%|███████████████████████████████████████▍                                     | 8003/15640 [10:58<08:42, 14.62it/s]

epoch: 1, batch id: 8001, loss: 0.5727085471153259, train acc: 0.3253030871141107


 52%|████████████████████████████████████████▍                                    | 8202/15640 [11:14<08:57, 13.84it/s]

epoch: 1, batch id: 8201, loss: 2.1749658584594727, train acc: 0.3257529569564687


 54%|█████████████████████████████████████████▎                                   | 8402/15640 [11:30<09:21, 12.89it/s]

epoch: 1, batch id: 8401, loss: 2.1598706245422363, train acc: 0.32695512438995356


 55%|██████████████████████████████████████████▎                                  | 8602/15640 [11:47<08:19, 14.08it/s]

epoch: 1, batch id: 8601, loss: 0.957973301410675, train acc: 0.32810138356005114


 56%|███████████████████████████████████████████▎                                 | 8801/15640 [12:03<10:46, 10.58it/s]

epoch: 1, batch id: 8801, loss: 1.7618780136108398, train acc: 0.3287126462901943


 58%|████████████████████████████████████████████▎                                | 9002/15640 [12:20<09:00, 12.29it/s]

epoch: 1, batch id: 9001, loss: 1.6113612651824951, train acc: 0.32982446394845016


 59%|█████████████████████████████████████████████▎                               | 9203/15640 [12:38<08:12, 13.08it/s]

epoch: 1, batch id: 9201, loss: 2.2073609828948975, train acc: 0.33083360504293013


 60%|██████████████████████████████████████████████▎                              | 9403/15640 [12:53<07:11, 14.44it/s]

epoch: 1, batch id: 9401, loss: 3.3381311893463135, train acc: 0.33124135730241466


 61%|███████████████████████████████████████████████▎                             | 9602/15640 [13:10<07:36, 13.23it/s]

epoch: 1, batch id: 9601, loss: 1.8594939708709717, train acc: 0.3331944589105302


 63%|████████████████████████████████████████████████▎                            | 9802/15640 [13:27<09:42, 10.02it/s]

epoch: 1, batch id: 9801, loss: 1.9502155780792236, train acc: 0.3338179777573717


 64%|████████████████████████████████████████████████▌                           | 10002/15640 [13:43<08:20, 11.27it/s]

epoch: 1, batch id: 10001, loss: 1.4844577312469482, train acc: 0.3344915508449155


 65%|█████████████████████████████████████████████████▌                          | 10202/15640 [14:00<08:46, 10.33it/s]

epoch: 1, batch id: 10201, loss: 1.483496069908142, train acc: 0.3354818155082835


 67%|██████████████████████████████████████████████████▌                         | 10402/15640 [14:16<07:01, 12.44it/s]

epoch: 1, batch id: 10401, loss: 1.9280543327331543, train acc: 0.33578502067108934


 68%|███████████████████████████████████████████████████▌                        | 10603/15640 [14:33<06:12, 13.53it/s]

epoch: 1, batch id: 10601, loss: 2.837447166442871, train acc: 0.3363833600603717


 69%|████████████████████████████████████████████████████▍                       | 10803/15640 [14:49<06:22, 12.64it/s]

epoch: 1, batch id: 10801, loss: 1.9117681980133057, train acc: 0.3373298768632534


 70%|█████████████████████████████████████████████████████▍                      | 11002/15640 [15:06<05:54, 13.09it/s]

epoch: 1, batch id: 11001, loss: 2.5293726921081543, train acc: 0.33785564948641034


 72%|██████████████████████████████████████████████████████▍                     | 11202/15640 [15:21<06:23, 11.56it/s]

epoch: 1, batch id: 11201, loss: 1.0469824075698853, train acc: 0.33842960449959825


 73%|███████████████████████████████████████████████████████▍                    | 11402/15640 [15:38<05:51, 12.06it/s]

epoch: 1, batch id: 11401, loss: 1.745417594909668, train acc: 0.3394877642312078


 74%|████████████████████████████████████████████████████████▎                   | 11601/15640 [15:55<06:48,  9.90it/s]

epoch: 1, batch id: 11601, loss: 0.8167910575866699, train acc: 0.34078958710456


 75%|█████████████████████████████████████████████████████████▎                  | 11803/15640 [16:11<04:34, 13.98it/s]

epoch: 1, batch id: 11801, loss: 2.153538227081299, train acc: 0.34156003728497586


 77%|██████████████████████████████████████████████████████████▎                 | 12003/15640 [16:27<04:36, 13.14it/s]

epoch: 1, batch id: 12001, loss: 1.8185888528823853, train acc: 0.34203399716690275


 78%|███████████████████████████████████████████████████████████▎                | 12201/15640 [16:44<05:13, 10.98it/s]

epoch: 1, batch id: 12201, loss: 2.3973724842071533, train acc: 0.3426973198918121


 79%|████████████████████████████████████████████████████████████▎               | 12403/15640 [17:00<03:52, 13.92it/s]

epoch: 1, batch id: 12401, loss: 1.1702300310134888, train acc: 0.34380291911942584


 81%|█████████████████████████████████████████████████████████████▏              | 12602/15640 [17:16<04:38, 10.91it/s]

epoch: 1, batch id: 12601, loss: 2.543942928314209, train acc: 0.34435759066740734


 82%|██████████████████████████████████████████████████████████████▏             | 12802/15640 [17:33<03:47, 12.45it/s]

epoch: 1, batch id: 12801, loss: 1.078452229499817, train acc: 0.34505116787750956


 83%|███████████████████████████████████████████████████████████████▏            | 13002/15640 [17:49<03:16, 13.46it/s]

epoch: 1, batch id: 13001, loss: 1.6297571659088135, train acc: 0.34568494731174526


 84%|████████████████████████████████████████████████████████████████▏           | 13202/15640 [18:05<03:13, 12.60it/s]

epoch: 1, batch id: 13201, loss: 4.272177219390869, train acc: 0.34622377092644496


 86%|█████████████████████████████████████████████████████████████████           | 13402/15640 [18:22<03:16, 11.39it/s]

epoch: 1, batch id: 13401, loss: 2.3703773021698, train acc: 0.3469703753451235


 87%|██████████████████████████████████████████████████████████████████          | 13602/15640 [18:38<03:06, 10.94it/s]

epoch: 1, batch id: 13601, loss: 1.5284205675125122, train acc: 0.34767664142342475


 88%|███████████████████████████████████████████████████████████████████         | 13802/15640 [18:55<02:35, 11.81it/s]

epoch: 1, batch id: 13801, loss: 1.9884403944015503, train acc: 0.3481994058401565


 90%|████████████████████████████████████████████████████████████████████        | 14003/15640 [19:12<02:36, 10.46it/s]

epoch: 1, batch id: 14001, loss: 2.2928261756896973, train acc: 0.3485643882579816


 91%|█████████████████████████████████████████████████████████████████████       | 14203/15640 [19:28<01:41, 14.17it/s]

epoch: 1, batch id: 14201, loss: 0.5309339761734009, train acc: 0.34930638687416377


 92%|█████████████████████████████████████████████████████████████████████▉      | 14403/15640 [19:45<01:32, 13.37it/s]

epoch: 1, batch id: 14401, loss: 0.7991712689399719, train acc: 0.3503402541490174


 93%|██████████████████████████████████████████████████████████████████████▉     | 14603/15640 [20:01<01:22, 12.54it/s]

epoch: 1, batch id: 14601, loss: 1.973130464553833, train acc: 0.35095198958975415


 95%|███████████████████████████████████████████████████████████████████████▉    | 14802/15640 [20:18<00:59, 13.98it/s]

epoch: 1, batch id: 14801, loss: 1.4011927843093872, train acc: 0.3515134112559962


 96%|████████████████████████████████████████████████████████████████████████▉   | 15002/15640 [20:34<01:10,  9.08it/s]

epoch: 1, batch id: 15001, loss: 1.3850641250610352, train acc: 0.35222651823211787


 97%|█████████████████████████████████████████████████████████████████████████▊  | 15202/15640 [20:51<00:34, 12.55it/s]

epoch: 1, batch id: 15201, loss: 1.7076590061187744, train acc: 0.35297019932899154


 98%|██████████████████████████████████████████████████████████████████████████▊ | 15402/15640 [21:08<00:18, 13.18it/s]

epoch: 1, batch id: 15401, loss: 3.1067214012145996, train acc: 0.3535322381663528


100%|███████████████████████████████████████████████████████████████████████████▊| 15602/15640 [21:23<00:02, 14.12it/s]

epoch: 1, batch id: 15601, loss: 1.233143925666809, train acc: 0.35412794051663354


100%|████████████████████████████████████████████████████████████████████████████| 15640/15640 [21:26<00:00, 12.16it/s]
  0%|                                                                                 | 5/3910 [00:00<01:24, 46.27it/s]

epoch: 1, train acc: 0.35429987212276215


100%|██████████████████████████████████████████████████████████████████████████████| 3910/3910 [00:56<00:00, 69.35it/s]
  0%|                                                                                | 2/15640 [00:00<18:14, 14.28it/s]

epoch: 1, test acc: 0.4108695652173913
epoch: 2, batch id: 1, loss: 1.3945467472076416, train acc: 0.5


  1%|█                                                                             | 201/15640 [00:17<25:59,  9.90it/s]

epoch: 2, batch id: 201, loss: 1.9479695558547974, train acc: 0.40298507462686567


  3%|██                                                                            | 403/15640 [00:33<17:50, 14.23it/s]

epoch: 2, batch id: 401, loss: 2.4828946590423584, train acc: 0.4021197007481297


  4%|███                                                                           | 603/15640 [00:49<18:42, 13.40it/s]

epoch: 2, batch id: 601, loss: 2.081570863723755, train acc: 0.40931780366056575


  5%|████                                                                          | 803/15640 [01:05<18:33, 13.32it/s]

epoch: 2, batch id: 801, loss: 1.2131773233413696, train acc: 0.4085518102372035


  6%|████▉                                                                        | 1003/15640 [01:22<18:33, 13.15it/s]

epoch: 2, batch id: 1001, loss: 2.0278873443603516, train acc: 0.40409590409590407


  8%|█████▉                                                                       | 1201/15640 [01:38<23:09, 10.39it/s]

epoch: 2, batch id: 1201, loss: 2.4686062335968018, train acc: 0.40903413821815154


  9%|██████▉                                                                      | 1403/15640 [01:55<20:27, 11.60it/s]

epoch: 2, batch id: 1401, loss: 0.9865354895591736, train acc: 0.4054246966452534


 10%|███████▉                                                                     | 1603/15640 [02:13<17:57, 13.03it/s]

epoch: 2, batch id: 1601, loss: 3.132720470428467, train acc: 0.40677701436602126


 12%|████████▉                                                                    | 1803/15640 [02:28<16:03, 14.37it/s]

epoch: 2, batch id: 1801, loss: 3.021733045578003, train acc: 0.40685730149916716


 13%|█████████▊                                                                   | 2002/15640 [02:45<17:08, 13.26it/s]

epoch: 2, batch id: 2001, loss: 1.3057390451431274, train acc: 0.4065467266366817


 14%|██████████▊                                                                  | 2202/15640 [03:02<22:34,  9.92it/s]

epoch: 2, batch id: 2201, loss: 2.878096580505371, train acc: 0.4075420263516583


 15%|███████████▊                                                                 | 2402/15640 [03:19<18:53, 11.67it/s]

epoch: 2, batch id: 2401, loss: 2.4054791927337646, train acc: 0.4080591420241566


 17%|████████████▊                                                                | 2602/15640 [03:36<22:07,  9.82it/s]

epoch: 2, batch id: 2601, loss: 1.1751011610031128, train acc: 0.4076316801230296


 18%|█████████████▊                                                               | 2802/15640 [03:52<16:18, 13.12it/s]

epoch: 2, batch id: 2801, loss: 1.244735598564148, train acc: 0.40797929310960374


 19%|██████████████▊                                                              | 3003/15640 [04:09<15:05, 13.96it/s]

epoch: 2, batch id: 3001, loss: 1.9487087726593018, train acc: 0.4080306564478507


 20%|███████████████▊                                                             | 3203/15640 [04:24<15:03, 13.76it/s]

epoch: 2, batch id: 3201, loss: 2.184830904006958, train acc: 0.40893470790378006


 22%|████████████████▊                                                            | 3403/15640 [04:41<17:39, 11.55it/s]

epoch: 2, batch id: 3401, loss: 2.7112021446228027, train acc: 0.4086298147603646


 23%|█████████████████▋                                                           | 3602/15640 [04:58<20:51,  9.62it/s]

epoch: 2, batch id: 3601, loss: 0.4334290325641632, train acc: 0.40974729241877256


 24%|██████████████████▋                                                          | 3803/15640 [05:14<17:17, 11.41it/s]

epoch: 2, batch id: 3801, loss: 1.1073251962661743, train acc: 0.4102867666403578


 26%|███████████████████▋                                                         | 4003/15640 [05:32<15:34, 12.45it/s]

epoch: 2, batch id: 4001, loss: 2.08813214302063, train acc: 0.4127093226693327


 27%|████████████████████▋                                                        | 4203/15640 [05:48<13:37, 13.99it/s]

epoch: 2, batch id: 4201, loss: 1.928961157798767, train acc: 0.4135324922637467


 28%|█████████████████████▋                                                       | 4403/15640 [06:04<14:02, 13.34it/s]

epoch: 2, batch id: 4401, loss: 3.2695751190185547, train acc: 0.4131447398318564


 29%|██████████████████████▋                                                      | 4602/15640 [06:21<20:26,  9.00it/s]

epoch: 2, batch id: 4601, loss: 1.3160830736160278, train acc: 0.41425777004998915


 31%|███████████████████████▋                                                     | 4803/15640 [06:38<16:09, 11.18it/s]

epoch: 2, batch id: 4801, loss: 2.144050359725952, train acc: 0.41449697979587585


 32%|████████████████████████▌                                                    | 5001/15640 [06:55<16:52, 10.51it/s]

epoch: 2, batch id: 5001, loss: 1.7307188510894775, train acc: 0.41421715656868624


 33%|█████████████████████████▌                                                   | 5202/15640 [07:11<14:24, 12.07it/s]

epoch: 2, batch id: 5201, loss: 1.0694621801376343, train acc: 0.41419919246298786


 35%|██████████████████████████▌                                                  | 5403/15640 [07:28<12:26, 13.70it/s]

epoch: 2, batch id: 5401, loss: 2.328127384185791, train acc: 0.4157563414182559


 36%|███████████████████████████▌                                                 | 5603/15640 [07:43<11:45, 14.22it/s]

epoch: 2, batch id: 5601, loss: 1.461138367652893, train acc: 0.41702374575968576


 37%|████████████████████████████▌                                                | 5803/15640 [08:00<12:52, 12.73it/s]

epoch: 2, batch id: 5801, loss: 0.939501941204071, train acc: 0.416910877434925


 38%|█████████████████████████████▌                                               | 6002/15640 [08:17<15:35, 10.30it/s]

epoch: 2, batch id: 6001, loss: 1.2186328172683716, train acc: 0.4167222129645059


 40%|██████████████████████████████▌                                              | 6203/15640 [08:34<12:12, 12.89it/s]

epoch: 2, batch id: 6201, loss: 2.446932315826416, train acc: 0.41779551685212063


 41%|███████████████████████████████▌                                             | 6402/15640 [08:51<15:17, 10.06it/s]

epoch: 2, batch id: 6401, loss: 2.315028429031372, train acc: 0.4182549601624746


 42%|████████████████████████████████▌                                            | 6602/15640 [09:07<10:14, 14.72it/s]

epoch: 2, batch id: 6601, loss: 2.7442879676818848, train acc: 0.41880018179063777


 43%|█████████████████████████████████▍                                           | 6803/15640 [09:23<10:40, 13.79it/s]

epoch: 2, batch id: 6801, loss: 1.7155091762542725, train acc: 0.4190560211733569


 45%|██████████████████████████████████▍                                          | 7003/15640 [09:39<11:14, 12.80it/s]

epoch: 2, batch id: 7001, loss: 1.8803596496582031, train acc: 0.4189758605913441


 46%|███████████████████████████████████▍                                         | 7202/15640 [09:56<12:30, 11.24it/s]

epoch: 2, batch id: 7201, loss: 2.0613815784454346, train acc: 0.4185182613525899


 47%|████████████████████████████████████▍                                        | 7402/15640 [10:13<13:28, 10.19it/s]

epoch: 2, batch id: 7401, loss: 0.9147982597351074, train acc: 0.41882853668423187


 49%|█████████████████████████████████████▍                                       | 7602/15640 [10:29<11:01, 12.16it/s]

epoch: 2, batch id: 7601, loss: 1.6181203126907349, train acc: 0.4198460728851467


 50%|██████████████████████████████████████▍                                      | 7802/15640 [10:47<10:07, 12.89it/s]

epoch: 2, batch id: 7801, loss: 1.0332086086273193, train acc: 0.4202025381361364


 51%|███████████████████████████████████████▍                                     | 8002/15640 [11:02<08:57, 14.21it/s]

epoch: 2, batch id: 8001, loss: 0.3093584477901459, train acc: 0.4207286589176353


 52%|████████████████████████████████████████▍                                    | 8203/15640 [11:19<09:17, 13.34it/s]

epoch: 2, batch id: 8201, loss: 1.9140816926956177, train acc: 0.4211986343128887


 54%|█████████████████████████████████████████▎                                   | 8401/15640 [11:36<11:00, 10.96it/s]

epoch: 2, batch id: 8401, loss: 1.9494397640228271, train acc: 0.4224794667301512


 55%|██████████████████████████████████████████▎                                  | 8603/15640 [11:53<09:02, 12.98it/s]

epoch: 2, batch id: 8601, loss: 0.786263644695282, train acc: 0.42279967445645855


 56%|███████████████████████████████████████████▎                                 | 8801/15640 [12:09<10:34, 10.78it/s]

epoch: 2, batch id: 8801, loss: 1.294402003288269, train acc: 0.42319054652880356


 58%|████████████████████████████████████████████▎                                | 9003/15640 [12:25<08:50, 12.51it/s]

epoch: 2, batch id: 9001, loss: 1.3457436561584473, train acc: 0.4242584157315854


 59%|█████████████████████████████████████████████▎                               | 9202/15640 [12:42<07:41, 13.96it/s]

epoch: 2, batch id: 9201, loss: 2.1634836196899414, train acc: 0.4247636126507988


 60%|██████████████████████████████████████████████▎                              | 9402/15640 [12:57<08:03, 12.90it/s]

epoch: 2, batch id: 9401, loss: 2.5168297290802, train acc: 0.42506116370598873


 61%|███████████████████████████████████████████████▎                             | 9603/15640 [13:15<08:37, 11.66it/s]

epoch: 2, batch id: 9601, loss: 0.8688086867332458, train acc: 0.42654410998854286


 63%|████████████████████████████████████████████████▎                            | 9801/15640 [13:31<09:50,  9.89it/s]

epoch: 2, batch id: 9801, loss: 2.5599582195281982, train acc: 0.42666564636261606


 64%|████████████████████████████████████████████████▌                           | 10002/15640 [13:47<07:44, 12.13it/s]

epoch: 2, batch id: 10001, loss: 1.6359505653381348, train acc: 0.4270822917708229


 65%|█████████████████████████████████████████████████▌                          | 10202/15640 [14:05<08:03, 11.24it/s]

epoch: 2, batch id: 10201, loss: 2.0116350650787354, train acc: 0.42802176257229685


 67%|██████████████████████████████████████████████████▌                         | 10402/15640 [14:20<06:12, 14.06it/s]

epoch: 2, batch id: 10401, loss: 2.5794177055358887, train acc: 0.4286366695510047


 68%|███████████████████████████████████████████████████▌                        | 10603/15640 [14:37<06:22, 13.16it/s]

epoch: 2, batch id: 10601, loss: 1.4501186609268188, train acc: 0.429251957362513


 69%|████████████████████████████████████████████████████▍                       | 10801/15640 [14:53<07:07, 11.32it/s]

epoch: 2, batch id: 10801, loss: 2.5718374252319336, train acc: 0.43037681696139246


 70%|█████████████████████████████████████████████████████▍                      | 11003/15640 [15:11<05:33, 13.89it/s]

epoch: 2, batch id: 11001, loss: 2.3420515060424805, train acc: 0.4307790200890828


 72%|██████████████████████████████████████████████████████▍                     | 11202/15640 [15:27<06:26, 11.49it/s]

epoch: 2, batch id: 11201, loss: 0.40584808588027954, train acc: 0.4312561378448353


 73%|███████████████████████████████████████████████████████▍                    | 11402/15640 [15:44<06:11, 11.42it/s]

epoch: 2, batch id: 11401, loss: 2.2332184314727783, train acc: 0.431760371897202


 74%|████████████████████████████████████████████████████████▍                   | 11602/15640 [16:01<05:00, 13.42it/s]

epoch: 2, batch id: 11601, loss: 0.971920907497406, train acc: 0.43282906645978797


 75%|█████████████████████████████████████████████████████████▎                  | 11802/15640 [16:16<04:31, 14.12it/s]

epoch: 2, batch id: 11801, loss: 1.9295976161956787, train acc: 0.43339547495974917


 77%|██████████████████████████████████████████████████████████▎                 | 12002/15640 [16:33<04:33, 13.29it/s]

epoch: 2, batch id: 12001, loss: 2.190340518951416, train acc: 0.4336305307891009


 78%|███████████████████████████████████████████████████████████▎                | 12202/15640 [16:50<05:27, 10.51it/s]

epoch: 2, batch id: 12201, loss: 2.6447410583496094, train acc: 0.4343086632243259


 79%|████████████████████████████████████████████████████████████▎               | 12401/15640 [17:07<04:29, 12.02it/s]

epoch: 2, batch id: 12401, loss: 1.6494566202163696, train acc: 0.43474316587371986


 81%|█████████████████████████████████████████████████████████████▏              | 12603/15640 [17:24<05:08,  9.84it/s]

epoch: 2, batch id: 12601, loss: 2.1618459224700928, train acc: 0.43508451710181734


 82%|██████████████████████████████████████████████████████████████▏             | 12803/15640 [17:40<03:19, 14.21it/s]

epoch: 2, batch id: 12801, loss: 0.9838250875473022, train acc: 0.4354542613858292


 83%|███████████████████████████████████████████████████████████████▏            | 13003/15640 [17:57<03:12, 13.72it/s]

epoch: 2, batch id: 13001, loss: 1.36307954788208, train acc: 0.4362164448888547


 84%|████████████████████████████████████████████████████████████████▏           | 13203/15640 [18:13<03:08, 12.91it/s]

epoch: 2, batch id: 13201, loss: 3.476532459259033, train acc: 0.4365957124460268


 86%|█████████████████████████████████████████████████████████████████           | 13401/15640 [18:30<03:15, 11.48it/s]

epoch: 2, batch id: 13401, loss: 1.9668827056884766, train acc: 0.43720617864338485


 87%|██████████████████████████████████████████████████████████████████          | 13602/15640 [18:46<02:57, 11.48it/s]

epoch: 2, batch id: 13601, loss: 1.389037013053894, train acc: 0.4377251672671127


 88%|███████████████████████████████████████████████████████████████████         | 13801/15640 [19:02<02:47, 11.00it/s]

epoch: 2, batch id: 13801, loss: 1.5948200225830078, train acc: 0.4386638649373234


 89%|████████████████████████████████████████████████████████████████████        | 13996/15640 [19:19<02:02, 13.43it/s]

In [None]:
def predict(predict_sentence):

    data = [predict_sentence, '0']
    dataset_another = [data]

    another_test = BERTDataset(dataset_another, 0, 1, tok, vocab, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size=batch_size)
    
    model.eval()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length= valid_length
        label = label.long().to(device)

        out = model(token_ids, valid_length, segment_ids)


        test_eval=[]
        for i in out:
            logits=i
            logits = logits.detach().cpu().numpy()

            if np.argmax(logits) == 0:
                test_eval.append("탄2, 단2, 지2")
            elif np.argmax(logits) == 1:
                test_eval.append("탄1, 단1, 지1")
            elif np.argmax(logits) == 2:
                test_eval.append("탄0, 단0, 지0")
            else:
                test_eval.append("응애")
        print(">> " + test_eval[0])