In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from copy import deepcopy
import nltk
import pickle
flatten = lambda l: [item for sublist in l for item in sublist]
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack
from data_utils import *

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
train_data, word2index, slot2index, intent2index = prepare_dataset('data/train.iob')

100%|██████████| 43474/43474 [00:02<00:00, 21261.77it/s]


In [184]:
class SDEN(nn.Module):
    def __init__(self,vocab_size,embed_size,hidden_size,slot_size,intent_size,dropout=0.3,pad_idx=0):
        super(SDEN,self).__init__()
        
        self.pad_idx = 0
        self.embed = nn.Embedding(vocab_size,embed_size,padding_idx=self.pad_idx)
        self.bigru_m = nn.GRU(embed_size,hidden_size,batch_first=True,bidirectional=True)
        self.bigru_c = nn.GRU(embed_size,hidden_size,batch_first=True,bidirectional=True)
        self.context_encoder = nn.Sequential(nn.Linear(hidden_size*4,hidden_size*2),
                                                               nn.Sigmoid())
        self.session_encoder = nn.GRU(hidden_size*2,hidden_size*2,batch_first=True,bidirectional=True)
        
        self.decoder_1 = nn.GRU(embed_size,hidden_size*2,batch_first=True,bidirectional=True)
        self.decoder_2 = nn.LSTM(hidden_size*4,hidden_size*2,batch_first=True,bidirectional=True)
        
        self.intent_linear = nn.Linear(hidden_size*4,intent_size)
        self.slot_linear = nn.Linear(hidden_size*4,slot_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self,history,current):
        batch_size = len(history)
        H= [] # encoded history
        for h in history:
            mask = torch.ByteTensor(h!=self.pad_idx)
            length = mask.sum(1).long()
            embeds = self.embed(h)
            embeds = self.dropout(embeds)
            lens, indices = torch.sort(length, 0, True)
            lens = [l if l>0 else 1 for l in lens.tolist()] # all zero-input
            packed_h = pack(embeds[indices], lens, batch_first=True)
            outputs, hidden = self.bigru_m(packed_h)
            _, _indices = torch.sort(indices, 0)
            hidden = torch.cat([hh for hh in hidden],-1)
            hidden = hidden[_indices].unsqueeze(0)
            H.append(hidden)
        
        M = torch.cat(H) # B,T_C,2H
        M = self.dropout(M)
        
        embeds = self.embed(current)
        embeds = self.dropout(embeds)
        mask = torch.ByteTensor(current!=self.pad_idx)
        length = mask.sum(1).long()
        lens, indices = torch.sort(length, 0, True)
        packed_h = pack(embeds[indices], lens.tolist(), batch_first=True)
        outputs, hidden = self.bigru_c(packed_h)
        _, _indices = torch.sort(indices, 0)
        hidden = torch.cat([hh for hh in hidden],-1)
        C = hidden[_indices].unsqueeze(1) # B,1,2H
        C = self.dropout(C)
        
        C = C.repeat(1,M.size(1),1) 
        CONCAT = torch.cat([M,C],-1) # B,T_c,4H
        
        G = self.context_encoder(CONCAT)
        
        _,H = self.session_encoder(G) # 2,B,2H
        weight = next(self.parameters())
        cell_state = weight.new_zeros(H.size())
        O_1,_ = self.decoder_1(embeds)
        O_1 = self.dropout(O_1)
        
        O_2,(S_2,_) = self.decoder_2(O_1,(H,cell_state))
        O_2 = self.dropout(O_2)
        S = torch.cat([s for s in S_2],1)
        
        intent_prob = self.intent_linear(S)
        slot_prob = self.slot_linear(O_2.contiguous().view(O_2.size(0)*O_2.size(1),-1))
        
        return slot_prob, intent_prob

In [185]:
EPOCH = 20
BATCH = 32
LR = 0.001

In [186]:
model = SDEN(len(word2index),100,64,len(slot2index),len(intent2index),word2index['<pad>'])
slot_loss_function = nn.CrossEntropyLoss(ignore_index=0)
intent_loss_function = nn.CrossEntropyLoss()
model.to(device)
optimizer = optim.Adam(model.parameters(),lr=LR)
scheduler = optim.lr_scheduler.MultiStepLR(gamma=0.1,milestones=[EPOCH//4,EPOCH//2],optimizer=optimizer)

In [187]:
model.train()
for epoch in range(EPOCH):
    losses=[]
    scheduler.step()
    for i,batch in enumerate(data_loader(train_data,BATCH,True)):
        h,c,slot,intent = pad_to_batch(batch,word2index,slot2index)
        h = [hh.to(device) for hh in h]
        c = c.to(device)
        slot = slot.to(device)
        intent = intent.to(device)
        model.zero_grad()
        slot_p, intent_p = model(h,c)

        loss_s = slot_loss_function(slot_p,slot.view(-1))
        loss_i = intent_loss_function(intent_p,intent.view(-1))
        loss = loss_s + loss_i
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("[%d/%d] [%d/%d] mean_loss : %.3f" % (epoch,EPOCH,i,len(train_data)//BATCH,np.mean(losses)))
            losses=[]

[0/20] [0/1358] mean_loss : 4.613
[0/20] [100/1358] mean_loss : 1.710
[0/20] [200/1358] mean_loss : 0.759
[0/20] [300/1358] mean_loss : 0.587
[0/20] [400/1358] mean_loss : 0.497
[0/20] [500/1358] mean_loss : 0.484
[0/20] [600/1358] mean_loss : 0.422
[0/20] [700/1358] mean_loss : 0.419
[0/20] [800/1358] mean_loss : 0.354
[0/20] [900/1358] mean_loss : 0.330
[0/20] [1000/1358] mean_loss : 0.332
[0/20] [1100/1358] mean_loss : 0.332


KeyboardInterrupt: 

## Test

In [188]:
index2slot = {v:k for k,v in slot2index.items()}
index2intent = {v:k for k,v in intent2index.items()}

In [189]:
data = open("data/dev.iob","r",encoding="utf-8").readlines()

In [191]:
dev_data=[]
history=[["<null>"]]
for d in data:
    if d=="\n":
        history=[["<null>"]]
        continue
    dd = d.replace("\n","").split("|||")
    if len(dd)==1:
        pass
        #bot = tagger.morphs(dd[0])
        #history.append(bot)
    else:
        user = dd[0].split()
        tag = dd[1].split()
        intent = dd[2]
        temp = deepcopy(history)
        dev_data.append([temp,user,tag,intent])
        history.append(user)

In [192]:
for t in dev_data:
    for i,history in enumerate(t[0]):
        t[0][i] = prepare_sequence(history, word2index).view(1, -1)

    t[1] = prepare_sequence(t[1], word2index).view(1, -1)
    t[2] = prepare_sequence(t[2], slot2index).view(1, -1)
    t[3] = torch.LongTensor([intent2index[t[3]]]).view(1,-1)

In [193]:
model.eval()
preds=[]
labels=[]
hits=0
with torch.no_grad():
    for i,batch in enumerate(data_loader(dev_data,BATCH,True)):
        h,c,slot,intent = pad_to_batch(batch,word2index,slot2index)
        h = [hh.to(device) for hh in h]
        c = c.to(device)
        slot = slot.to(device)
        intent = intent.to(device)
        slot_p, intent_p = model(h,c)
        
        preds.extend([index2slot[i] for i in slot_p.max(1)[1].tolist()])
        labels.extend([index2slot[i] for i in slot.view(-1).tolist()])
        hits+=torch.eq(intent_p.max(1)[1],intent.view(-1)).sum().item()
        
        
print(hits/len(dev_data))

0.9475612548660408


In [194]:
sorted_labels = sorted(
    list(set(labels) - {'O','<pad>'}),
    key=lambda name: (name[1:], name[0])
)

In [195]:
from sklearn_crfsuite import metrics

In [196]:
preds = [[y] for y in preds] # this is because sklearn_crfsuite.metrics function flatten inputs
labels = [[y] for y in labels]

In [197]:
print(metrics.flat_classification_report(
    labels, preds, labels = sorted_labels, digits=3
))

                     precision    recall  f1-score   support

           B-agenda      0.625     0.278     0.385        36
           I-agenda      0.011     0.778     0.023        54
             B-date      0.813     0.859     0.836       911
             I-date      0.375     0.898     0.529       549
         B-distance      0.605     0.493     0.543       487
         I-distance      0.463     0.263     0.336       167
            B-event      0.838     0.803     0.820       517
            I-event      0.664     0.817     0.733       367
         B-location      0.237     0.895     0.374       572
         I-location      0.013     0.904     0.026       280
            B-party      0.647     0.882     0.747       187
            I-party      0.667     0.471     0.552        17
         B-poi_type      0.830     0.657     0.734       534
         I-poi_type      0.517     0.748     0.611       301
             B-room      0.692     0.257     0.375        35
             I-room    

In [53]:
import json

In [54]:
data = json.load(open('../dataset/kvret/kvret_test_public.json','r'))

In [55]:
index = random.choice(range(len(data)))

In [66]:
index = random.choice(range(len(data)))
history=[prepare_sequence(["<null>"],word2index).view(1,-1)]
for d in data[index]['dialogue']:
    
    if d['turn']=='assistant':
        phrase = nltk.word_tokenize(d['data']['utterance'])
        phrase = prepare_sequence(phrase,word2index).view(1,-1)
        history.append(phrase)
    else:
        h = pad_to_history(history,word2index)
        c = nltk.word_tokenize(d['data']['utterance'])
        c = prepare_sequence(c,word2index).view(1,-1)
        with torch.no_grad():
            slot_p, intent_p = model(h,c)
        
        slots = slot_p.max(1)[1]
        intent = intent_p.max(1)[1]
        slots = [index2slot[i] for i in slots.tolist()]
        intent = index2intent[intent.item()]
        print(d['data']['utterance'])
        print(slots)
        print(intent)
        print("\n")
        history.append(c)

get me directions to a local cafe
['O', 'O', 'O', 'O', 'O', 'O', 'B-poi_type']
navigate


Yes.
['O', 'O']
navigate


Thanks can you provide the address and is this the fastest route?
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-distance', 'O', 'O']
navigate




In [35]:
import pickle

In [36]:
model = model.cpu()
torch.save(model.state_dict(),'sden.pkl')

In [37]:
pickle.dump(word2index,open('vocab.pkl','wb'))
pickle.dump(slot2index,open('slot.pkl','wb'))
pickle.dump(intent2index,open('intent.pkl','wb'))