In [1]:
import os
import sys
import json
import torch
import random

from src.model import BiLSTM_CRF, MakeEmbed, textCNN, DAN, EpochLogger, save
from src.dataset import Preprocessing, MakeDataset

In [11]:
class NaturalLanguageUnderstanding:
    
    def __init__(self):
        self.dataset = MakeDataset()
        self.embed = MakeEmbed()
        self.embed.load_word2vec()
        
        self.weights = self.embed.word2vec.wv.vectors
        self.weights = torch.FloatTensor(self.weights)
        
        self.intent_clsf = textCNN(self.weights, 256, [3,4,5], 0.5, len(self.dataset.intent_label))
        self.slot_tagger = BiLSTM_CRF(self.weights, self.dataset.entity_label, 256, 128)
        self.ood_detector = DAN(self.weights, 256, 0.5, 2)
        
    def init_NLU_result(self): #NLU result
        NLU_result = {
            'intent' : "",
            'SLOT' :[
                
            ]
        }
        return NLU_result
    
    def model_load(self, intent_path, slot_path, ood_path):
        self.intent_clsf.load_state_dict(torch.load(intent_path))
        self.slot_tagger.load_state_dict(torch.load(slot_path))
        self.ood_detector.load_state_dict(torch.load(ood_path))
        self.intent_clsf.eval()
        self.slot_tagger.eval()
        self.ood_detector.eval()
        
    def predict(self, query):
        x = self.dataset.prep.pad_idx_sequencing(self.embed.query2idx(self.dataset.tokenize(query)))
        
        x = torch.tensor(x)
        '''
        ood detector
        '''
        f = self.ood_detector(x.unsqueeze(0))
        ood = torch.argmax(f).tolist()
        print(ood)
        if(ood):
            '''
            intent clsf
            '''
            f = self.intent_clsf(x.unsqueeze(0))
            
            intent = self.dataset.intents[torch.argmax(f).tolist()]
            
        else:
            intent = 'ood'
        '''
        slot tagger
        '''
        
        f = self.slot_tagger(x.unsqueeze(0))
        
        mask = torch.where(x>0, torch.tensor([1.]), torch.tensor([0.])).type(torch.uint8)
        
        predict = self.slot_tagger.decode(f, mask.view(1, -1))
        return intent, predict
    
    def convert_nlu_result(self, query, intent, predict):
        NLU_result = self.init_NLU_result()
        x_token = query.split()
        
        # slot taggin token back processing
        '''
        q : 제주도 맛집
        NLU.nlu_predict : ['restaurant', [[12,0]]]
                                            [12, 0] = [S-LOCATION , 0]
        '''
        slots = [] # back processing after SLOTS save
        BIE = [] # back processing need SLOTS
        prev = "";
        for i, slot in enumerate([self.dataset.entitys[p] for p in predict[0]]):
            name = slot[2:]
            
            if('S-' in slot):
                if(BIE != []):
                    '''
                    if B-Location, I-Location, S-Date:
                    S-Date before BIE saved B-Location and I-Location saveing
                    '''
                    slots.append(prec[2:] + "^"+" ".join(BIE))
                    BIE = []
                slots.append(name+"^"+x_token[i])
                
            elif('B-' in slot):
                '''
                after append SLOT will appear then save the BIE
                '''
                BIE.append(x_token[i])
                prev = slot
            
            elif('I-' in slot and "B" in prev):
                '''
                after append SLOT will appear then save the BIE
                '''
                BIE.append(x_token[i])
                prev = slot
            
            elif('E-' in slot and ('I' in prev or 'B' in prev)):
                '''
                arrive the SLOT end
                then saved the BIE for TOKEN are join(save) from SLOTS
                '''
                BIE.append(x_token[i])
                slots.append(name+'^'+' '.join(BIE))
                BIE = []
            
            else:
                '''
                if 0 have the BIE then save
                '''
                if(BIE != []):
                    slots.append(prev[2:] + "^"+" ".join(BIE))
                    BIE = []
                    
        NLU_result['INTENT'] = intent
        NLU_result['SLOT'] = slots
        return NLU_result
    
    def run(self, query):
        intent, predict = self.predict(query)
        self.nlu_predict = [intent, predict]
        NLU_result = self.convert_nlu_result(query, intent, predict)
        return NLU_result
    
    def print_nlu_result(self, nlu_result):
        print('speech intention : '+nlu_result.get('INTENT'))
        print('ignition object : ')
        for slot_concat in nlu_result.get('SLOT'):
            slot_name = slot_concat.split('^')[0]
            slot_value = slot_concat.split('^')[1]
            print('     '+slot_name+" : "+slot_value)

In [12]:
intent_pretrain_path = "./chatbot_data/pretraining/1_intent_clsf_model/intent_clsf_99.585_steps_7.pt"
entity_pretrain_path = "./chatbot_data/pretraining/1_entity_recog_model/entity_recog_20.89_steps_2.pt"
ood_pretrain_path = "./chatbot_data/pretraining/1_ood_clsf_model/ood_clsf_99.663_steps_4.pt"

In [13]:
NLU = NaturalLanguageUnderstanding()

NLU.model_load(intent_pretrain_path, entity_pretrain_path, ood_pretrain_path)

In [14]:
NLU_result = NLU.run('제주도 맛집')

NLU.print_nlu_result(NLU_result)

1
speech intention : restaurant
ignition object : 
     LOCATION : 제주도


  score = torch.where(mask[i].unsqueeze(1), next_score, score)


In [15]:
NLU_result = NLU.run('오늘 제주도 날씨')

NLU.print_nlu_result(NLU_result)

1
speech intention : weather
ignition object : 
     DATE : 오늘
     LOCATION : 제주도


In [19]:
NLU_result = NLU.run('제주도')

NLU.print_nlu_result(NLU_result)

1
speech intention : travel
ignition object : 
     LOCATION : 제주도


In [20]:
NLU_result = NLU.run('나 내일 제주도 여행가는데 미세먼지 알려줘')

NLU.print_nlu_result(NLU_result)

1
speech intention : dust
ignition object : 
     DATE : 내일
     LOCATION : 제주도


In [22]:
NLU_result = NLU.run('나 이번 주 제주도 여행가는데 미세먼지 알려줘')

NLU.print_nlu_result(NLU_result)

1
speech intention : dust
ignition object : 
     DATE : 이번 주
     LOCATION : 제주도


In [23]:
NLU.nlu_predict

['dust', [[0, 1, 5, 12, 0, 0, 0]]]