In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_utils import *
from model import SDEN
import pickle
import json
import random
import nltk

In [2]:
checkpoint = torch.load('weight/model.pkl',map_location=lambda storage, loc: storage)

In [3]:
config = checkpoint['config']

In [4]:
model = SDEN(len(checkpoint['vocab']),config.embed_size,config.hidden_size,
             len(checkpoint['slot_vocab']),len(checkpoint['intent_vocab']))
model.load_state_dict(checkpoint['model'])

In [5]:
model.eval()

SDEN(
  (embed): Embedding(1179, 100, padding_idx=0)
  (bigru_m): GRU(100, 64, batch_first=True, bidirectional=True)
  (bigru_c): GRU(100, 64, batch_first=True, bidirectional=True)
  (context_encoder): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): Sigmoid()
  )
  (session_encoder): GRU(128, 128, batch_first=True, bidirectional=True)
  (decoder_1): GRU(100, 128, batch_first=True, bidirectional=True)
  (decoder_2): LSTM(256, 128, batch_first=True, bidirectional=True)
  (intent_linear): Linear(in_features=256, out_features=4, bias=True)
  (slot_linear): Linear(in_features=256, out_features=24, bias=True)
  (dropout): Dropout(p=0.3)
)

In [9]:
data = json.load(open('../dataset/kvret/kvret_test_public.json','r'))

In [10]:
test = random.choice(data)

In [11]:
index2intent = {v:k for k,v in checkpoint['intent_vocab'].items()}
index2slot = {v:k for k,v in checkpoint['slot_vocab'].items()}

In [17]:
test = random.sample(data,2)
index = random.choice([i for i in range(len(test[0]['dialogue'])) if i%2==0])
test = test[0]['dialogue'][:index] + test[1]['dialogue']

In [18]:
history=[["<null>"]]
for d in test:
    utter = d['data']['utterance']
    token = nltk.word_tokenize(utter)
    c = prepare_sequence(token,checkpoint['vocab']).unsqueeze(0)
    h = pad_to_history(history,checkpoint['vocab'])
    with torch.no_grad():
        s,i = model(h,c)
    slot_p = s.max(1)[1]
    intent_p = i.max(1)[1]
    if d['turn']=='driver':
        print(token)
        print('intent : ',index2intent[intent_p.item()])
        print('slot : ',[index2slot[s] for s in slot_p.data.tolist()])
        print("")
    history.append(token)

['What', 'is', 'the', 'date', 'and', 'time', 'of', 'my', 'next', 'meeting', 'and', 'who', 'will', 'be', 'attending', 'it', '?']
intent :  schedule
slot :  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-event', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

['Please', 'give', 'me', 'the', 'address', 'and', 'directions', 'via', 'a', 'route', 'with', 'no', 'traffic', 'to', 'the', 'nearest', 'pizza', 'restaurant', '.']
intent :  navigate
slot :  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-traffic_info', 'O', 'O', 'B-distance', 'B-poi_type', 'I-poi_type', 'O']

['Yes', ',', 'let', "'s", 'go', ',', 'thank', 'you', '!']
intent :  thanks
slot :  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

