In [1]:
import os.path
import time
import pickle
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from kobert import get_tokenizer
from kobert import get_pytorch_kobert_model
from transformers import AdamW, BertModel
from transformers.optimization import get_cosine_schedule_with_warmup
import json
import pandas as pd


In [2]:
device = torch.device("cuda:0")
model_path = 'ksic_model'

In [None]:
??BertModel

#### 모델 저장의 방법
* Pytorch는 모델을 저장할 때 torch.save(object, file) 함수 사용
    * object : 모델 객체, file : 파일 이름
##### 예시1
* torch.save(model, 'model.pt')
* model = torch.load('model.pt')
##### 예시2
* torch.save(model.state_dict(), 'model.pt')
* model.load_state_dict(torch.load('model.pt'))

In [3]:
bertmodel, vocab = get_pytorch_kobert_model(cachedir=".cache")

using cached model. /home/hdh/PycharmProjects/KoBERT-master/.cache/kobert_v1.zip
using cached model. /home/hdh/PycharmProjects/KoBERT-master/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [None]:
# torch.save(bertmodel, 'model_save_test/ksic_bert_model.pt')
# model = torch.load('model_save_test/ksic_bert_model.pt')
# # 이건 왜 성공하지?

In [4]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size=768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate

        self.classifier = nn.Linear(hidden_size, num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)

    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)

        _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(),
                              attention_mask=attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        else:
            out = pooler
        return self.classifier(out)

In [5]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)
model = torch.load(os.path.join(model_path, 'KSIC_KoBERT.pt'))
# Can't get attribute 'BERTClassifier' on <module '__main__'>

In [6]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model. /home/hdh/PycharmProjects/KoBERT-master/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [7]:
# 예측 모델 설정
def predict(predict_sentence):

    data = [predict_sentence, '0']
    dataset_another = [data]

    another_test = BERTDataset(dataset_another, 0, 1, tok, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=5)
    
    model.eval()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length= valid_length
        label = label.long().to(device)

        out = model(token_ids, valid_length, segment_ids)
        max_vals, max_indices = torch.max(out, 1)
#         test_eval=[]
#         for i in out:
#             logits=i
#             logits = logits.detach().cpu().numpy()
#             min_v = min(logits)
#             total = 0
#             probability = []
#             logits = np.round(new_softmax(logits), 3).tolist()
#             for logit in logits:
#                 print(logit)
#                 probability.append(np.round(logit, 3))

#             if np.argmax(logits) == 0:  emotion = "기쁨"
#             elif np.argmax(logits) == 1: emotion = "불안"
#             elif np.argmax(logits) == 2: emotion = '당황'
#             elif np.argmax(logits) == 3: emotion = '슬픔'
#             elif np.argmax(logits) == 4: emotion = '분노'
#             elif np.argmax(logits) == 5: emotion = '상처'

#             probability.append(emotion)
#             print(probability)
    return max_vals, max_indices

In [None]:
??torch.max

In [8]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in tqdm(dataset)]
        self.labels = [np.int32(i[label_idx]) for i in tqdm(dataset)]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [9]:
with open(".cache/ksic_data_test_id.pickle", "rb") as fr:
    data_test_id = pickle.load(fr)
with open('.cache/label_ksic.pickle', 'rb') as f:
    ksic_index_dict = pickle.load(f)
with open('.cache/ksic_label.pickle', 'rb') as f:
    ksic_label_dict = pickle.load(f)
print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # ' 1:36PM EDT on Oct 18, 2010'
print('Completed to load BERTDataset')

 8:35PM KST on Jun 26, 2022
Completed to load BERTDataset


In [10]:
max_len = 256
batch_size = 8
warmup_ratio = 0.1
num_epochs = 5
max_grad_norm = 1
log_interval = 10000
learning_rate = 5e-5

In [11]:
test_dataloader = torch.utils.data.DataLoader(data_test_id, batch_size=batch_size, num_workers=8
#                                               , shuffle=True
                                             )

In [12]:
np.shape(data_test_id)

  result = asarray(a).shape


(620082, 4)

In [13]:
def calc_accuracy(X, y):
    max_vals, max_indices = torch.max(X, 1)
    acc = (max_indices == y).sum().data.cpu().numpy()/max_indices.size()[0]
    return acc

In [14]:
%%time
test_acc = 0.0
results = []
for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
#     results.extend([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    results.append([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    test_acc += calc_accuracy(out, label)
    if batch_id % 6000 == 0:
        print(time.strftime('%l:%M%p %Z on %b %d, %Y'), 'batch_id: ', batch_id, ', test acc {}'.format(test_acc / (batch_id+1)), )
#     print(results[-3:])
#     print('label: ', label)    

  0%|          | 2/77511 [00:00<1:24:25, 15.30it/s]

 8:36PM KST on Jun 26, 2022 batch_id:  0 , test acc 0.875


  8%|▊         | 6004/77511 [05:16<1:02:59, 18.92it/s]

 8:41PM KST on Jun 26, 2022 batch_id:  6000 , test acc 0.7739960006665556


 15%|█▌        | 12004/77511 [10:33<57:45, 18.90it/s] 

 8:46PM KST on Jun 26, 2022 batch_id:  12000 , test acc 0.7734355470377469


 23%|██▎       | 18004/77511 [15:50<52:41, 18.82it/s]  

 8:51PM KST on Jun 26, 2022 batch_id:  18000 , test acc 0.7729640019998889


 31%|███       | 24004/77511 [21:09<47:21, 18.83it/s]  

 8:57PM KST on Jun 26, 2022 batch_id:  24000 , test acc 0.773368817965918


 39%|███▊      | 30004/77511 [26:27<42:02, 18.84it/s]  

 9:02PM KST on Jun 26, 2022 batch_id:  30000 , test acc 0.7729492350254992


 46%|████▋     | 36004/77511 [31:45<36:39, 18.87it/s]

 9:07PM KST on Jun 26, 2022 batch_id:  36000 , test acc 0.7736833699063915


 54%|█████▍    | 42004/77511 [37:04<31:27, 18.81it/s]  

 9:13PM KST on Jun 26, 2022 batch_id:  42000 , test acc 0.7740024047046499


 62%|██████▏   | 48004/77511 [42:22<26:09, 18.80it/s]

 9:18PM KST on Jun 26, 2022 batch_id:  48000 , test acc 0.7739187725255724


 70%|██████▉   | 54004/77511 [47:40<20:48, 18.83it/s]

 9:23PM KST on Jun 26, 2022 batch_id:  54000 , test acc 0.7739671487565045


 77%|███████▋  | 60004/77511 [52:59<15:28, 18.86it/s]  

 9:29PM KST on Jun 26, 2022 batch_id:  60000 , test acc 0.7742475125414576


 85%|████████▌ | 66004/77511 [58:17<10:11, 18.82it/s]

 9:34PM KST on Jun 26, 2022 batch_id:  66000 , test acc 0.774692807684732


 93%|█████████▎| 72004/77511 [1:03:36<04:52, 18.85it/s]

 9:39PM KST on Jun 26, 2022 batch_id:  72000 , test acc 0.7748538214747017


100%|██████████| 77511/77511 [1:08:28<00:00, 18.87it/s]

CPU times: user 1h 8min 1s, sys: 20.9 s, total: 1h 8min 22s
Wall time: 1h 8min 28s





In [15]:
with open('.cache/test_result_rework_20220626.pickle', "wb") as fw:
    pickle.dump(results, fw)

In [None]:
with open('.cache/test_result.pickle', "rb") as fw:
    result = pickle.load(fw)

In [None]:
result_np = np.array(result)

In [None]:
result_np.shape

In [None]:
type(data_test_id)

In [None]:
len(data_test_id.sentences)

In [None]:
data_test_id.labels[:10]

In [None]:
result_np[0]