In [1]:
import os.path
import time
import pickle
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from kobert import get_tokenizer
from kobert import get_pytorch_kobert_model
from transformers import AdamW, BertModel
from transformers.optimization import get_cosine_schedule_with_warmup
import json
import pandas as pd

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")
model_path = 'ksic_model'

In [3]:
# ??BertModel

#### 모델 저장의 방법
* Pytorch는 모델을 저장할 때 torch.save(object, file) 함수 사용
    * object : 모델 객체, file : 파일 이름
##### 예시1
* torch.save(model, 'model.pt')
* model = torch.load('model.pt')
##### 예시2
* torch.save(model.state_dict(), 'model.pt')
* model.load_state_dict(torch.load('model.pt'))

In [4]:
bertmodel, vocab = get_pytorch_kobert_model(cachedir=".cache")

using cached model. /home/hdh/PycharmProjects/KoBERT-master/.cache/kobert_v1.zip
using cached model. /home/hdh/PycharmProjects/KoBERT-master/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [5]:
# torch.save(bertmodel, 'model_save_test/ksic_bert_model.pt')
# model = torch.load('model_save_test/ksic_bert_model.pt')
# # 이건 왜 성공하지?

In [6]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size=768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate

        self.classifier = nn.Linear(hidden_size, num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)

    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)

        _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(),
                              attention_mask=attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        else:
            out = pooler
        return self.classifier(out)

In [7]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)
model = torch.load(os.path.join(model_path, 'KSIC_KoBERT.pt'))
# Can't get attribute 'BERTClassifier' on <module '__main__'>

In [8]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model. /home/hdh/PycharmProjects/KoBERT-master/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [9]:
# 예측 모델 설정
def predict(predict_sentence):

    data = [predict_sentence, '0']
    dataset_another = [data]

    another_test = BERTDataset(dataset_another, 0, 1, tok, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=5)
    
    model.eval()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length= valid_length
        label = label.long().to(device)

        out = model(token_ids, valid_length, segment_ids)
        max_vals, max_indices = torch.max(out, 1)
#         test_eval=[]
#         for i in out:
#             logits=i
#             logits = logits.detach().cpu().numpy()
#             min_v = min(logits)
#             total = 0
#             probability = []
#             logits = np.round(new_softmax(logits), 3).tolist()
#             for logit in logits:
#                 print(logit)
#                 probability.append(np.round(logit, 3))

#             if np.argmax(logits) == 0:  emotion = "기쁨"
#             elif np.argmax(logits) == 1: emotion = "불안"
#             elif np.argmax(logits) == 2: emotion = '당황'
#             elif np.argmax(logits) == 3: emotion = '슬픔'
#             elif np.argmax(logits) == 4: emotion = '분노'
#             elif np.argmax(logits) == 5: emotion = '상처'

#             probability.append(emotion)
#             print(probability)
    return max_vals, max_indices

In [10]:
# ??torch.max

In [11]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in tqdm(dataset)]
        self.labels = [np.int32(i[label_idx]) for i in tqdm(dataset)]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [12]:
test_input = pd.read_csv('.cache/test_input.csv', encoding='utf-8', low_memory=False)

In [13]:
test_input[:2]

Unnamed: 0,ksic,an,ad,pn,pd,rn,rd,ipc,cpc,title,ab,cl,apg,invt,label
0,38230,10-2003-0069760,20031004,특2003-0083663,20031030,,,B02C13/14,B02C23/08 | B02C13/18 | B07B13/04 | B07B1/28 |...,건설폐기물로부터 시멘트 페이스트 및 모르타르가 제거된재생골재 및 모래를 생산하는 방...,건설폐기물의 폐콘크리트를 레미콘용 재생골재 및 모래로 사용이 가능하도록 시멘트 페이...,폐 콘크리트의 골재 표면으로부터 시멘트 페이스트 및 모르타르를 제거하여 재생골재 및...,삼영플랜트주식회사,기준호 | 기형호,395
1,25942,2019830009004,19831021,2019850003356,19850617,,,F16B23/00,F16B39/10,이탈 방지 장치를 갖는 나사,"케이스(1)에 커버(2)를 부착시키는 나사(3)에 있어서, 커버의 두께(t)보다 큰...",\n,엘지전자 주식회사,,154


In [14]:
test_input.shape  # (211133, 15)

(211133, 15)

In [None]:
test_input_copy = test_input.copy()

In [None]:
test_input_copy['ab'].fillna(test_input_copy['title'], inplace=True)
test_input_copy['cl'].fillna(test_input_copy['ab'], inplace=True)

In [None]:
test_title_ds = test_input_copy[['title', 'label']].copy()
test_title_ds.rename(columns={'title': 'text'}, inplace=True)
test_ab_ds = test_input_copy[['ab', 'label']].copy()
test_ab_ds.rename(columns={'title': 'text'}, inplace=True)
test_cl_ds = test_input_copy[['cl', 'label']].copy()
test_cl_ds.rename(columns={'title': 'text'}, inplace=True)

In [None]:
test_title_ds.to_csv('.cache/test_title_ds.tsv', encoding='utf-8', mode='w', index=False, sep='\t')
test_ab_ds.to_csv('.cache/test_ab_ds.tsv', encoding='utf-8', mode='w', index=False, sep='\t')
test_cl_ds.to_csv('.cache/test_cl_ds.tsv', encoding='utf-8', mode='w', index=False, sep='\t')

In [None]:
test_title_ds = nlp.data.TSVDataset('.cache/test_title_ds.tsv', encoding='utf-8',
                              field_indices=[0, 1], num_discard_samples=1)
test_ab_ds = nlp.data.TSVDataset('.cache/test_ab_ds.tsv', encoding='utf-8',
                              field_indices=[0, 1], num_discard_samples=1)
test_cl_ds = nlp.data.TSVDataset('.cache/test_cl_ds.tsv', encoding='utf-8',
                              field_indices=[0, 1], num_discard_samples=1)
print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # ' 1:36PM EDT on Oct 18, 2010'
print('Loading saved text-label pair dataset completed')
with open('.cache/label_ksic.pickle', 'rb') as f:
    ksic_index_dict = pickle.load(f)
with open('.cache/ksic_label.pickle', 'rb') as f:
    ksic_label_dict = pickle.load(f)

In [15]:
with open('.cache/label_ksic.pickle', 'rb') as f:
    ksic_index_dict = pickle.load(f)
with open('.cache/ksic_label.pickle', 'rb') as f:
    ksic_label_dict = pickle.load(f)

In [107]:
ksic_index_dict

{0: '20411',
 1: '25121',
 2: '20491',
 3: '30110',
 4: '11111',
 5: '23231',
 6: '21210',
 7: '29210',
 8: '19102',
 9: '14192',
 10: '29193',
 11: '24111',
 12: '20132',
 13: '20493',
 14: '15211',
 15: '29180',
 16: '20311',
 17: '27216',
 18: '27192',
 19: '38312',
 20: '13401',
 21: '27199',
 22: '42491',
 23: '24211',
 24: '17110',
 25: '33920',
 26: '10220',
 27: '20321',
 28: '10802',
 29: '10794',
 30: '20111',
 31: '22199',
 32: '29242',
 33: '23232',
 34: '28111',
 35: '13101',
 36: '29150',
 37: '29194',
 38: '23311',
 39: '18121',
 40: '10742',
 41: '24113',
 42: '12000',
 43: '10713',
 44: '19221',
 45: '33992',
 46: '26521',
 47: '20499',
 48: '29191',
 49: '27302',
 50: '19229',
 51: '15219',
 52: '24219',
 53: '25119',
 54: '16102',
 55: '10792',
 56: '10743',
 57: '27219',
 58: '35200',
 59: '20494',
 60: '10403',
 61: '29175',
 62: '20121',
 63: '28121',
 64: '10799',
 65: '28902',
 66: '23119',
 67: '29291',
 68: '29141',
 69: '25922',
 70: '24191',
 71: '10212',
 7

In [16]:
max_len = 256
batch_size = 16
warmup_ratio = 0.1
num_epochs = 5
max_grad_norm = 1
log_interval = 10000
learning_rate = 5e-5

In [None]:
# test_tl_id = BERTDataset(test_title_ds, 0, 1, tok, max_len, True, False)
# test_ab_id = BERTDataset(test_ab_ds, 0, 1, tok, max_len, True, False)
# test_cl_id = BERTDataset(test_cl_ds, 0, 1, tok, max_len, True, False)

In [None]:
# with open('.cache/ksic_test_tl_id.pickle', "wb") as fw:
#     pickle.dump(train_tl_id, fw)
# with open('.cache/ksic_test_ab_id.pickle', "wb") as fw:
#     pickle.dump(train_ab_id, fw)
# with open('.cache/ksic_test_cl_id.pickle', "wb") as fw:
#     pickle.dump(train_cl_id, fw)

In [17]:
with open(".cache/ksic_test_tl_id.pickle", "rb") as fr:
    test_tl_id = pickle.load(fr)
with open('.cache/ksic_test_ab_id.pickle', 'rb') as f:
    test_ab_id = pickle.load(f)
with open('.cache/ksic_test_cl_id.pickle', 'rb') as f:
    test_cl_id = pickle.load(f)

In [18]:
test_tl_dataloader = torch.utils.data.DataLoader(test_tl_id, batch_size=batch_size, num_workers=8, shuffle=False)
test_ab_dataloader = torch.utils.data.DataLoader(test_ab_id, batch_size=batch_size, num_workers=8, shuffle=False)
test_cl_dataloader = torch.utils.data.DataLoader(test_cl_id, batch_size=batch_size, num_workers=8, shuffle=False)

In [19]:
np.shape(test_cl_id)

  result = asarray(a).shape


(211133, 4)

In [20]:
np.asarray(test_cl_id).shape

  np.asarray(test_cl_id).shape


(211133, 4)

In [21]:
def calc_accuracy(X, y):
    max_vals, max_indices = torch.max(X, 1)
    acc = (max_indices == y).sum().data.cpu().numpy()/max_indices.size()[0]
    return acc

In [22]:
%%time
test_acc = 0.0
tl_out = []
tl_max_indices = []
for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_tl_dataloader), total=len(test_tl_dataloader)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
#     results.extend([out.tolist(), max_vals.tolist(), max_indices.tolist()])
#     cl_results.append([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    tl_out.extend(out.tolist())
    tl_max_indices.extend(max_indices.tolist())
    test_acc += calc_accuracy(out, label)
    if batch_id % 3000 == 0:
        print(time.strftime('%l:%M%p %Z on %b %d, %Y'), 'batch_id: ', batch_id, ', test acc {}'.format(test_acc / (batch_id+1)), )

  0%|          | 3/13196 [00:00<25:16,  8.70it/s]

 9:14AM KST on Jul 06, 2022 batch_id:  0 , test acc 0.8125


  1%|          | 93/13196 [00:09<21:26, 10.18it/s]


KeyboardInterrupt: 

In [None]:
with open('.cache/test_tl_pred_20220627.pickle', "wb") as fw:
    pickle.dump(tl_results, fw)
with open('.cache/test_tl_pred_20220627.pickle', "rb") as fw:
    tl_results = pickle.load(fw)

In [68]:
%%time
test_acc = 0.0
ab_out = []
ab_max_indices = []
for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_ab_dataloader), total=len(test_ab_dataloader)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
#     results.extend([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    ab_out.extend(out.tolist())
    ab_max_indices.extend(max_indices.tolist())
    test_acc += calc_accuracy(out, label)
    if batch_id % 10000 == 0:
        print(time.strftime('%l:%M%p %Z on %b %d, %Y'), 'batch_id: ', batch_id, ', test acc {}'.format(test_acc / (batch_id+1)), )
# with open('.cache/test_ab_pred_20220627.pickle', "wb") as fw:
#     pickle.dump(ab_results, fw)
# with open('.cache/test_ab_pred_20220627.pickle', "rb") as fw:
#     ab_results = pickle.load(fw)

  0%|          | 2/13196 [00:00<23:06,  9.51it/s]

 4:47PM KST on Jun 28, 2022 batch_id:  0 , test acc 0.875


 76%|███████▌  | 10002/13196 [16:08<05:09, 10.30it/s]

 5:03PM KST on Jun 28, 2022 batch_id:  10000 , test acc 0.7741600839916009


100%|██████████| 13196/13196 [21:20<00:00, 10.31it/s]


NameError: name 'ab_results' is not defined

In [71]:
%%time
test_acc = 0.0
cl_out = []
cl_max_indices = []
for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_cl_dataloader), total=len(test_cl_dataloader)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
#     results.extend([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    cl_out.extend(out.tolist())
    cl_max_indices.extend(max_indices.tolist())
    test_acc += calc_accuracy(out, label)
    if batch_id % 10000 == 0:
        print(time.strftime('%l:%M%p %Z on %b %d, %Y'), 'batch_id: ', batch_id, ', test acc {}'.format(test_acc / (batch_id+1)), )

  0%|          | 3/13196 [00:00<22:32,  9.75it/s]

 5:24PM KST on Jun 28, 2022 batch_id:  0 , test acc 0.75


 76%|███████▌  | 10003/13196 [16:09<05:09, 10.32it/s]

 5:40PM KST on Jun 28, 2022 batch_id:  10000 , test acc 0.7438818618138187


100%|██████████| 13196/13196 [21:20<00:00, 10.30it/s]

CPU times: user 21min 15s, sys: 4.68 s, total: 21min 20s
Wall time: 21min 21s





In [None]:
# with open('.cache/test_cl_pred_20220627.pickle', "wb") as fw:
#     pickle.dump(cl_results, fw)

In [None]:
# with open('.cache/test_tl_pred_20220627.pickle', "rb") as fw:
#     tl_results = pickle.load(fw)
# with open('.cache/test_ab_pred_20220627.pickle', "rb") as fw:
#     ab_results = pickle.load(fw)
# with open('.cache/test_cl_pred_20220627.pickle', "rb") as fw:
#     cl_results = pickle.load(fw)

In [None]:
tl_result_np = np.array(tl_results)

In [None]:
tl_result_np.shape  # (26392, 3)

In [None]:
np.argmax(tl_result_np[0][0][1])

In [None]:
tl_result_np[0][2]  # batch별 label predict

In [None]:
tl_result_np[0][0]  # batch별 logit

In [None]:
len(tl_result_np[0][0][7])

In [41]:
%%time
test_acc = 0.0
tl_out = []
tl_max_indices = []
for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_tl_dataloader), total=len(test_tl_dataloader)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
#     results.extend([out.tolist(), max_vals.tolist(), max_indices.tolist()])
#     cl_results.append([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    tl_out.extend(out.tolist())
    tl_max_indices.extend(max_indices.tolist())
    test_acc += calc_accuracy(out, label)
    if batch_id % 3000 == 0:
        print(time.strftime('%l:%M%p %Z on %b %d, %Y'), 'batch_id: ', batch_id, ', test acc {}'.format(test_acc / (batch_id+1)), )

  0%|          | 2/13196 [00:00<23:27,  9.38it/s]

 3:47PM KST on Jun 28, 2022 batch_id:  0 , test acc 0.8125


 23%|██▎       | 3002/13196 [04:50<16:28, 10.31it/s]

 3:51PM KST on Jun 28, 2022 batch_id:  3000 , test acc 0.7748250583138954


 45%|████▌     | 6002/13196 [09:41<11:37, 10.31it/s]

 3:56PM KST on Jun 28, 2022 batch_id:  6000 , test acc 0.7742042992834528


 68%|██████▊   | 9002/13196 [14:32<06:46, 10.31it/s]

 4:01PM KST on Jun 28, 2022 batch_id:  9000 , test acc 0.7740528830129986


 91%|█████████ | 12002/13196 [19:23<01:55, 10.31it/s]

 4:06PM KST on Jun 28, 2022 batch_id:  12000 , test acc 0.7745239980001667


100%|██████████| 13196/13196 [21:19<00:00, 10.31it/s]

CPU times: user 21min 14s, sys: 4.06 s, total: 21min 18s
Wall time: 21min 19s





In [46]:
label_np = np.array(test_tl_id)[:,3]

In [69]:
tl_out_np = np.array(tl_out)
tl_max_indices_np = np.array(tl_max_indices)
# np.array(tl_out).argmax(axis=1)
# print(tl_out_np.shape, tl_max_indices_np.shape)
tl_out_top3 = tl_out_np.argsort()[:, -3:]
cnt = 0
for i, lab in enumerate(label_np):
    if lab in tl_out_top3[i]:
        cnt += 1 
acc_top3 = cnt / label_np.shape[0]
print('Top-3 Guess: ', acc_top3)
tl_out_top5 = tl_out_np.argsort()[:, -5:]
cnt = 0
for i, lab in enumerate(label_np):
    if lab in tl_out_top5[i]:
        cnt += 1 
acc_top5 = cnt / label_np.shape[0]
print('Top-5 Guess: ', acc_top5)

Top-3 Guess:  0.8985615701950903
Top-5 Guess:  0.9285237267504369


In [70]:
ab_out_np = np.array(ab_out)
ab_max_indices_np = np.array(ab_max_indices)
# np.array(ab_out).argmax(axis=1)
# print(ab_out_np.shape, ab_max_indices_np.shape)
ab_out_top3 = ab_out_np.argsort()[:, -3:]
cnt = 0
for i, lab in enumerate(label_np):
    if lab in ab_out_top3[i]:
        cnt += 1 
ab_acc_top3 = cnt / label_np.shape[0]
print('Top-3 Guess: ', ab_acc_top3)
ab_out_top5 = ab_out_np.argsort()[:, -5:]
cnt = 0
for i, lab in enumerate(label_np):
    if lab in ab_out_top5[i]:
        cnt += 1 
ab_acc_top5 = cnt / label_np.shape[0]
print('Top-5 Guess: ', ab_acc_top5)

Top-3 Guess:  0.8812075800561732
Top-5 Guess:  0.9050551074441229


In [72]:
cl_out_np = np.array(cl_out)
cl_max_indices_np = np.array(cl_max_indices)
# np.array(tl_out).argmax(axis=1)
# print(tl_out_np.shape, tl_max_indices_np.shape)
cl_out_top3 = cl_out_np.argsort()[:, -3:]
cnt = 0
for i, lab in enumerate(label_np):
    if lab in cl_out_top3[i]:
        cnt += 1 
cl_acc_top3 = cnt / label_np.shape[0]
print('Top-3 Guess: ', cl_acc_top3)
cl_out_top5 = cl_out_np.argsort()[:, -5:]
cnt = 0
for i, lab in enumerate(label_np):
    if lab in cl_out_top5[i]:
        cnt += 1 
cl_acc_top5 = cnt / label_np.shape[0]
print('Top-5 Guess: ', cl_acc_top5)

Top-3 Guess:  0.8628163290437781
Top-5 Guess:  0.8921153964562621


In [99]:
test_input_copy = test_input.copy()

In [100]:
test_input_copy[:2]

Unnamed: 0,ksic,an,ad,pn,pd,rn,rd,ipc,cpc,title,ab,cl,apg,invt,label
0,38230,10-2003-0069760,20031004,특2003-0083663,20031030,,,B02C13/14,B02C23/08 | B02C13/18 | B07B13/04 | B07B1/28 |...,건설폐기물로부터 시멘트 페이스트 및 모르타르가 제거된재생골재 및 모래를 생산하는 방...,건설폐기물의 폐콘크리트를 레미콘용 재생골재 및 모래로 사용이 가능하도록 시멘트 페이...,폐 콘크리트의 골재 표면으로부터 시멘트 페이스트 및 모르타르를 제거하여 재생골재 및...,삼영플랜트주식회사,기준호 | 기형호,395
1,25942,2019830009004,19831021,2019850003356,19850617,,,F16B23/00,F16B39/10,이탈 방지 장치를 갖는 나사,"케이스(1)에 커버(2)를 부착시키는 나사(3)에 있어서, 커버의 두께(t)보다 큰...",\n,엘지전자 주식회사,,154


In [114]:
tl_top3_pd = pd.DataFrame(tl_out_top3, columns=['tl_top3', 'tl_top2', 'tl_top1'])
ab_top3_pd = pd.DataFrame(ab_out_top3, columns=['ab_top3', 'ab_top2', 'ab_top1'])
cl_top3_pd = pd.DataFrame(cl_out_top3, columns=['cl_top3', 'cl_top2', 'cl_top1'])
test_input_copy = pd.concat([test_input_copy, tl_top3_pd], axis=1)
test_input_copy = pd.concat([test_input_copy, ab_top3_pd], axis=1)
test_input_copy = pd.concat([test_input_copy, cl_top3_pd], axis=1)
test_input_copy['ksic_tl_top3'] = test_input_copy['tl_top3'].map(ksic_index_dict)
test_input_copy['ksic_tl_top2'] = test_input_copy['tl_top2'].map(ksic_index_dict)
test_input_copy['ksic_tl_top1'] = test_input_copy['tl_top1'].map(ksic_index_dict)
test_input_copy['ksic_ab_top3'] = test_input_copy['ab_top3'].map(ksic_index_dict)
test_input_copy['ksic_ab_top2'] = test_input_copy['ab_top2'].map(ksic_index_dict)
test_input_copy['ksic_ab_top1'] = test_input_copy['ab_top1'].map(ksic_index_dict)
test_input_copy['ksic_cl_top3'] = test_input_copy['cl_top3'].map(ksic_index_dict)
test_input_copy['ksic_cl_top2'] = test_input_copy['cl_top2'].map(ksic_index_dict)
test_input_copy['ksic_cl_top1'] = test_input_copy['cl_top1'].map(ksic_index_dict)

In [115]:
test_input_copy[:2]

Unnamed: 0,ksic,an,ad,pn,pd,rn,rd,ipc,cpc,title,...,cl_top1,ksic_tl_top3,ksic_tl_top2,ksic_tl_top1,ksic_ab_top3,ksic_ab_top2,ksic_ab_top1,ksic_cl_top3,ksic_cl_top2,ksic_cl_top1
0,38230,10-2003-0069760,20031004,특2003-0083663,20031030,,,B02C13/14,B02C23/08 | B02C13/18 | B07B13/04 | B07B1/28 |...,건설폐기물로부터 시멘트 페이스트 및 모르타르가 제거된재생골재 및 모래를 생산하는 방...,...,395,29241,42110,38230,38311,42110,38230,29242,42110,38230
1,25942,2019830009004,19831021,2019850003356,19850617,,,F16B23/00,F16B39/10,이탈 방지 장치를 갖는 나사,...,276,27192,25941,25942,27192,25941,25942,33309,33303,33993


In [116]:
from sklearn.metrics import confusion_matrix

In [123]:
pd.DataFrame(confusion_matrix(test_input_copy['ksic'],
                              test_input_copy['ksic_tl_top1'], labels=True)).to_csv('.cache/confusion_tl.csv',
                                                                       encoding='utf-8', mode='w', index=False, sep='\t')
pd.DataFrame(confusion_matrix(test_input_copy['ksic'],
                              test_input_copy['ksic_ab_top1'], labels=True)).to_csv('.cache/confusion_ab.csv',
                                                                       encoding='utf-8', mode='w', index=False, sep='\t')
pd.DataFrame(confusion_matrix(test_input_copy['ksic'], 
                              test_input_copy['ksic_cl_top1'], labels=True)).to_csv('.cache/confusion_cl.csv',
                                                                       encoding='utf-8', mode='w', index=False, sep='\t')

TypeError: '<' not supported between instances of 'bool' and 'str'

In [124]:
??confusion_matrix

In [120]:
class BERTDataset_no_label(Dataset):
    def __init__(self, dataset, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i]) for i in tqdm(dataset['text'])]
#         self.labels = [np.int32(i[label_idx]) for i in tqdm(dataset)]

    def __getitem__(self, i):
        return (self.sentences[i])

    def __len__(self):
        return (len(self.sentences))

### 20220602 공개건 대상 테스트

In [54]:
test_input2 = pd.read_excel('.cache/pd20220602_testset.xlsx')
test_input2_copy = test_input2.copy()
# test_input2_copy['요약'].fillna(test_input2_copy['발명의명칭'], inplace=True)
# test_input2_copy['대표청구항'].fillna(test_input2_copy['요약'], inplace=True)
test_title2_ds = test_input2_copy[['발명의명칭']].copy()
test_title2_ds.rename(columns={'발명의명칭': 'text'}, inplace=True)
test_tl_id2 = BERTDataset_no_label(test_title2_ds, tok, max_len, True, False)
test_tl_dataloader2 = torch.utils.data.DataLoader(test_tl_id2, batch_size=batch_size, num_workers=8, shuffle=False)


  0%|          | 0/1462 [05:24<?, ?it/s]
100%|██████████| 1462/1462 [00:00<00:00, 20946.16it/s]


In [56]:
%%time
test_acc2 = 0.0
tl_out2 = []
tl_max_indices2 = []
for batch_id, (token_ids, valid_length, segment_ids) in tqdm(enumerate(test_tl_dataloader2), total=len(test_tl_dataloader2)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
#     label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
#     results.extend([out.tolist(), max_vals.tolist(), max_indices.tolist()])
#     cl_results.append([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    tl_out2.extend(out.tolist())
    tl_max_indices2.extend(max_indices.tolist())
#     test_acc += calc_accuracy(out, label)
#     if batch_id % 3000 == 0:
#         print(time.strftime('%l:%M%p %Z on %b %d, %Y'), 'batch_id: ', batch_id, ', test acc {}'.format(test_acc / (batch_id+1)), )

100%|██████████| 92/92 [00:08<00:00, 10.36it/s]

CPU times: user 8.77 s, sys: 202 ms, total: 8.97 s
Wall time: 9.11 s





In [79]:
tl_out2_topn = np.array(tl_out2).argsort()
tl_out2_top3 = tl_out2_topn[:,-3:]
tl_top3_pd = pd.DataFrame(tl_out2_top3, columns=['tl_top3', 'tl_top2', 'tl_top1'])
test_input_copy2 = pd.concat([test_input2_copy, tl_top3_pd], axis=1)
test_input_copy2['ksic_tl_top3'] = test_input_copy2['tl_top3'].map(ksic_index_dict)
test_input_copy2['ksic_tl_top2'] = test_input_copy2['tl_top2'].map(ksic_index_dict)
test_input_copy2['ksic_tl_top1'] = test_input_copy2['tl_top1'].map(ksic_index_dict)


In [82]:
test_input_copy2.to_excel('.cache/pd20220602_testset_result.xlsx')

### 한문장씩 입력으로 받아서 예측하는 함수

In [165]:
ksic_label = pd.read_csv('.cache/KSIC_567_label.csv', encoding='utf-8', delimiter='\t')

In [140]:
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)
transform = nlp.data.BERTSentenceTransform(tok, max_seq_length=256, pad=True, pair=False)

In [197]:
def sample_pred(sample):
    sample = [sample]
    sample_id = transform(sample)
    # sample_id[0]
    # test_tl_dataloader = torch.utils.data.DataLoader(sample_id, batch_size=batch_size, num_workers=8, shuffle=False)

    out = model(torch.tensor([sample_id[0]]).long().to(device),
                [sample_id[1]],
                torch.tensor([sample_id[2]]).long().to(device))
    result = out.tolist()
    result_top3 = np.array(result[0]).argsort()[-3:]  # ['27199', '26329', '27112']
    for i, pred in enumerate(result_top3[::-1]):
        print('top ', i, ': ', ksic_label[ksic_label['코드']==ksic_index_dict[pred]]['코드'].to_string(index=False),
             ksic_label[ksic_label['코드']==ksic_index_dict[pred]]['항목명'].to_string(index=False))

In [223]:
sample_pred(input('텍스트를 입력해주세요\n'))

텍스트를 입력해주세요
본 발명은 천마를 주성분으로 하는 건강음료이다. 더 상세하게는 천마의 불쾌취 성분을 최소화함으로써 유용 성분을 다량 함유한 천마를 부담 없이 섭취할 수 있도록 하고, 또한 이를 전통적인 한약재와 적절히 배합하여 우수한 기능성을 부여한 건강음료에 관한 것이다. 본 발명은 천마 추출액 및, 백복령, 감초, 숙지황, 백작약, 천궁, 당귀, 대추, 오미자, 구기자로 구성된 한약재의 추출액을 포함하는 것을 특징으로 한다.
top  0 :  10309 기타 과실ㆍ채소 가공 및 저장 처리업
top  1 :  21220 한의약품 제조업
top  2 :  10792 차류 가공업


In [158]:
# 간섭 패턴에 기초하여 오브젝트에 대한 복수의 깊이 정보들을 획득함으로써 오브젝트를 스캔한다.


['27199', '26329', '27112']

Unnamed: 0,연번,코드,항목명
312,313,27112,전기식 진단 및 요법 기기 제조업


### 20220331 국가핵심 납품 목록 테스트
* 2022.01~03까지 납품된 34314건 중, KSIC는 1565건에 부여되어 있으며,
* 출원번호로 키위에서 확인 결과 501건이 공개됨(공개, 등록 모두 포함, 출원번호 10-2022-0007684 같은 중복 존재)

#### 발명의 명칭으로 예측

In [130]:
test_input3 = pd.read_excel('.cache/20220331_ksic_testset.xlsx')
test_input3_copy = test_input3.copy()
# test_input2_copy['요약'].fillna(test_input2_copy['발명의명칭'], inplace=True)
# test_input2_copy['대표청구항'].fillna(test_input2_copy['요약'], inplace=True)
test_title3_ds = test_input3_copy[['발명의명칭']].copy()
test_title3_ds.rename(columns={'발명의명칭': 'text'}, inplace=True)
test_tl_id3 = BERTDataset_no_label(test_title3_ds, tok, max_len, True, False)
test_tl_dataloader3 = torch.utils.data.DataLoader(test_tl_id3, batch_size=batch_size, num_workers=8, shuffle=False)


100%|██████████| 501/501 [00:00<00:00, 20081.48it/s]


In [131]:
%%time
test_acc3 = 0.0
tl_out3 = []
tl_max_indices3 = []
for batch_id, (token_ids, valid_length, segment_ids) in tqdm(enumerate(test_tl_dataloader3), total=len(test_tl_dataloader3)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
#     label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
#     results.extend([out.tolist(), max_vals.tolist(), max_indices.tolist()])
#     cl_results.append([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    tl_out3.extend(out.tolist())
    tl_max_indices3.extend(max_indices.tolist())
#     test_acc += calc_accuracy(out, label)
#     if batch_id % 3000 == 0:
#         print(time.strftime('%l:%M%p %Z on %b %d, %Y'), 'batch_id: ', batch_id, ', test acc {}'.format(test_acc / (batch_id+1)), )

100%|██████████| 32/32 [00:03<00:00, 10.44it/s]

CPU times: user 3.03 s, sys: 190 ms, total: 3.22 s
Wall time: 3.31 s





In [132]:
tl_out3_topn = np.array(tl_out3).argsort()
tl_out3_top3 = tl_out3_topn[:,-3:]
tl_top3_pd = pd.DataFrame(tl_out3_top3, columns=['tl_top3', 'tl_top2', 'tl_top1'])
test_input_copy3 = pd.concat([test_input3_copy, tl_top3_pd], axis=1)
test_input_copy3['ksic_tl_top3'] = test_input_copy3['tl_top3'].map(ksic_index_dict)
test_input_copy3['ksic_tl_top2'] = test_input_copy3['tl_top2'].map(ksic_index_dict)
test_input_copy3['ksic_tl_top1'] = test_input_copy3['tl_top1'].map(ksic_index_dict)


In [237]:
test_input_copy3.to_excel('.cache/20220331_ksic_testset_result.xlsx')

In [233]:
test_input3[:2]

Unnamed: 0,특허번호,출원번호,발행번호,Original CPC,Current CPC,요약,대표청구항,발명의명칭
0,KR102380739 B1,KR102022000001315,KR10002380739B1,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,"\t\t본 발명에서는, 3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서, ...","3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서,상기 전기자동차 급속충전장...",3상4선식 필터를 적용한 전기자동차 급속충전장치
1,KR1020220010046 A,KR102022000002985,KR102022000010046A,H04W56/0015 | H04J11/0073 | H04J11/0076 | H04J...,H04W56/0015 | H04W92/18 | H04J11/0073 | H04J11...,\t\t본 발명은 D2D 통신을 위한 동기화 신호 구성 방법 및 장치에 관한 것이다...,"제1 무선 사용자 장치(wireless user device)에 있어서,제2 무선 ...",D2D 통신을 위한 동기화 신호 구성 방법 및 장치


In [133]:
test_result3 = pd.read_excel('.cache/20220331_ksic_testset_result2.xlsx')

In [160]:
test_result3_copy=test_result3[['특허번호', '출원번호', 'Original CPC', 'Current CPC', '발명의명칭', '요약', '대표청구항', '분류원 결과',
             'tl_top3', 'tl_top2', 'tl_top1', 'ksic_tl_top3', 'ksic_tl_top2', 'ksic_tl_top1']].copy()

In [135]:
# test_result3_copy[:2]

In [136]:
test_result3_copy['분류원 결과2'] = test_result3_copy['분류원 결과'].str.lower()
test_result3_copy['분류원 결과2'] = test_result3_copy['분류원 결과2'].str.replace('c','')
# test_result3_copy['분류원 결과2'] = test_result3_copy['분류원 결과2'].str.split(',')

In [137]:
test_result3_copy['top1_acc'] = test_result3_copy.apply(
    lambda x:'30332' in test_result3_copy['분류원 결과'], axis=1)
# test_result3_copy.loc[test_result3_copy['분류원 결과'].str.contains(str(test_result3_copy['ksic_tl_top1']))]

In [161]:
test_result3_copy[:1]

Unnamed: 0,특허번호,출원번호,Original CPC,Current CPC,발명의명칭,요약,대표청구항,분류원 결과,tl_top3,tl_top2,tl_top1,ksic_tl_top3,ksic_tl_top2,ksic_tl_top1
0,KR102380739 B1,KR102022000001315,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,3상4선식 필터를 적용한 전기자동차 급속충전장치,"\t\t본 발명에서는, 3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서, ...","3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서,상기 전기자동차 급속충전장...",C30332,75,285,349,28119,30331,30332


#### 요약으로 예측

In [139]:
# test_input3 = pd.read_excel('.cache/20220331_ksic_testset.xlsx')
# test_input3_copy = test_input3.copy()
# test_input2_copy['요약'].fillna(test_input2_copy['발명의명칭'], inplace=True)
# test_input2_copy['대표청구항'].fillna(test_input2_copy['요약'], inplace=True)
test_ab3_ds = test_input3_copy[['요약']].copy()
test_ab3_ds.rename(columns={'요약': 'text'}, inplace=True)
test_ab_id3 = BERTDataset_no_label(test_ab3_ds, tok, max_len, True, False)
test_ab_dataloader3 = torch.utils.data.DataLoader(test_ab_id3, batch_size=batch_size, num_workers=8, shuffle=False)

100%|██████████| 501/501 [00:00<00:00, 2810.33it/s]


In [140]:
%%time
test_acc3 = 0.0
ab_out3 = []
ab_max_indices3 = []
for batch_id, (token_ids, valid_length, segment_ids) in tqdm(enumerate(test_ab_dataloader3), total=len(test_ab_dataloader3)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
#     label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
#     results.extend([out.tolist(), max_vals.tolist(), max_indices.tolist()])
#     cl_results.append([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    ab_out3.extend(out.tolist())
    ab_max_indices3.extend(max_indices.tolist())
#     test_acc += calc_accuracy(out, label)
#     if batch_id % 3000 == 0:
#         print(time.strftime('%l:%M%p %Z on %b %d, %Y'), 'batch_id: ', batch_id, ', test acc {}'.format(test_acc / (batch_id+1)), )

100%|██████████| 32/32 [00:03<00:00, 10.48it/s]

CPU times: user 3.02 s, sys: 200 ms, total: 3.22 s
Wall time: 3.29 s





In [162]:
ab_out3_topn = np.array(ab_out3).argsort()
ab_out3_top3 = ab_out3_topn[:,-3:]
ab_top3_pd = pd.DataFrame(ab_out3_top3, columns=['ab_top3', 'ab_top2', 'ab_top1'])
test_result3_copy = pd.concat([test_result3_copy, ab_top3_pd], axis=1)
test_result3_copy['ksic_ab_top3'] = test_result3_copy['ab_top3'].map(ksic_index_dict)
test_result3_copy['ksic_ab_top2'] = test_result3_copy['ab_top2'].map(ksic_index_dict)
test_result3_copy['ksic_ab_top1'] = test_result3_copy['ab_top1'].map(ksic_index_dict)


In [163]:
test_result3_copy[:2]

Unnamed: 0,특허번호,출원번호,Original CPC,Current CPC,발명의명칭,요약,대표청구항,분류원 결과,tl_top3,tl_top2,tl_top1,ksic_tl_top3,ksic_tl_top2,ksic_tl_top1,ab_top3,ab_top2,ab_top1,ksic_ab_top3,ksic_ab_top2,ksic_ab_top1
0,KR102380739 B1,KR102022000001315,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,3상4선식 필터를 적용한 전기자동차 급속충전장치,"\t\t본 발명에서는, 3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서, ...","3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서,상기 전기자동차 급속충전장...",C30332,75,285,349,28119,30331,30332,338,285,349,29161,30331,30332
1,KR1020220010046 A,KR102022000002985,H04W56/0015 | H04J11/0073 | H04J11/0076 | H04J...,H04W56/0015 | H04W92/18 | H04J11/0073 | H04J11...,D2D 통신을 위한 동기화 신호 구성 방법 및 장치,\t\t본 발명은 D2D 통신을 위한 동기화 신호 구성 방법 및 장치에 관한 것이다...,"제1 무선 사용자 장치(wireless user device)에 있어서,제2 무선 ...",C26422,225,447,284,26421,42321,26429,225,447,284,26421,42321,26429


#### TL+AB 예측

In [147]:
test_result3_copy['tl_ab'] = test_result3_copy[['발명의명칭', '요약']].agg(' '.join, axis=1)

In [149]:
test_result3_copy[:1]

Unnamed: 0,특허번호,Original CPC,Current CPC,발명의명칭,요약,대표청구항,tl_top3,tl_top2,tl_top1,ksic_tl_top3,...,분류원 결과,분류원 결과2,top1_acc,ab_top3,ab_top2,ab_top1,ksic_ab_top3,ksic_ab_top2,ksic_ab_top1,tl_ab
0,KR102380739 B1,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,3상4선식 필터를 적용한 전기자동차 급속충전장치,"\t\t본 발명에서는, 3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서, ...","3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서,상기 전기자동차 급속충전장...",75,285,349,28119,...,C30332,30332,False,338,285,349,29161,30331,30332,"3상4선식 필터를 적용한 전기자동차 급속충전장치 \t\t본 발명에서는, 3상4선식 ..."


In [151]:
test_tl_ab3_ds = test_result3_copy[['tl_ab']].copy()
test_tl_ab3_ds.rename(columns={'tl_ab': 'text'}, inplace=True)
test_tl_ab_id3 = BERTDataset_no_label(test_tl_ab3_ds, tok, max_len, True, False)
test_tl_ab_dataloader3 = torch.utils.data.DataLoader(test_tl_ab_id3, batch_size=batch_size, num_workers=8, shuffle=False)
# %%time
test_acc3 = 0.0
tl_ab_out3 = []
tl_ab_max_indices3 = []
for batch_id, (token_ids, valid_length, segment_ids) in tqdm(enumerate(test_tl_ab_dataloader3), total=len(test_tl_ab_dataloader3)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
#     label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
#     results.extend([out.tolist(), max_vals.tolist(), max_indices.tolist()])
#     cl_results.append([out.tolist(), max_vals.tolist(), max_indices.tolist()])
    tl_ab_out3.extend(out.tolist())
    tl_ab_max_indices3.extend(max_indices.tolist())
#     test_acc += calc_accuracy(out, label)
#     if batch_id % 3000 == 0:
#         print(time.strftime('%l:%M%p %Z on %b %d, %Y'), 'batch_id: ', batch_id, ', test acc {}'.format(test_acc / (batch_id+1)), )

100%|██████████| 501/501 [00:00<00:00, 2677.58it/s]
100%|██████████| 32/32 [00:03<00:00, 10.41it/s]


In [164]:
tl_ab_out3_topn = np.array(tl_ab_out3).argsort()
tl_ab_out3_top3 = tl_ab_out3_topn[:,-3:]
tl_ab_top3_pd = pd.DataFrame(tl_ab_out3_top3, columns=['tl_ab_top3', 'tl_ab_top2', 'tl_ab_top1'])
test_result3_copy = pd.concat([test_result3_copy, tl_ab_top3_pd], axis=1)
test_result3_copy['ksic_tl_ab_top3'] = test_result3_copy['tl_ab_top3'].map(ksic_index_dict)
test_result3_copy['ksic_tl_ab_top2'] = test_result3_copy['tl_ab_top2'].map(ksic_index_dict)
test_result3_copy['ksic_tl_ab_top1'] = test_result3_copy['tl_ab_top1'].map(ksic_index_dict)


In [165]:
test_result3_copy[:2]

Unnamed: 0,특허번호,출원번호,Original CPC,Current CPC,발명의명칭,요약,대표청구항,분류원 결과,tl_top3,tl_top2,...,ab_top1,ksic_ab_top3,ksic_ab_top2,ksic_ab_top1,tl_ab_top3,tl_ab_top2,tl_ab_top1,ksic_tl_ab_top3,ksic_tl_ab_top2,ksic_tl_ab_top1
0,KR102380739 B1,KR102022000001315,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,B60L53/11 | B60L53/24 | B60L53/35 | H02M1/0064...,3상4선식 필터를 적용한 전기자동차 급속충전장치,"\t\t본 발명에서는, 3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서, ...","3상4선식 필터를 적용한 전기자동차 급속충전장치에 있어서,상기 전기자동차 급속충전장...",C30332,75,285,...,349,29161,30331,30332,338,285,349,29161,30331,30332
1,KR1020220010046 A,KR102022000002985,H04W56/0015 | H04J11/0073 | H04J11/0076 | H04J...,H04W56/0015 | H04W92/18 | H04J11/0073 | H04J11...,D2D 통신을 위한 동기화 신호 구성 방법 및 장치,\t\t본 발명은 D2D 통신을 위한 동기화 신호 구성 방법 및 장치에 관한 것이다...,"제1 무선 사용자 장치(wireless user device)에 있어서,제2 무선 ...",C26422,225,447,...,284,26421,42321,26429,225,447,284,26421,42321,26429


In [166]:
test_result3_copy.to_excel('.cache/20220331_ksic_tl_ab_testset_result.xlsx')