In [12]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import gluonnlp as nlp
import numpy as np
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model
from transformers import AdamW, AutoModel, AutoTokenizer
from transformers.optimization import get_cosine_schedule_with_warmup

##GPU 사용 시
device = torch.device("cuda:0")

#BERT vocab 모델 가져오기
bertmodel, vocab = get_pytorch_kobert_model()

class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,  
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out) ##
    
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))


# Setting parameters
max_len = 32 
batch_size = 32
warmup_ratio = 0.1
num_epochs = 2
max_grad_norm = 1
log_interval = 200
learning_rate = 5e-5

# 토큰화
tokenizer= get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

#저장된 가중치 가져오기
model_pt = BERTClassifier(bertmodel,  dr_rate=0.5)
model_pt.load_state_dict(torch.load('KoBERT_PN_v.0.4.4_dani.pt'))
model_pt.to(device)

using cached model. /home/adminuser/notebooks/modeling/PNclassfication/[4차] KoBERT 모델 입출력 테스트/.cache/kobert_v1.zip
using cached model. /home/adminuser/notebooks/modeling/PNclassfication/[4차] KoBERT 모델 입출력 테스트/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece
using cached model. /home/adminuser/notebooks/modeling/PNclassfication/[4차] KoBERT 모델 입출력 테스트/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


BERTClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(8002, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True

In [130]:
def softmax(arr):
    m = np.argmax(arr)
    arr = arr - m
    arr = np.exp(arr)
    return arr / np.sum(arr)

def predict(predict_sentence):

    data = [predict_sentence, '0']
    dataset_another = [data]

    another_test = BERTDataset(dataset_another, 0, 1, tok, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=5)
    
    model_pt.eval()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length= valid_length
        label = label.long().to(device)

        out = model_pt(token_ids, valid_length, segment_ids)

        test_eval=[]
        for i in out:
            logits=i
            logits = logits.detach().cpu().numpy()

            if softmax(logits)[1] >= 0.95:
                test_eval.append("부정")
            else:
                test_eval.append("일반")
            

        print(">> 입력하신 내용은 " + test_eval[0] + "채팅이라고 판단됩니다.")
        print(softmax(logits))
            # return softmax(logits)[1]        

In [131]:
# 부정채팅

predict('너무 싫어요')
predict('개별로임')
predict('이런걸 왜 사죠?')
predict('쇼호스트 왜저래?')
predict('진심 별로다')
predict('목소리가 너무 커요')
predict('방장 뭐함?')
predict('채팅관리 안하시나요?')

0.99186975

In [128]:
predict('때탈듯')
predict('흰색은 빨리때탈듯')

>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.15083203 0.84916794]
>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.57633644 0.42366353]


In [129]:
predict('너무 좋아요')
predict('빨리 구매하고 싶어요')
predict('안녕하세용')
predict('유하')
predict('오')
predict('괜찮은데요?')
predict('좋은데요?')
predict('좋아요')
predict('배송은 언제쯤 오나요?')

>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.80724126 0.19275877]
>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.78298664 0.21701334]
>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.2780534 0.7219466]
>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.47794172 0.5220583 ]
>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.35243535 0.64756465]
>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.5294268 0.4705732]
>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.7072413  0.29275873]
>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.5318021  0.46819788]
>> 입력하신 내용은 일반채팅이라고 판단됩니다.
[0.10898361 0.8910164 ]
