In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from copy import deepcopy
import nltk
import pickle
flatten = lambda l: [item for sublist in l for item in sublist]
from data_utils import *

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
data = open("data/train.iob","r",encoding="utf-8").readlines()

In [5]:
train_data=[]
history=[["<null>"]]
for d in data:
    if d=="\n":
        history=[["<null>"]]
        continue
    dd = d.replace("\n","").split("|||")
    if len(dd)==1:
        bot = tagger.morphs(dd[0])
        history.append(bot)
    else:
        user = dd[0].split()
        tag = dd[1].split()
        intent = dd[2]
        temp = deepcopy(history)
        train_data.append([temp,user,tag,intent])
        history.append(user)

In [7]:
historys, currents, slots, intents = list(zip(*train_data))

In [8]:
vocab = list(set(flatten(currents)))
slot_vocab = list(set(flatten(slots)))
intent_vocab = list(set(intents))

In [9]:
word2index={"<pad>" : 0, "<unk>" : 1, "<null>" : 2, "<s>" : 3, "</s>" : 4}
for vo in vocab:
    if word2index.get(vo)==None:
        word2index[vo] = len(word2index)
        
slot2index={"<pad>" : 0}
for vo in slot_vocab:
    if slot2index.get(vo)==None:
        slot2index[vo] = len(slot2index)
        
intent2index={}
for vo in intent_vocab:
    if intent2index.get(vo)==None:
        intent2index[vo] = len(intent2index)

In [10]:
for t in train_data:
    for i,history in enumerate(t[0]):
        t[0][i] = prepare_sequence(history, word2index).view(1, -1)

    t[1] = prepare_sequence(t[1], word2index).view(1, -1)
    t[2] = prepare_sequence(t[2], slot2index).view(1, -1)
    t[3] = torch.LongTensor([intent2index[t[3]]]).view(1,-1)

In [11]:
class SDEN(nn.Module):
    def __init__(self,vocab_size,embed_size,hidden_size,slot_size,intent_size,pad_idx=0):
        super(SDEN,self).__init__()
        
        self.pad_idx = 0
        self.embed = nn.Embedding(vocab_size,embed_size,padding_idx=self.pad_idx)
        self.bigru_m = nn.GRU(embed_size,hidden_size,batch_first=True,bidirectional=True)
        self.bigru_c = nn.GRU(embed_size,hidden_size,batch_first=True,bidirectional=True)
        self.context_encoder = nn.Sequential(nn.Linear(hidden_size*4,hidden_size*2),
                                                               nn.Sigmoid())
        self.session_encoder = nn.GRU(hidden_size*2,hidden_size*2,batch_first=True,bidirectional=True)
        
        self.decoder_1 = nn.GRU(embed_size,hidden_size*2,batch_first=True,bidirectional=True)
        self.decoder_2 = nn.LSTM(hidden_size*4,hidden_size*2,batch_first=True,bidirectional=True)
        
        self.intent_linear = nn.Linear(hidden_size*4,intent_size)
        self.slot_linear = nn.Linear(hidden_size*4,slot_size)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self,history,current):
        batch_size = len(history)
        H= [] # encoded history
        for h in history:
            mask = h.eq(self.pad_idx)
            embeds = self.embed(h)
            embeds = self.dropout(embeds)
            outputs, hidden = self.bigru_m(embeds)
            real_hidden = []

            for i, o in enumerate(outputs): # B,T,D
                real_length = mask[i].tolist().count(0) 
                real_hidden.append(o[real_length - 1])

            H.append(torch.cat(real_hidden).view(h.size(0), -1).unsqueeze(0))
        
        M = torch.cat(H) # B,T_C,2H
        M = self.dropout(M)
        embeds = self.embed(current)
        embeds = self.dropout(embeds)
        mask = current.eq(self.pad_idx)
        outputs, hidden = self.bigru_c(embeds)
        real_hidden=[]
        for i, o in enumerate(outputs): # B,T,D
            real_length = mask[i].tolist().count(0) 
            real_hidden.append(o[real_length - 1])
        C = torch.cat(real_hidden).view(current.size(0),1, -1) # B,1,2H
        C = self.dropout(C)
        
        C = C.repeat(1,M.size(1),1) 
        CONCAT = torch.cat([M,C],-1) # B,T_c,4H
        
        G = self.context_encoder(CONCAT)
        
        _,H = self.session_encoder(G) # 2,B,2H
        weight = next(self.parameters())
        cell_state = weight.new_zeros(H.size())
        O_1,_ = self.decoder_1(embeds)
        O_1 = self.dropout(O_1)
        
        O_2,(S_2,_) = self.decoder_2(O_1,(H,cell_state))
        O_2 = self.dropout(O_2)
        S = torch.cat([s for s in S_2],1)
        
        intent_prob = self.intent_linear(S)
        slot_prob = self.slot_linear(O_2.contiguous().view(O_2.size(0)*O_2.size(1),-1))
        
        return slot_prob, intent_prob

In [12]:
EPOCH = 20
BATCH = 32
LR = 0.001

In [13]:
model = SDEN(len(word2index),100,64,len(slot2index),len(intent2index),word2index['<pad>'])
slot_loss_function = nn.CrossEntropyLoss(ignore_index=0)
intent_loss_function = nn.CrossEntropyLoss()
model.to(device)
optimizer = optim.Adam(model.parameters(),lr=LR)
scheduler = optim.lr_scheduler.MultiStepLR(gamma=0.1,milestones=[EPOCH//4,EPOCH//2],optimizer=optimizer)

In [14]:
model.train()
for epoch in range(EPOCH):
    losses=[]
    scheduler.step()
    for i,batch in enumerate(data_loader(train_data,BATCH,True)):
        h,c,slot,intent = pad_to_batch(batch,word2index,slot2index)
        h = [hh.to(device) for hh in h]
        c = c.to(device)
        slot = slot.to(device)
        intent = intent.to(device)
        model.zero_grad()
        slot_p, intent_p = model(h,c)

        loss_s = slot_loss_function(slot_p,slot.view(-1))
        loss_i = intent_loss_function(intent_p,intent.view(-1))
        loss = loss_s + loss_i
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("[%d/%d] [%d/%d] mean_loss : %.3f" % (epoch,EPOCH,i,len(train_data)//BATCH,np.mean(losses)))
            losses=[]

[0/20] [0/961] mean_loss : 4.543
[0/20] [100/961] mean_loss : 2.027
[0/20] [200/961] mean_loss : 1.189
[0/20] [300/961] mean_loss : 0.965
[0/20] [400/961] mean_loss : 0.848
[0/20] [500/961] mean_loss : 0.757
[0/20] [600/961] mean_loss : 0.716
[0/20] [700/961] mean_loss : 0.658
[0/20] [800/961] mean_loss : 0.634
[0/20] [900/961] mean_loss : 0.569
[1/20] [0/961] mean_loss : 0.622
[1/20] [100/961] mean_loss : 0.556
[1/20] [200/961] mean_loss : 0.538
[1/20] [300/961] mean_loss : 0.536
[1/20] [400/961] mean_loss : 0.517
[1/20] [500/961] mean_loss : 0.512
[1/20] [600/961] mean_loss : 0.498
[1/20] [700/961] mean_loss : 0.472
[1/20] [800/961] mean_loss : 0.482
[1/20] [900/961] mean_loss : 0.487
[2/20] [0/961] mean_loss : 0.331
[2/20] [100/961] mean_loss : 0.437
[2/20] [200/961] mean_loss : 0.458
[2/20] [300/961] mean_loss : 0.432
[2/20] [400/961] mean_loss : 0.437
[2/20] [500/961] mean_loss : 0.437
[2/20] [600/961] mean_loss : 0.443
[2/20] [700/961] mean_loss : 0.436
[2/20] [800/961] mean_loss

KeyboardInterrupt: 

## Test

In [17]:
index2slot = {v:k for k,v in slot2index.items()}
index2intent = {v:k for k,v in intent2index.items()}

In [18]:
data = open("data/dev.iob","r",encoding="utf-8").readlines()

In [19]:
dev_data=[]
history=[["<null>"]]
for d in data:
    if d=="\n":
        history=[["<null>"]]
        continue
    dd = d.replace("\n","").split("|||")
    if len(dd)==1:
        bot = tagger.morphs(dd[0])
        history.append(bot)
    else:
        user = dd[0].split()
        tag = dd[1].split()
        intent = dd[2]
        temp = deepcopy(history)
        dev_data.append([temp,user,tag,intent])
        history.append(user)

In [20]:
for t in dev_data:
    for i,history in enumerate(t[0]):
        t[0][i] = prepare_sequence(history, word2index).view(1, -1)

    t[1] = prepare_sequence(t[1], word2index).view(1, -1)
    t[2] = prepare_sequence(t[2], slot2index).view(1, -1)
    t[3] = torch.LongTensor([intent2index[t[3]]]).view(1,-1)

In [44]:
model.eval()
preds=[]
labels=[]
hits=0
with torch.no_grad():
    for i,batch in enumerate(data_loader(dev_data,BATCH,True)):
        h,c,slot,intent = pad_to_batch(batch,word2index,slot2index)
        h = [hh.to(device) for hh in h]
        c = c.to(device)
        slot = slot.to(device)
        intent = intent.to(device)
        slot_p, intent_p = model(h,c)
        
        preds.extend([index2slot[i] for i in slot_p.max(1)[1].tolist()])
        labels.extend([index2slot[i] for i in slot.view(-1).tolist()])
        hits+=torch.eq(intent_p.max(1)[1],intent.view(-1)).sum().item()
        
        
print(hits/len(dev_data))

0.9588724584103512


In [47]:
sorted_labels = sorted(
    list(set(labels) - {'O','<pad>'}),
    key=lambda name: (name[1:], name[0])
)

In [49]:
from sklearn_crfsuite import metrics

In [51]:
preds = [[y] for y in preds] # this is because sklearn_crfsuite.metrics function flatten inputs
labels = [[y] for y in labels]

In [52]:
print(metrics.flat_classification_report(
    labels, preds, labels = sorted_labels, digits=3
))

                     precision    recall  f1-score   support

           B-agenda      1.000     0.333     0.500         9
           I-agenda      0.004     0.545     0.008        11
             B-date      0.312     0.818     0.452       461
             I-date      0.124     0.982     0.220       283
         B-distance      0.583     0.744     0.653       242
         I-distance      0.341     0.311     0.326        90
            B-event      0.837     0.815     0.826       259
            I-event      0.799     0.924     0.857       172
         B-location      0.841     0.898     0.869       325
         I-location      0.703     0.981     0.819       157
            B-party      0.883     0.978     0.929        93
            I-party      0.438     0.700     0.538        10
         B-poi_type      0.645     0.789     0.710       265
         I-poi_type      0.413     0.842     0.554       146
             B-room      1.000     0.364     0.533        11
             I-room    

In [53]:
import json

In [54]:
data = json.load(open('../dataset/kvret/kvret_test_public.json','r'))

In [55]:
index = random.choice(range(len(data)))

In [66]:
index = random.choice(range(len(data)))
history=[prepare_sequence(["<null>"],word2index).view(1,-1)]
for d in data[index]['dialogue']:
    
    if d['turn']=='assistant':
        phrase = nltk.word_tokenize(d['data']['utterance'])
        phrase = prepare_sequence(phrase,word2index).view(1,-1)
        history.append(phrase)
    else:
        h = pad_to_history(history,word2index)
        c = nltk.word_tokenize(d['data']['utterance'])
        c = prepare_sequence(c,word2index).view(1,-1)
        with torch.no_grad():
            slot_p, intent_p = model(h,c)
        
        slots = slot_p.max(1)[1]
        intent = intent_p.max(1)[1]
        slots = [index2slot[i] for i in slots.tolist()]
        intent = index2intent[intent.item()]
        print(d['data']['utterance'])
        print(slots)
        print(intent)
        print("\n")
        history.append(c)

get me directions to a local cafe
['O', 'O', 'O', 'O', 'O', 'O', 'B-poi_type']
navigate


Yes.
['O', 'O']
navigate


Thanks can you provide the address and is this the fastest route?
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-distance', 'O', 'O']
navigate




In [35]:
import pickle

In [36]:
model = model.cpu()
torch.save(model.state_dict(),'sden.pkl')

In [37]:
pickle.dump(word2index,open('vocab.pkl','wb'))
pickle.dump(slot2index,open('slot.pkl','wb'))
pickle.dump(intent2index,open('intent.pkl','wb'))