<a href="https://colab.research.google.com/github/Johyeyoung/capstone-2021-35/blob/master/Transaction/KoBERT/fine_tunning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## 라이브러리 설치

In [2]:
# 이는 구글 코랩으로 돌린 버전입니다. 그리고 기본 코드는 SKT Brain의 KoBERT를 그대로 가져왔고, 학습 및 테스트 데이터셋만 따로 준비한 것입니다.
# SKT Brain github 주소는 다음과 같습니다. https://github.com/SKTBrain/KoBERT

!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3 # 최신 버전으로 설치하면 "Input: must be Tensor, not str" 라는 에러 발생
!pip install torch

!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Collecting mxnet
[?25l  Downloading https://files.pythonhosted.org/packages/30/07/66174e78c12a3048db9039aaa09553e35035ef3a008ba3e0ed8d2aa3c47b/mxnet-1.8.0.post0-py2.py3-none-manylinux2014_x86_64.whl (46.9MB)
[K     |████████████████████████████████| 46.9MB 64kB/s 
[?25hCollecting graphviz<0.9.0,>=0.8.1
  Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl
Installing collected packages: graphviz, mxnet
  Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.8.4 mxnet-1.8.0.post0
Collecting gluonnlp
[?25l  Downloading https://files.pythonhosted.org/packages/9c/81/a238e47ccba0d7a61dcef4e0b4a7fd4473cb86bed3d84dd4fe28d45a0905/gluonnlp-0.10.0.tar.gz (344kB)
[K     |████████████████████████████████| 348kB 29.9MB/s 
Building wheels for collected packages: gluonnlp
  Building wheel fo

## KoBERT 불러오기

In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

##GPU 사용 시
device = torch.device("cuda:0")
bertmodel, vocab = get_pytorch_kobert_model()

# 기본 Bert tokenizer 사용
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)


class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) 

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))
        

[██████████████████████████████████████████████████]
[██████████████████████████████████████████████████]
using cached model


## 데이터 로드

In [4]:
# 학습용 데이터셋 불러오기
import pandas as pd
file_path = '/content/drive/MyDrive/Trigger_list_final.csv'
new_data = pd.read_csv(file_path, names=['type_num', 'content', 'type'])

# 데이터 전처리
new_data = new_data.dropna(axis=0) # 결측값 제거
new_data.drop(['type_num'], axis=1, inplace=True)
new_data.head()

Unnamed: 0,content,type
1,ADAS 꺼달라함,ADAS_고장
2,ADAS 시스템 점검경고등 및 시스템점검 문구 뜸,ADAS_고장
3,ADAS 장애 운행에는 지장이 없음,ADAS_고장
4,ADAS 프로그램 점검 안내 점등되었다고인입,ADAS_고장
5,ADSD경고등으로 주행에는 문제 없으나 이용에 불편하실지 질의 괜찮음,ADAS_고장


In [5]:

# 같은 내용의 카테고리 합치기
new_data["type"] = new_data["type"].str.replace(pat=r'하이패스_카드분실', repl=r'하이패스_분실', regex=True)
new_data["type"] = new_data["type"].str.replace(pat=r'트렁크_파손', repl=r'트렁크_고장', regex=True)
new_data["type"] = new_data["type"].str.replace(pat=r'거치대_(분실|파손)', repl=r'내비게이션_파손', regex=True)
new_data["type"] = new_data["type"].str.replace(pat=r'타이어_파스/펑크\(후우\)', repl=r'타이어_파스/펑크', regex=True)
new_data["type"] = new_data["type"].str.replace(pat=r'타이어_파스/펑크\(후좌\)', repl=r'타이어_파스/펑크', regex=True)
new_data["type"] = new_data["type"].str.replace(pat=r'타이어_파스/펑크\(전우\)', repl=r'타이어_파스/펑크', regex=True)
new_data["type"] = new_data["type"].str.replace(pat=r'타이어_파스/펑크\(전좌\)', repl=r'타이어_파스/펑크', regex=True)
new_data["type"] = new_data["type"].str.replace(pat=r'블랙박스_탈거', repl=r'블랙박스_고장', regex=True)


# 장애요인 라벨링
from sklearn.preprocessing import LabelEncoder
encoder = LabelEncoder()
encoder.fit(new_data['type'])
new_data['type'] = encoder.transform(new_data['type'])
new_data.head()


# 라벨링된 type 매핑 ex) {0: '타이어_파스/펑크', 1: '내비게이션_고장', 2: '블랙박스고장', 3: '하이패스 분실', 4: '주행_차량떨림'}
mapping = dict(zip(range(len(encoder.classes_)), encoder.classes_))
print(mapping)

# Train / Test set 분리
from sklearn.model_selection import train_test_split
train, test = train_test_split(new_data, test_size=0.2, random_state=42)
print("train shape is:", len(train))
print("test shape is:", len(test))


{0: 'ADAS_고장', 1: 'ADAS_파손', 2: 'USB포트_고장', 3: '경고등_ABS', 4: '경고등_DPF(배출가스)', 5: '경고등_ESC(차세제어)', 6: '경고등_배터리', 7: '경고등_브레이크', 8: '경고등_스패너(서비스알림)', 9: '경고등_에어백', 10: '경고등_엔진오일', 11: '경고등_엔진체크', 12: '경고등_연료부족', 13: '경고등_온도게이지', 14: '경고등_요소수부족', 15: '경고등_전/후측방 충돌 경고', 16: '경고등_통합경고등(느낌표)', 17: '글로브박스_파손', 18: '내비게이션_고장', 19: '내비게이션_분실', 20: '내비게이션_업데이트', 21: '내비게이션_인터넷연결불가', 22: '내비게이션_파손', 23: '내장제_파손', 24: '단말기_배선', 25: '단말기_신호불량', 26: '단말기_탈거', 27: '단말기_파손', 28: '도어_고장', 29: '도어_소음', 30: '도어_파손', 31: '라이트/전구_경고등(고장)', 32: '라이트/전구_파손', 33: '룸미러_파손', 34: '목받침대_고장/분실', 35: '문제어불가_옥스트라', 36: '문제어불가_카드태깅', 37: '발판매트_분실', 38: '번호판_분실', 39: '번호판_파손', 40: '본넷_고장', 41: '본넷_파손', 42: '브레이크_고장', 43: '브레이크_밀림', 44: '브레이크_소음', 45: '블랙박스_고장', 46: '블랙박스_분실', 47: '블랙박스_파손', 48: '사고조사_요청', 49: '사이드미러_고장', 50: '사이드미러_파손', 51: '사이드브레이크_해제불가', 52: '선바이저_파손', 53: '시거잭_고장', 54: '시동_고장', 55: '시동_꺼짐', 56: '시동_배터리방전', 57: '시동_키인식불가(키아웃)', 58: '시트_시트조절_고장', 59: '시트_열선시트_고장', 60: '시트_파손', 61: '실내등_고장', 62: '안심번호

In [6]:
type_num = train['type'].values.tolist()
content = train['content'].values.tolist()
train_list = [[str(content[i]), type_num[i]] for i in range(len(type_num))]
# Train 리스트 비율 확인
print('Train ratio :', [type_num.count(i) for i in range(len(mapping))])

type_num = test['type'].values.tolist()
content = test['content'].values.tolist()
test_list = [[str(content[i]), type_num[i]] for i in range(len(type_num))]
# Test 리스트 비율 확인
print('Test ratio: ', [type_num.count(i) for i in range(len(mapping))])


Train ratio : [22, 1, 24, 7, 6, 8, 14, 62, 28, 14, 30, 331, 18, 12, 18, 84, 191, 5, 806, 6, 73, 536, 130, 55, 17, 133, 10, 1, 74, 14, 16, 843, 48, 5, 2, 100, 16, 7, 14, 7, 16, 4, 41, 66, 138, 125, 5, 4, 13, 38, 15, 27, 9, 25, 63, 31, 412, 147, 112, 30, 9, 24, 4, 5, 63, 2, 3, 13, 2, 40, 82, 135, 9, 29, 27, 21, 13, 27, 33, 3, 10, 355, 2, 18, 9, 6, 6, 4, 11, 14, 13, 8, 13, 15, 33, 42, 91, 7, 7, 8, 37, 4, 0, 80, 43, 2, 281, 106, 102, 25, 23, 31, 46, 58, 11, 69, 6, 57, 510, 34, 9, 285, 32, 104, 26, 31, 15, 29, 22, 27, 2, 12, 3, 17, 205, 42, 7, 25, 18, 11]
Test ratio:  [8, 0, 9, 3, 1, 0, 3, 10, 3, 6, 10, 103, 7, 2, 6, 11, 39, 0, 202, 2, 20, 111, 46, 22, 6, 23, 6, 0, 19, 5, 2, 235, 10, 1, 2, 31, 3, 2, 2, 1, 1, 0, 6, 19, 32, 24, 1, 0, 1, 13, 3, 9, 3, 8, 18, 10, 88, 31, 29, 9, 3, 8, 0, 1, 12, 0, 2, 3, 0, 8, 22, 44, 4, 7, 9, 5, 2, 8, 9, 1, 0, 73, 2, 1, 1, 1, 0, 1, 2, 5, 3, 5, 3, 4, 9, 9, 19, 0, 2, 1, 10, 0, 1, 19, 11, 0, 60, 22, 30, 3, 11, 9, 17, 13, 1, 16, 2, 18, 136, 14, 2, 62, 7, 23, 5, 10, 0

In [7]:
for i in range(len(test_list)):
  print(train_list[i])

['지난 이용간 후방카메라 없었음 다른 차량에서는 작동되는지 문의', 134]
['헤드라이트 기스 예약취소요청', 32]
['브레이크 밟을때마다 조수석 앞바퀴에서 쇠긁히는 소음 발생', 44]
['히터가 안나와서 너무 추워 대차 요청', 137]
['차량 조수석 문계패 불가', 28]
['배터리 수치', 56]
['내비게이션 인터넷 및 와이파이 연결불가항의', 21]
['주행 중 워셔액 안나옴', 81]
['현장조치 이후 기사님이 보조석 앞타이어 휠 쪽 확인 요청하였다고함', 136]
['타이어공기압 경고등 ', 118]
['내비게이션 SD카드 오류나서 인입', 21]
['악취난다며인입', 87]
['시동 거니 경고등이 많이 떠있다함', 15]
['운행 중 차량떨림이 심함', 107]
['엔진 소음이 심하고 악셀레이터를 밟을 때 차가 튕기고 불안정함', 106]
['안나오는 라이트 사진촬영후 인입', 31]
['안전벨트가 제대로 안된다고하심', 64]
['차량 경고등 문의', 31]
['차량 바리게이트가 쳐져있어 지상층으로 입차 불가', 99]
['제동등 경고등', 31]
['워셔액 얼어있음', 81]
['키아웃으로 긴급출동 점검 받으셧으나 조치 불가로 확인되어 차량 잔여시간 환불요청', 57]
['기타', 7]
['타이어공기압경고등점등 및 제동등 점검 문구로 인입', 118]
['카드가 빠져있다는 알람 지속 출력 카드 재삽입 시에도 동일했다고 함 ', 123]
['내비게이션 음성 안내 나오지 않음', 18]
['스마트키를 인식할수 없다고 시동안걸려서 인입', 57]
['액셀 밞으면 소리난다고함', 106]
['고객 시동걸었으나 지속 소음발생인입', 111]
['어제 이용한 코나 차량 후방카메라 고장', 135]
['차량 사이드 미러 각도 조작이 안되서 문의', 49]
['내비 탈거되있다고 인입', 22]
['인터넷 연결 불가 상태', 18]
['경고등이 점등되었다고 함', 32]
['눈때문에 출차 불가 문의', 101]
['내비게이션 고정이 안되고 자꾸 앞으

In [8]:
# CUDA memory 비우기
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [9]:
# Setting parameters
max_len = 164 # 해당 길이를 초과하는 단어에 대해선 bert가 학습하지 않음
batch_size = 4 # 기본64 (메모리초과로 인해 4로 설정)
warmup_ratio = 0.1
num_epochs = 20
max_grad_norm = 1
log_interval = 200
learning_rate = 5e-5


# 분류할 type의 총 가짓수
type_cnt = len(list(set(new_data['type'].values.tolist()))) # 존재하는 카테고리 종류 모음

In [10]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes = type_cnt, # softmax 사용 <- binary일 경우는 2
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)


# pytorch용 DataLoader 사용
data_train = BERTDataset(train_list, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(test_list, 0, 1, tok, max_len, True, False)
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)


# 모델 불러오기
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)

# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# 옵티마이저 선언
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss() # softmax용 Loss Function 정하기 <- binary classification도 해당 loss function 사용 가능

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

# 학습 평가 지표인 accuracy 계산 -> 얼마나 타겟값을 많이 맞추었는가
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc
  
# 모델 학습 시작
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # gradient clipping
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    


  cpuset_checked))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 1 batch id 1 loss 4.785386562347412 train acc 0.0
epoch 1 batch id 201 loss 5.2790327072143555 train acc 0.0037313432835820895
epoch 1 batch id 401 loss 4.631361961364746 train acc 0.029925187032418952
epoch 1 batch id 601 loss 4.1583685874938965 train acc 0.050748752079866885
epoch 1 batch id 801 loss 3.785870313644409 train acc 0.0749063670411985
epoch 1 batch id 1001 loss 4.129068374633789 train acc 0.09065934065934066
epoch 1 batch id 1201 loss 3.974503993988037 train acc 0.09970857618651124
epoch 1 batch id 1401 loss 3.902794599533081 train acc 0.10902926481084939
epoch 1 batch id 1601 loss 3.5897376537323 train acc 0.11789506558401
epoch 1 batch id 1801 loss 3.7837605476379395 train acc 0.12881732370905052
epoch 1 batch id 2001 loss 4.115349769592285 train acc 0.14042978510744628

epoch 1 train acc 0.1515047879616963


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 2 batch id 1 loss 4.067741394042969 train acc 0.25
epoch 2 batch id 201 loss 4.4552998542785645 train acc 0.3345771144278607
epoch 2 batch id 401 loss 3.3815860748291016 train acc 0.31982543640897754
epoch 2 batch id 601 loss 2.975615978240967 train acc 0.3198835274542429
epoch 2 batch id 801 loss 2.475768566131592 train acc 0.33083645443196
epoch 2 batch id 1001 loss 3.83709716796875 train acc 0.32867132867132864
epoch 2 batch id 1201 loss 1.4712629318237305 train acc 0.3257701915070774
epoch 2 batch id 1401 loss 3.090872287750244 train acc 0.32780157030692364
epoch 2 batch id 1601 loss 3.007084369659424 train acc 0.32948157401623984
epoch 2 batch id 1801 loss 2.4226455688476562 train acc 0.3337034980566352
epoch 2 batch id 2001 loss 3.485016345977783 train acc 0.33758120939530234

epoch 2 train acc 0.3385772913816689


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 3 batch id 1 loss 3.2731142044067383 train acc 0.25
epoch 3 batch id 201 loss 4.167140007019043 train acc 0.39925373134328357
epoch 3 batch id 401 loss 2.083172082901001 train acc 0.3759351620947631
epoch 3 batch id 601 loss 2.9739437103271484 train acc 0.38019966722129783
epoch 3 batch id 801 loss 2.371427059173584 train acc 0.3910736579275905
epoch 3 batch id 1001 loss 3.8935515880584717 train acc 0.3841158841158841
epoch 3 batch id 1201 loss 1.5002387762069702 train acc 0.3726061615320566
epoch 3 batch id 1401 loss 3.867370128631592 train acc 0.37205567451820126
epoch 3 batch id 1601 loss 3.676703453063965 train acc 0.37008119925046845
epoch 3 batch id 1801 loss 2.6113812923431396 train acc 0.3707662409772349
epoch 3 batch id 2001 loss 3.2313809394836426 train acc 0.37243878060969515

epoch 3 train acc 0.37209302325581395


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 4 batch id 1 loss 2.3758020401000977 train acc 0.25
epoch 4 batch id 201 loss 4.477324962615967 train acc 0.4141791044776119
epoch 4 batch id 401 loss 1.6718165874481201 train acc 0.40523690773067333
epoch 4 batch id 601 loss 2.7351059913635254 train acc 0.4101497504159734
epoch 4 batch id 801 loss 1.8699690103530884 train acc 0.42134831460674155
epoch 4 batch id 1001 loss 4.00813627243042 train acc 0.41483516483516486
epoch 4 batch id 1201 loss 1.1107622385025024 train acc 0.4052872606161532
epoch 4 batch id 1401 loss 3.89259672164917 train acc 0.40649536045681656
epoch 4 batch id 1601 loss 3.6980361938476562 train acc 0.4080262336039975
epoch 4 batch id 1801 loss 2.0846943855285645 train acc 0.412965019433648
epoch 4 batch id 2001 loss 2.1179540157318115 train acc 0.4142928535732134

epoch 4 train acc 0.41393068855449155


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 5 batch id 1 loss 2.401393175125122 train acc 0.25
epoch 5 batch id 201 loss 4.31713342666626 train acc 0.44029850746268656
epoch 5 batch id 401 loss 1.616711974143982 train acc 0.428927680798005
epoch 5 batch id 601 loss 2.8444318771362305 train acc 0.435108153078203
epoch 5 batch id 801 loss 1.6055357456207275 train acc 0.44818976279650435
epoch 5 batch id 1001 loss 4.031868934631348 train acc 0.4448051948051948
epoch 5 batch id 1201 loss 1.02219820022583 train acc 0.43984179850124894
epoch 5 batch id 1401 loss 3.8411402702331543 train acc 0.4377230549607423
epoch 5 batch id 1601 loss 3.1615679264068604 train acc 0.43707058088694567
epoch 5 batch id 1801 loss 1.8374886512756348 train acc 0.43725707940033315
epoch 5 batch id 2001 loss 2.6735572814941406 train acc 0.4397801099450275

epoch 5 train acc 0.43969448244414044


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 6 batch id 1 loss 1.474848747253418 train acc 0.75
epoch 6 batch id 201 loss 4.156374931335449 train acc 0.4664179104477612
epoch 6 batch id 401 loss 1.5376741886138916 train acc 0.456359102244389
epoch 6 batch id 601 loss 2.3775930404663086 train acc 0.456738768718802
epoch 6 batch id 801 loss 1.5947645902633667 train acc 0.4681647940074906
epoch 6 batch id 1001 loss 3.8135509490966797 train acc 0.4617882117882118
epoch 6 batch id 1201 loss 0.8096612691879272 train acc 0.4587843463780183
epoch 6 batch id 1401 loss 4.121147155761719 train acc 0.4603854389721627
epoch 6 batch id 1601 loss 3.189771890640259 train acc 0.46158650843222987
epoch 6 batch id 1801 loss 1.5485589504241943 train acc 0.4655746807329261
epoch 6 batch id 2001 loss 2.2746434211730957 train acc 0.4681409295352324

epoch 6 train acc 0.468422252621979


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 7 batch id 1 loss 1.7717170715332031 train acc 0.5
epoch 7 batch id 201 loss 3.811843156814575 train acc 0.48507462686567165
epoch 7 batch id 401 loss 1.4314892292022705 train acc 0.4756857855361596
epoch 7 batch id 601 loss 2.193575620651245 train acc 0.46921797004991683
epoch 7 batch id 801 loss 1.5537700653076172 train acc 0.48408239700374533
epoch 7 batch id 1001 loss 3.2258942127227783 train acc 0.4825174825174825
epoch 7 batch id 1201 loss 0.8356752395629883 train acc 0.4804329725228976
epoch 7 batch id 1401 loss 2.6419992446899414 train acc 0.4851891506067095
epoch 7 batch id 1601 loss 3.026862859725952 train acc 0.4871955028107433
epoch 7 batch id 1801 loss 1.2855572700500488 train acc 0.49125485841199334
epoch 7 batch id 2001 loss 2.092928409576416 train acc 0.49462768615692154

epoch 7 train acc 0.49851801185590516


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 8 batch id 1 loss 1.9380874633789062 train acc 0.75
epoch 8 batch id 201 loss 3.575437307357788 train acc 0.5509950248756219
epoch 8 batch id 401 loss 1.1938687562942505 train acc 0.5399002493765586
epoch 8 batch id 601 loss 2.0743255615234375 train acc 0.5320299500831946
epoch 8 batch id 801 loss 1.5027507543563843 train acc 0.5433832709113608
epoch 8 batch id 1001 loss 3.7945876121520996 train acc 0.537962037962038
epoch 8 batch id 1201 loss 1.282543420791626 train acc 0.5349708576186512
epoch 8 batch id 1401 loss 3.599787712097168 train acc 0.5362241256245539
epoch 8 batch id 1601 loss 0.5613811016082764 train acc 0.5365396627108058
epoch 8 batch id 1801 loss 0.782213568687439 train acc 0.5412270960577457
epoch 8 batch id 2001 loss 1.9601629972457886 train acc 0.5446026986506747

epoch 8 train acc 0.5463976288189695


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 9 batch id 1 loss 1.304030179977417 train acc 0.5
epoch 9 batch id 201 loss 3.3791496753692627 train acc 0.5845771144278606
epoch 9 batch id 401 loss 0.2901570200920105 train acc 0.5766832917705735
epoch 9 batch id 601 loss 2.1377015113830566 train acc 0.5827787021630616
epoch 9 batch id 801 loss 0.9037314653396606 train acc 0.5923845193508115
epoch 9 batch id 1001 loss 4.035294532775879 train acc 0.5861638361638362
epoch 9 batch id 1201 loss 0.9165542125701904 train acc 0.5838884263114071
epoch 9 batch id 1401 loss 3.154783010482788 train acc 0.5835117773019272
epoch 9 batch id 1601 loss 0.830683171749115 train acc 0.5843222985633979
epoch 9 batch id 1801 loss 0.660473108291626 train acc 0.5867573570238757
epoch 9 batch id 2001 loss 1.8304531574249268 train acc 0.5913293353323338

epoch 9 train acc 0.5951892384860921


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 10 batch id 1 loss 1.8645906448364258 train acc 0.75
epoch 10 batch id 201 loss 2.9235410690307617 train acc 0.6218905472636815
epoch 10 batch id 401 loss 0.06109699606895447 train acc 0.6246882793017456
epoch 10 batch id 601 loss 1.7349987030029297 train acc 0.6289517470881864
epoch 10 batch id 801 loss 0.456693559885025 train acc 0.6376404494382022
epoch 10 batch id 1001 loss 3.6422314643859863 train acc 0.6363636363636364
epoch 10 batch id 1201 loss 0.32078367471694946 train acc 0.6357202331390508
epoch 10 batch id 1401 loss 2.9271693229675293 train acc 0.6349036402569593
epoch 10 batch id 1601 loss 0.37528538703918457 train acc 0.6347595252966896
epoch 10 batch id 1801 loss 0.5838004946708679 train acc 0.6370072182121044
epoch 10 batch id 2001 loss 1.7120012044906616 train acc 0.6405547226386806

epoch 10 train acc 0.6439808481532148


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 11 batch id 1 loss 1.7373239994049072 train acc 0.75
epoch 11 batch id 201 loss 2.6793601512908936 train acc 0.6853233830845771
epoch 11 batch id 401 loss 0.3052716851234436 train acc 0.6851620947630923
epoch 11 batch id 601 loss 2.765665054321289 train acc 0.6855241264559068
epoch 11 batch id 801 loss 1.0368224382400513 train acc 0.6931960049937578
epoch 11 batch id 1001 loss 3.5592803955078125 train acc 0.6875624375624375
epoch 11 batch id 1201 loss 0.24774929881095886 train acc 0.6850541215653622
epoch 11 batch id 1401 loss 1.9917082786560059 train acc 0.6834403997144897
epoch 11 batch id 1601 loss 0.6238046884536743 train acc 0.683635227982511
epoch 11 batch id 1801 loss 0.35579368472099304 train acc 0.6853137146029983
epoch 11 batch id 2001 loss 1.6029688119888306 train acc 0.6877811094452774

epoch 11 train acc 0.6893524851801186


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 12 batch id 1 loss 1.5545389652252197 train acc 0.75
epoch 12 batch id 201 loss 2.1964774131774902 train acc 0.7238805970149254
epoch 12 batch id 401 loss 0.09127737581729889 train acc 0.7275561097256857
epoch 12 batch id 601 loss 2.0473365783691406 train acc 0.7217138103161398
epoch 12 batch id 801 loss 0.2831624150276184 train acc 0.7275280898876404
epoch 12 batch id 1001 loss 3.4471607208251953 train acc 0.7222777222777222
epoch 12 batch id 1201 loss 0.24063368141651154 train acc 0.7212739383846795
epoch 12 batch id 1401 loss 1.085897445678711 train acc 0.7235902926481085
epoch 12 batch id 1601 loss 0.15078751742839813 train acc 0.721736414740787
epoch 12 batch id 1801 loss 0.08310998231172562 train acc 0.7241810105496946
epoch 12 batch id 2001 loss 1.6473948955535889 train acc 0.7266366816591704

epoch 12 train acc 0.7281121751025992


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 13 batch id 1 loss 1.3968074321746826 train acc 0.75
epoch 13 batch id 201 loss 2.5044991970062256 train acc 0.7599502487562189
epoch 13 batch id 401 loss 0.014846984297037125 train acc 0.7630922693266833
epoch 13 batch id 601 loss 2.0952510833740234 train acc 0.7658069883527454
epoch 13 batch id 801 loss 0.15219078958034515 train acc 0.7709113607990012
epoch 13 batch id 1001 loss 3.821960687637329 train acc 0.764985014985015
epoch 13 batch id 1201 loss 0.1870517134666443 train acc 0.763946711074105
epoch 13 batch id 1401 loss 0.06111183017492294 train acc 0.7630264097073519
epoch 13 batch id 1601 loss 0.058943964540958405 train acc 0.7632729544034978
epoch 13 batch id 1801 loss 0.0626075342297554 train acc 0.7629094947251527
epoch 13 batch id 2001 loss 1.3575828075408936 train acc 0.7658670664667666

epoch 13 train acc 0.7676698586411309


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 14 batch id 1 loss 1.0829901695251465 train acc 0.75
epoch 14 batch id 201 loss 1.6687068939208984 train acc 0.7823383084577115
epoch 14 batch id 401 loss 0.02197442576289177 train acc 0.7917705735660848
epoch 14 batch id 601 loss 1.3101404905319214 train acc 0.7911813643926788
epoch 14 batch id 801 loss 0.14944197237491608 train acc 0.7971285892634207
epoch 14 batch id 1001 loss 3.5076184272766113 train acc 0.7939560439560439
epoch 14 batch id 1201 loss 0.1875700205564499 train acc 0.7926727726894255
epoch 14 batch id 1401 loss 0.6794179677963257 train acc 0.7947894361170592
epoch 14 batch id 1601 loss 0.03892911598086357 train acc 0.7952841973766396
epoch 14 batch id 1801 loss 0.03955342620611191 train acc 0.7970571904497501
epoch 14 batch id 2001 loss 1.927363395690918 train acc 0.7981009495252374

epoch 14 train acc 0.7990196078431373


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 15 batch id 1 loss 1.1721537113189697 train acc 0.75
epoch 15 batch id 201 loss 0.6431078910827637 train acc 0.8146766169154229
epoch 15 batch id 401 loss 0.01737729087471962 train acc 0.8204488778054863
epoch 15 batch id 601 loss 1.227211594581604 train acc 0.8232113144758736
epoch 15 batch id 801 loss 0.21345333755016327 train acc 0.8295880149812734
epoch 15 batch id 1001 loss 3.3350510597229004 train acc 0.8226773226773226
epoch 15 batch id 1201 loss 0.09761207550764084 train acc 0.8180682764363031
epoch 15 batch id 1401 loss 0.02249465510249138 train acc 0.8187009279086367
epoch 15 batch id 1601 loss 0.03131985291838646 train acc 0.8202685821361649
epoch 15 batch id 1801 loss 0.04484298825263977 train acc 0.8206551915602444
epoch 15 batch id 2001 loss 2.0806055068969727 train acc 0.8217141429285357

epoch 15 train acc 0.8226174190606476


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 16 batch id 1 loss 0.7144769430160522 train acc 0.75
epoch 16 batch id 201 loss 0.755109965801239 train acc 0.8370646766169154
epoch 16 batch id 401 loss 0.015835925936698914 train acc 0.8403990024937655
epoch 16 batch id 601 loss 1.0360181331634521 train acc 0.841514143094842
epoch 16 batch id 801 loss 0.21452847123146057 train acc 0.8470661672908864
epoch 16 batch id 1001 loss 3.0537009239196777 train acc 0.8431568431568431
epoch 16 batch id 1201 loss 0.03239971771836281 train acc 0.8401332223147377
epoch 16 batch id 1401 loss 0.11632437258958817 train acc 0.8394004282655246
epoch 16 batch id 1601 loss 0.02605225332081318 train acc 0.8411930043722673
epoch 16 batch id 1801 loss 0.02454887144267559 train acc 0.8393947806774015
epoch 16 batch id 2001 loss 1.5644710063934326 train acc 0.84007996001999

epoch 16 train acc 0.8399452804377565


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 17 batch id 1 loss 0.6302294731140137 train acc 0.75
epoch 17 batch id 201 loss 0.0626457929611206 train acc 0.8507462686567164
epoch 17 batch id 401 loss 0.01343977078795433 train acc 0.8553615960099751
epoch 17 batch id 601 loss 1.0544477701187134 train acc 0.8585690515806988
epoch 17 batch id 801 loss 0.146515890955925 train acc 0.864856429463171
epoch 17 batch id 1001 loss 3.242792844772339 train acc 0.8593906093906094
epoch 17 batch id 1201 loss 0.025120418518781662 train acc 0.8563696919233972
epoch 17 batch id 1401 loss 0.015272149816155434 train acc 0.8559957173447538
epoch 17 batch id 1601 loss 0.02337639033794403 train acc 0.8555590256089943
epoch 17 batch id 1801 loss 0.018648970872163773 train acc 0.85494169905608
epoch 17 batch id 2001 loss 1.1226822137832642 train acc 0.856071964017991

epoch 17 train acc 0.8557911536707706


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 18 batch id 1 loss 0.3971105217933655 train acc 1.0
epoch 18 batch id 201 loss 0.08034761995077133 train acc 0.8706467661691543
epoch 18 batch id 401 loss 0.01346304826438427 train acc 0.8721945137157108
epoch 18 batch id 601 loss 0.8564375638961792 train acc 0.872712146422629
epoch 18 batch id 801 loss 0.1455828696489334 train acc 0.8773408239700374
epoch 18 batch id 1001 loss 2.3816957473754883 train acc 0.8713786213786214
epoch 18 batch id 1201 loss 0.021759124472737312 train acc 0.8676103247293921
epoch 18 batch id 1401 loss 0.02246900461614132 train acc 0.8661670235546038
epoch 18 batch id 1601 loss 0.019250622019171715 train acc 0.8663335415365396
epoch 18 batch id 1801 loss 0.6779094338417053 train acc 0.8656302054414214
epoch 18 batch id 2001 loss 1.6206270456314087 train acc 0.8656921539230384

epoch 18 train acc 0.8650250797993616


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 19 batch id 1 loss 0.4349147081375122 train acc 0.75
epoch 19 batch id 201 loss 0.042106978595256805 train acc 0.8756218905472637
epoch 19 batch id 401 loss 0.0106861786916852 train acc 0.8827930174563591
epoch 19 batch id 601 loss 0.3702848553657532 train acc 0.8835274542429284
epoch 19 batch id 801 loss 0.11101645976305008 train acc 0.8876404494382022
epoch 19 batch id 1001 loss 2.540135383605957 train acc 0.8801198801198801
epoch 19 batch id 1201 loss 0.012471480295062065 train acc 0.8763530391340549
epoch 19 batch id 1401 loss 0.018440403044223785 train acc 0.8745538900785154
epoch 19 batch id 1601 loss 0.023912338539958 train acc 0.8749219237976265
epoch 19 batch id 1801 loss 0.010897781699895859 train acc 0.8740977234869517
epoch 19 batch id 2001 loss 1.4328497648239136 train acc 0.8740629685157422

epoch 19 train acc 0.8738030095759234


HBox(children=(FloatProgress(value=0.0, max=2193.0), HTML(value='')))

epoch 20 batch id 1 loss 0.43762317299842834 train acc 1.0
epoch 20 batch id 201 loss 0.10669133067131042 train acc 0.8830845771144279
epoch 20 batch id 401 loss 0.010151270776987076 train acc 0.8859102244389028
epoch 20 batch id 601 loss 0.18493294715881348 train acc 0.8851913477537438
epoch 20 batch id 801 loss 0.05842973664402962 train acc 0.8904494382022472
epoch 20 batch id 1001 loss 2.6178412437438965 train acc 0.8851148851148851
epoch 20 batch id 1201 loss 0.04862089455127716 train acc 0.8811407160699417
epoch 20 batch id 1401 loss 0.014654329046607018 train acc 0.8793718772305497
epoch 20 batch id 1601 loss 0.022285221144557 train acc 0.8791380387257963
epoch 20 batch id 1801 loss 0.016695531085133553 train acc 0.8784008883953359
epoch 20 batch id 2001 loss 1.6196684837341309 train acc 0.8784357821089456

epoch 20 train acc 0.8769949840401277


In [11]:
# 테스트 문장 예측
test_sentence = '네비게이션 분실로 인입.'
test_label = 0 # 장애 번호

unseen_test = pd.DataFrame([[test_sentence, test_label]], columns = [['질문 내용', '장애요인']])
unseen_values = unseen_test.values
test_set = BERTDataset(unseen_values, 0, 1, tok, max_len, True, False)
test_input = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=5)

right = []
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_input)):
  token_ids = token_ids.long().to(device)
  segment_ids = segment_ids.long().to(device)
  valid_length= valid_length
  with torch.no_grad():
    out = model(token_ids, valid_length, segment_ids)
  right.append(out.to('cpu'))
  del out
  torch.cuda.empty_cache()
  # 라벨링된 type 매핑 ex) {0: '타이어_파스/펑크', 1: '내비게이션_고장', 2: '블랙박스고장', 3: '하이패스 분실', 4: '주행_차량떨림'}
  print(mapping)
#  print(out)    
  print(right[-1])    
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_input)):
  token_ids = token_ids.long().to(device)
  segment_ids = segment_ids.long().to(device)
  valid_length= valid_length
  out = model(token_ids, valid_length, segment_ids)
  # 라벨링된 type 매핑 ex) {0: '타이어_파스/펑크', 1: '내비게이션_고장', 2: '블랙박스고장', 3: '하이패스 분실', 4: '주행_차량떨림'}
  print(mapping)
  print(out)
    

  cpuset_checked))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

{0: 'ADAS_고장', 1: 'ADAS_파손', 2: 'USB포트_고장', 3: '경고등_ABS', 4: '경고등_DPF(배출가스)', 5: '경고등_ESC(차세제어)', 6: '경고등_배터리', 7: '경고등_브레이크', 8: '경고등_스패너(서비스알림)', 9: '경고등_에어백', 10: '경고등_엔진오일', 11: '경고등_엔진체크', 12: '경고등_연료부족', 13: '경고등_온도게이지', 14: '경고등_요소수부족', 15: '경고등_전/후측방 충돌 경고', 16: '경고등_통합경고등(느낌표)', 17: '글로브박스_파손', 18: '내비게이션_고장', 19: '내비게이션_분실', 20: '내비게이션_업데이트', 21: '내비게이션_인터넷연결불가', 22: '내비게이션_파손', 23: '내장제_파손', 24: '단말기_배선', 25: '단말기_신호불량', 26: '단말기_탈거', 27: '단말기_파손', 28: '도어_고장', 29: '도어_소음', 30: '도어_파손', 31: '라이트/전구_경고등(고장)', 32: '라이트/전구_파손', 33: '룸미러_파손', 34: '목받침대_고장/분실', 35: '문제어불가_옥스트라', 36: '문제어불가_카드태깅', 37: '발판매트_분실', 38: '번호판_분실', 39: '번호판_파손', 40: '본넷_고장', 41: '본넷_파손', 42: '브레이크_고장', 43: '브레이크_밀림', 44: '브레이크_소음', 45: '블랙박스_고장', 46: '블랙박스_분실', 47: '블랙박스_파손', 48: '사고조사_요청', 49: '사이드미러_고장', 50: '사이드미러_파손', 51: '사이드브레이크_해제불가', 52: '선바이저_파손', 53: '시거잭_고장', 54: '시동_고장', 55: '시동_꺼짐', 56: '시동_배터리방전', 57: '시동_키인식불가(키아웃)', 58: '시트_시트조절_고장', 59: '시트_열선시트_고장', 60: '시트_파손', 61: '실내등_고장', 62: '안심번호

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

{0: 'ADAS_고장', 1: 'ADAS_파손', 2: 'USB포트_고장', 3: '경고등_ABS', 4: '경고등_DPF(배출가스)', 5: '경고등_ESC(차세제어)', 6: '경고등_배터리', 7: '경고등_브레이크', 8: '경고등_스패너(서비스알림)', 9: '경고등_에어백', 10: '경고등_엔진오일', 11: '경고등_엔진체크', 12: '경고등_연료부족', 13: '경고등_온도게이지', 14: '경고등_요소수부족', 15: '경고등_전/후측방 충돌 경고', 16: '경고등_통합경고등(느낌표)', 17: '글로브박스_파손', 18: '내비게이션_고장', 19: '내비게이션_분실', 20: '내비게이션_업데이트', 21: '내비게이션_인터넷연결불가', 22: '내비게이션_파손', 23: '내장제_파손', 24: '단말기_배선', 25: '단말기_신호불량', 26: '단말기_탈거', 27: '단말기_파손', 28: '도어_고장', 29: '도어_소음', 30: '도어_파손', 31: '라이트/전구_경고등(고장)', 32: '라이트/전구_파손', 33: '룸미러_파손', 34: '목받침대_고장/분실', 35: '문제어불가_옥스트라', 36: '문제어불가_카드태깅', 37: '발판매트_분실', 38: '번호판_분실', 39: '번호판_파손', 40: '본넷_고장', 41: '본넷_파손', 42: '브레이크_고장', 43: '브레이크_밀림', 44: '브레이크_소음', 45: '블랙박스_고장', 46: '블랙박스_분실', 47: '블랙박스_파손', 48: '사고조사_요청', 49: '사이드미러_고장', 50: '사이드미러_파손', 51: '사이드브레이크_해제불가', 52: '선바이저_파손', 53: '시거잭_고장', 54: '시동_고장', 55: '시동_꺼짐', 56: '시동_배터리방전', 57: '시동_키인식불가(키아웃)', 58: '시트_시트조절_고장', 59: '시트_열선시트_고장', 60: '시트_파손', 61: '실내등_고장', 62: '안심번호