### Check GPU availability

In [1]:
!nvidia-smi

Sat Jan 19 15:54:25 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.145                Driver Version: 384.145                   |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  Off  | 00000000:02:00.0 Off |                  N/A |
| 23%   28C    P8    17W / 250W |    740MiB / 11172MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  Off  | 00000000:03:00.0 Off |                  N/A |
| 23%   33C    P8    17W / 250W |    589MiB / 11172MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                            

### Import libraries

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import math
import numpy as np
import time
import json
import plotly
import logging
logging.getLogger().setLevel(logging.INFO)

from pprint import pprint
from tqdm import tqdm_notebook
from idst_util import trivial
from idst_util import dstc2

from plotly.graph_objs import Scatter, Layout
from plotly.graph_objs.layout import Margin
plotly.offline.init_notebook_mode(connected = True)

[nltk_data] Downloading package punkt to /home/is/andrei-
[nltk_data]     cc/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


### Check DSTC2 availability

In [3]:
trivial.print_idst()
dstc2.check()

INFO:root:+--------------------------------+
INFO:root:|         _ ____  ___________    |
INFO:root:|        (_) __ \/ ___/_  __/    |
INFO:root:|       / / / / /\__ \ / /       |
INFO:root:|      / / /_/ /___/ // /        |
INFO:root:|     /_/_____//____//_/         |
INFO:root:|                                |
INFO:root:+--------------------------------+
INFO:root:|Incremental Dialog State Tracker|
INFO:root:+--------------------------------+
INFO:root:+--------------------------------+
INFO:root:|     Dialog State Tracker 2     |
INFO:root:|         Data Checker           |
INFO:root:+--------------------------------+
INFO:root:Looking for dstc2 directory in .
INFO:root:dstc2 was found!
INFO:root:Looking for dstc2_traindev directory in ./dstc2
INFO:root:dstc2_traindev was found!
INFO:root:Looking for dstc2_test directory in ./dstc2
INFO:root:dstc2_test was found!
INFO:root:Looking for dstc2_scripts directory in ./dstc2
INFO:root:dstc2_scripts was found!
INFO:root:Done!


### Retrieve data

In [4]:
raw_X_train, raw_Y_train, \
raw_X_dev, raw_Y_dev, \
raw_X_test, raw_Y_test, \
ontology = dstc2.retrieve_raw_datasets(train_data_augmentation = True)

INFO:root:+--------------------------------+
INFO:root:|     Dialog State Tracker 2     |
INFO:root:|       Dataset Retrieval        |
INFO:root:+--------------------------------+
INFO:root:Reading dstc2_train.flist, dstc2_dev.flist and ontology_dstc2.json
INFO:root:Asserted 1612 dialogs for dstc2_train.flist
INFO:root:Asserted 506 dialogs for dstc2_dev.flist
INFO:root:Extracting raw train features
100%|██████████| 1612/1612 [00:17<00:00, 144.67it/s]
INFO:root:Extracting raw dev features
100%|██████████| 506/506 [00:05<00:00, 95.80it/s] 
INFO:root:Reading dstc2_test.flist
INFO:root:Asserted 1117 dialogs for dstc2_test.flist
INFO:root:Extracting raw test features
100%|██████████| 1117/1117 [00:12<00:00, 86.08it/s]


### Set device

In [5]:
logging.info("+--------------------------------+")
logging.info("|            Baseline            |")
logging.info("+--------------------------------+")

GPU_ID = 0
DEVICE = torch.device("cuda:{}".format(GPU_ID) if torch.cuda.is_available() else "cpu")
#DEVICE = "cpu"
if DEVICE == "cpu":
    logging.warning("Running on CPU")
else:
    logging.info("Running on GPU {}".format(GPU_ID))

INFO:root:+--------------------------------+
INFO:root:|            Baseline            |
INFO:root:+--------------------------------+
INFO:root:Running on GPU 0


### Create vocabularies

In [6]:
logging.info("+--------------------------------+")
logging.info("|          Vocabulary            |")
logging.info("+--------------------------------+")
logging.info("Creating token_to_index, index_to_token and token_to_count dictionaries")

token_to_index = {"<unk>": 0}
index_to_token = {0: "<unk>"}
token_to_count = {"<unk>": 1}

for raw_train_dialog in tqdm_notebook(raw_X_train):
    
    for raw_train_turn in raw_train_dialog["turns"]:
        
        tokens_scores = raw_train_turn["system"] + raw_train_turn["user"]
        
        for token_score in tokens_scores:
            token = token_score[0]
            if token not in token_to_index:
                token_to_index[token] = len(token_to_index)
                index_to_token[len(token_to_index)] = token
                token_to_count[token] = 1
            else:
                token_to_count[token] += 1
                
assert len(token_to_index) == len(index_to_token)
assert len(token_to_index) == len(token_to_count)

INFO:root:+--------------------------------+
INFO:root:|          Vocabulary            |
INFO:root:+--------------------------------+
INFO:root:Creating token_to_index, index_to_token and token_to_count dictionaries


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




### Execution configuration

In [7]:
logging.info("+--------------------------------+")
logging.info("|         Configuration          |")
logging.info("+--------------------------------+")

VOCABULARY_SIZE = len(token_to_index)
logging.info("VOCABULARY_SIZE:\t\t{}".format(VOCABULARY_SIZE))

# NOTE: +2 because of null and dontcare
GOAL_FOOD_DIM = len(ontology["informable"]["food"]) + 2 
GOAL_PRICERANGE_DIM = len(ontology["informable"]["pricerange"]) + 2
GOAL_NAME_DIM = len(ontology["informable"]["name"]) + 2
GOAL_AREA_DIM = len(ontology["informable"]["area"]) + 2
METHOD_DIM = len(ontology["method"])
REQUESTED_DIM = len(ontology["requestable"])
EMBEDDING_DIM = 170
ALTERED_EMBEDDING_DIM = 300
HIDDEN_DIM = 100
NUM_EPOCHS = 30
GOAL_LOSS_FUNCTION = nn.CrossEntropyLoss()
METHOD_LOSS_FUNCTION = nn.CrossEntropyLoss()
REQUESTED_LOSS_FUNCTION = nn.BCELoss()
logging.info("GOAL_FOOD_DIM:\t\t{}".format(GOAL_FOOD_DIM))
logging.info("GOAL_PRICERANGE_DIM:\t\t{}".format(GOAL_PRICERANGE_DIM))
logging.info("GOAL_NAME_DIM:\t\t{}".format(GOAL_NAME_DIM))
logging.info("GOAL_AREA_DIM:\t\t{}".format(GOAL_AREA_DIM))
logging.info("METHOD_DIM:\t\t\t{}".format(METHOD_DIM))
logging.info("REQUESTED_DIM:\t\t{}".format(REQUESTED_DIM))
logging.info("EMBEDDING_DIM:\t\t{}".format(EMBEDDING_DIM))
logging.info("ALTERED_EMBEDDING_DIM:\t{}".format(ALTERED_EMBEDDING_DIM))
logging.info("HIDDEN_DIM:\t\t\t{}".format(HIDDEN_DIM))
logging.info("NUM_EPOCHS:\t\t\t{}".format(NUM_EPOCHS))
logging.info("GOAL_LOSS_FUNCTION:\t\t{}".format(GOAL_LOSS_FUNCTION))
logging.info("METHOD_LOSS_FUNCTION:\t\t{}".format(METHOD_LOSS_FUNCTION))
logging.info("REQUESTED_LOSS_FUNCTION:\t{}".format(REQUESTED_LOSS_FUNCTION))

INFO:root:+--------------------------------+
INFO:root:|         Configuration          |
INFO:root:+--------------------------------+
INFO:root:VOCABULARY_SIZE:		1149
INFO:root:GOAL_FOOD_DIM:		93
INFO:root:GOAL_PRICERANGE_DIM:		5
INFO:root:GOAL_NAME_DIM:		115
INFO:root:GOAL_AREA_DIM:		7
INFO:root:METHOD_DIM:			5
INFO:root:REQUESTED_DIM:		8
INFO:root:EMBEDDING_DIM:		170
INFO:root:ALTERED_EMBEDDING_DIM:	300
INFO:root:HIDDEN_DIM:			100
INFO:root:NUM_EPOCHS:			30
INFO:root:GOAL_LOSS_FUNCTION:		CrossEntropyLoss()
INFO:root:METHOD_LOSS_FUNCTION:		CrossEntropyLoss()
INFO:root:REQUESTED_LOSS_FUNCTION:	BCELoss()


### Models utilities

In [8]:
def get_index_and_score(turn, token_to_index, mode, device):
    indices = []
    scores = []
    if mode == "train":
        for system_token, system_token_score in turn["system"]:
            indices.append(token_to_index[system_token])
            scores.append(system_token_score)
        for user_token, user_token_score in turn["user"]:
            if np.random.binomial(n = 1, p = 0.1) == 1:
                indices.append(token_to_index["<unk>"])
            else:
                indices.append(token_to_index[user_token])
            scores.append(user_token_score)
    else:
        tokens_scores = turn["system"] + turn["user"]
        for token, score in tokens_scores:
            if token not in token_to_index:
                indices.append(token_to_index["<unk>"])
            else:
                indices.append(token_to_index[token])
            scores.append(score)
    assert len(indices) == len(scores)
    return torch.tensor(indices, dtype = torch.long, device = device), torch.tensor(scores, dtype = torch.float, device = device)

def plotly_plot(dev_accuracies, test_accuracies, accuracy_type):
    plotly.offline.iplot({
                            "data": [Scatter(
                                            x = list(range(len(dev_accuracies))),
                                            y = dev_accuracies,
                                            mode = "lines+markers",
                                            name = "Dev {} Accuracy".format(accuracy_type),
                                            marker = dict(color = "#3498db")),
                                    Scatter(
                                            x = list(range(len(test_accuracies))),
                                            y = test_accuracies,
                                            mode = "lines+markers",
                                            name = "Test {} Accuracy".format(accuracy_type),
                                            marker = dict(color = "#9b59b6"))],
                            "layout": Layout(
                                             title = "<b>Dev-Test {} Accuracy</b>".format(accuracy_type),
                                             xaxis = dict(title = "<b>Epoch</b>",
                                                          dtick = 1,
                                                          titlefont = dict(color = "#34495e")),
                                             yaxis = dict(title = "<b>Accuracy</b>",
                                                          titlefont = dict(color = "#34495e")),
                                             margin = Margin(b = 150))
                        })

class EarlyStopping():
    
    def __init__(self, min_delta = 0, patience = 0):
        
        self.min_delta = min_delta
        self.patience = patience
        self.wait = 0
        self.stopped_epoch = 0
        self.best = -np.Inf
        self.stop_training = False
    
    def on_epoch_end(self, epoch, current_value):
        if np.greater((current_value - self.min_delta), self.best):
            self.best = current_value
            self.wait = 0
        else:
            self.wait += 1
            if self.wait > self.patience:
                self.stopped_epoch = epoch
                self.stop_training = True
        return self.stop_training

def make_tracker(model_Goal, model_Requested, model_Method, raw_X, raw_Y, dataset):
    
    model_Goal = model_Goal.eval()
    model_Requested = model_Requested.eval()
    model_Method = model_Method.eval()
    
    with torch.no_grad():
        tracker_json = {}
        tracker_json["dataset"] = dataset
        tracker_json["sessions"] = []

        start_time = time.time()
        for raw_X_dialog, raw_Y_dialog in tqdm_notebook(zip(raw_X, raw_Y), total = len(raw_X)):
            
            model_Goal.hidden = model_Goal.init_hidden()
            model_Requested.hidden = model_Requested.init_hidden()
            model_Method.hidden = model_Method.init_hidden()
            
            session = {}
            session["session-id"] = raw_X_dialog["session-id"]
            session["turns"] = []

            for raw_X_turn, raw_Y_turn in zip(raw_X_dialog["turns"], raw_Y_dialog["turns"]):
                turn = {}
                turn["goal-labels"] = {}

                indices, scores = get_index_and_score(raw_X_turn, token_to_index, mode = "eval", device = DEVICE)
                
                goal_food, goal_pricerange, goal_name, goal_area = model_Goal(indices, scores)
                requested = model_Requested(indices, scores)
                method = model_Method(indices, scores)
                turn["goal-labels"]["food"] = retrieve_output_GoalFood(goal_food, ontology)
                turn["goal-labels"]["pricerange"] = retrieve_output_GoalPricerange(goal_pricerange, ontology)
                turn["goal-labels"]["name"] = retrieve_output_GoalName(goal_name, ontology)
                turn["goal-labels"]["area"] = retrieve_output_GoalArea(goal_area, ontology)
                turn["requested-slots"] = retrieve_output_Requested(requested, ontology)
                turn["method-label"] = retrieve_output_Method(method, ontology)
                
                session["turns"].append(turn)
            tracker_json["sessions"].append(session)
        end_time = time.time()
        tracker_json["wall-time"] = end_time - start_time
        
        return tracker_json

def get_scores(model_Goal, model_Requested, model_Method, raw_X, raw_Y, dataset):
    goal_accuracy = 0
    goal_l2 = 0
    requested_accuracy = 0
    requested_l2 = 0
    method_accuracy = 0
    method_l2 = 0
    
    tracker = make_tracker(model_Goal, model_Requested, model_Method, raw_X, raw_Y, dataset = dataset)
    
    with open("tracker_gmr.json", "w") as tracker_file:
        json.dump(tracker, tracker_file)
    
    if dataset == "dstc2_train":
        !python2 dstc2/dstc2_scripts/score.py\
        --dataset dstc2_train\
        --dataroot dstc2/dstc2_traindev/data\
        --ontology dstc2/dstc2_scripts/config/ontology_dstc2.json\
        --trackfile tracker_gmr.json\
        --scorefile tracker_gmr.score.csv
    elif dataset == "dstc2_dev":
        !python2 dstc2/dstc2_scripts/score.py\
        --dataset dstc2_dev\
        --dataroot dstc2/dstc2_traindev/data\
        --ontology dstc2/dstc2_scripts/config/ontology_dstc2.json\
        --trackfile tracker_gmr.json\
        --scorefile tracker_gmr.score.csv
    else:
        !python2 dstc2/dstc2_scripts/score.py\
        --dataset dstc2_test\
        --dataroot dstc2/dstc2_test/data\
        --ontology dstc2/dstc2_scripts/config/ontology_dstc2.json\
        --trackfile tracker_gmr.json\
        --scorefile tracker_gmr.score.csv

    file_cat = !python2 dstc2/dstc2_scripts/report.py --scorefile tracker_gmr.score.csv
    
    found_accuracies = False
    for line in file_cat:
        if line.startswith("Accuracy"):
            accuracies = line.split("|")
            goal_accuracy = float(accuracies[1])
            requested_accuracy = float(accuracies[2])
            method_accuracy = float(accuracies[3])
            found_accuracies = True
        if found_accuracies and line.startswith("l2"):
            l2s = line.split("|")
            goal_l2 = float(l2s[1])
            method_l2 = float(l2s[2])
            requested_l2 = float(l2s[3])
    
    return goal_accuracy, goal_l2, requested_accuracy, requested_l2, method_accuracy, method_l2

def retrieve_gold_GoalFood(raw_Y, ontology, device):
    ontology_informable_food = ontology["informable"]["food"]
    raw_goal_food = raw_Y["goal"]["food"]
    goal_food = 0
    if raw_goal_food != None:
        if raw_goal_food == "dontcare":
            goal_food = 1
        else:    
            goal_food = ontology_informable_food.index(raw_goal_food) + 2
    return torch.tensor([goal_food], dtype = torch.long, device = device)

def retrieve_output_GoalFood(output_tensor, ontology):
    ontology_informable_food = ontology["informable"]["food"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    goal_food_dict = {}
    
    index = torch.argmax(output_tensor).item()
    goal_food_dict["dontcare"] = output_tensor[1].item() 
    for index in range(len(output_tensor) - 2):
        goal_food_dict[ontology_informable_food[index]] = output_tensor[index + 2].item()
    return goal_food_dict

def retrieve_gold_GoalPriceRange(raw_Y, ontology, device):
    ontology_informable_pricerange = ontology["informable"]["pricerange"]
    raw_goal_pricerange = raw_Y["goal"]["pricerange"]
    goal_pricerange = 0
    if raw_goal_pricerange != None:
        if raw_goal_pricerange == "dontcare":
            goal_pricerange = 1
        else:    
            goal_pricerange = ontology_informable_pricerange.index(raw_goal_pricerange) + 2
    return torch.tensor([goal_pricerange], dtype = torch.long, device = device)

def retrieve_output_GoalPricerange(output_tensor, ontology):
    ontology_informable_pricerange = ontology["informable"]["pricerange"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    goal_pricerange_dict = {}
    
    index = torch.argmax(output_tensor).item()
    goal_pricerange_dict["dontcare"] = output_tensor[1].item()
    for index in range(len(output_tensor) - 2):     
        goal_pricerange_dict[ontology_informable_pricerange[index]] = output_tensor[index + 2].item()
    return goal_pricerange_dict

def retrieve_gold_GoalName(raw_Y, ontology, device):
    ontology_informable_name = ontology["informable"]["name"]
    raw_goal_name = raw_Y["goal"]["name"]
    goal_name = 0
    if raw_goal_name != None:
        if raw_goal_name == "dontcare":
            goal_name = 1
        else:    
            goal_name = ontology_informable_name.index(raw_goal_name) + 2
    return torch.tensor([goal_name], dtype = torch.long, device = device)

def retrieve_output_GoalName(output_tensor, ontology):
    ontology_informable_name = ontology["informable"]["name"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    goal_name_dict = {}
    
    index = torch.argmax(output_tensor).item()
    goal_name_dict["dontcare"] = output_tensor[1].item()
    for index in range(len(output_tensor) - 2):
        goal_name_dict[ontology_informable_name[index]] = output_tensor[index + 2].item()
        
    return goal_name_dict

def retrieve_gold_GoalArea(raw_Y, ontology, device):
    ontology_informable_area = ontology["informable"]["area"]
    raw_goal_area = raw_Y["goal"]["area"]
    goal_area = 0
    if raw_goal_area != None:
        if raw_goal_area == "dontcare":
            goal_area = 1
        else:    
            goal_area = ontology_informable_area.index(raw_goal_area) + 2
    return torch.tensor([goal_area], dtype = torch.long, device = device)

def retrieve_output_GoalArea(output_tensor, ontology):
    ontology_informable_area = ontology["informable"]["area"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    goal_area_dict = {}
    
    index = torch.argmax(output_tensor).item()
    goal_area_dict["dontcare"] = output_tensor[1].item()
    for index in range(len(output_tensor) - 2):
        goal_area_dict[ontology_informable_area[index]] = output_tensor[index + 2].item()
        
    return goal_area_dict

def retrieve_gold_Method(raw_Y, ontology, device):
    ontology_method = ontology["method"]
    raw_gold_method = raw_Y["method"]
    gold_method = ontology_method.index(raw_gold_method)
    return torch.tensor([gold_method], dtype = torch.long, device = device)

def retrieve_output_Method(output_tensor, ontology):
    ontology_method = ontology["method"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    method_dict = {}
    
    for index in range(len(output_tensor)):
        method_dict[ontology_method[index]] = output_tensor[index].item()
    
    return method_dict

def retrieve_gold_Requested(raw_Y, ontology, device):
    ontology_requestable = ontology["requestable"]
    raw_gold_requested = raw_Y["requested"]
    gold_requested = np.zeros(len(ontology_requestable), dtype = float)
    if len(raw_gold_requested) != 0:
        for requested in raw_gold_requested:
            gold_requested[ontology_requestable.index(requested)] = 1.0
    return torch.tensor([gold_requested], dtype = torch.float, device = device)

def retrieve_output_Requested(output_tensor, ontology):
    ontology_requestable = ontology["requestable"]
    output_tensor = output_tensor.view(-1)
    requested_dict = {}
    for index in range(len(output_tensor)):
        probability_value = output_tensor[index].item()
        if np.greater_equal(probability_value, 0.5):
            requested_dict[ontology_requestable[index]] = probability_value
    return requested_dict

### Goal Model

In [9]:
class GoalModel(nn.Module):
    
    def __init__(self,
                 vocabulary_size,
                 embedding_dim,
                 altered_embedding_dim,
                 hidden_dim,
                 goal_food_dim,
                 goal_pricerange_dim,
                 goal_name_dim,
                 goal_area_dim,
                 device):
        
        super(GoalModel, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.goal_food_dim = goal_food_dim
        self.goal_pricerange_dim = goal_pricerange_dim
        self.goal_name_dim = goal_name_dim
        self.goal_area_dim = goal_area_dim
        self.device = device
        
        self.embeddings = nn.Embedding(num_embeddings = vocabulary_size,
                                       embedding_dim = embedding_dim)
        
        self.altered_embeddings = nn.Linear(in_features = (embedding_dim + 1),
                                            out_features = altered_embedding_dim)
        
        self.lstm = nn.LSTM(input_size = altered_embedding_dim,
                            hidden_size = hidden_dim)
        
        self.goal_food_classifier = nn.Linear(in_features = hidden_dim,
                                              out_features = goal_food_dim)
        
        self.goal_pricerange_classifier = nn.Linear(in_features = hidden_dim,
                                                    out_features = goal_pricerange_dim)
        
        self.goal_name_classifier = nn.Linear(in_features = hidden_dim,
                                              out_features = goal_name_dim)
        
        self.goal_area_classifier = nn.Linear(in_features = hidden_dim,
                                              out_features = goal_area_dim)
        
        self.hidden = self.init_hidden()
    
    def init_hidden(self):
        
        return (torch.zeros(1, 1, self.hidden_dim, device = self.device),
                torch.zeros(1, 1, self.hidden_dim, device = self.device))
        
    def forward(self, indices, scores):
    
        embeddings = self.embeddings(indices)
        
        embeddings_concat_score = torch.cat((embeddings, scores.unsqueeze(dim = 1)), dim = 1) 
        
        altered_embeddings = F.relu(self.altered_embeddings(embeddings_concat_score))
        
        lstm_out, self.hidden = self.lstm(altered_embeddings.view(len(indices), 1, -1), self.hidden)
        
        goal_food = F.log_softmax(self.goal_food_classifier(self.hidden[0]).view(-1, self.goal_food_dim), dim = 1)        
        
        goal_pricerange = F.log_softmax(self.goal_pricerange_classifier(self.hidden[0]).view(-1, self.goal_pricerange_dim), dim = 1) 
        
        goal_name = F.log_softmax(self.goal_name_classifier(self.hidden[0]).view(-1, self.goal_name_dim), dim = 1)
        
        goal_area = F.log_softmax(self.goal_area_classifier(self.hidden[0]).view(-1, self.goal_area_dim), dim = 1)
        
        return goal_food, goal_pricerange, goal_name, goal_area

model_Goal = GoalModel(vocabulary_size = VOCABULARY_SIZE,
                       embedding_dim = EMBEDDING_DIM,
                       altered_embedding_dim = ALTERED_EMBEDDING_DIM,
                       hidden_dim = HIDDEN_DIM,
                       goal_food_dim = GOAL_FOOD_DIM,
                       goal_pricerange_dim = GOAL_PRICERANGE_DIM,
                       goal_name_dim = GOAL_NAME_DIM,
                       goal_area_dim = GOAL_AREA_DIM,
                       device = DEVICE)

model_Goal = model_Goal.to(DEVICE)

optimizer_GoalModel = optim.Adam(model_Goal.parameters(), lr = 1e-4)

### Requested Model

In [10]:
class RequestedModel(nn.Module):
    
    def __init__(self,
                 vocabulary_size,
                 embedding_dim,
                 altered_embedding_dim,
                 hidden_dim,
                 requested_dim,
                 device):
        
        super(RequestedModel, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.requested_dim = requested_dim
        self.device = device
        
        self.embeddings = nn.Embedding(num_embeddings = vocabulary_size,
                                       embedding_dim = embedding_dim)
        
        self.altered_embeddings = nn.Linear(in_features = (embedding_dim + 1),
                                            out_features = altered_embedding_dim)
        
        self.lstm = nn.LSTM(input_size = altered_embedding_dim,
                            hidden_size = hidden_dim)
        
        self.requested_classifier = nn.Linear(in_features = hidden_dim,
                                              out_features = requested_dim)
        
        self.hidden = self.init_hidden()
    
    def init_hidden(self):
        
        return (torch.zeros(1, 1, self.hidden_dim, device = self.device),
                torch.zeros(1, 1, self.hidden_dim, device = self.device))
        
    def forward(self, indices, scores):
    
        embeddings = self.embeddings(indices)
        
        embeddings_concat_score = torch.cat((embeddings, scores.unsqueeze(dim = 1)), dim = 1) 
        
        altered_embeddings = F.relu(self.altered_embeddings(embeddings_concat_score))
        
        lstm_out, self.hidden = self.lstm(altered_embeddings.view(len(indices), 1, -1), self.hidden)
        
        requested = torch.sigmoid(self.requested_classifier(self.hidden[0]).view(-1, self.requested_dim))
        
        return requested

model_Requested = RequestedModel(vocabulary_size = VOCABULARY_SIZE,
                                 embedding_dim = EMBEDDING_DIM,
                                 altered_embedding_dim = ALTERED_EMBEDDING_DIM,
                                 hidden_dim = HIDDEN_DIM,
                                 requested_dim = REQUESTED_DIM,
                                 device = DEVICE)

model_Requested = model_Requested.to(DEVICE)

optimizer_RequestedModel = optim.Adam(model_Requested.parameters(), lr = 1e-4)

### Method Model

In [11]:
class MethodModel(nn.Module):
    
    def __init__(self,
                 vocabulary_size,
                 embedding_dim,
                 altered_embedding_dim,
                 hidden_dim,
                 method_dim,
                 device):
        
        super(MethodModel, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.method_dim = method_dim
        self.device = device
        
        self.embeddings = nn.Embedding(num_embeddings = vocabulary_size,
                                       embedding_dim = embedding_dim)
        
        self.altered_embeddings = nn.Linear(in_features = (embedding_dim + 1),
                                            out_features = altered_embedding_dim)
        
        self.lstm = nn.LSTM(input_size = altered_embedding_dim,
                            hidden_size = hidden_dim)
        
        self.method_classifier = nn.Linear(in_features = hidden_dim,
                                           out_features = method_dim)
        
        self.hidden = self.init_hidden()
    
    def init_hidden(self):
        
        return (torch.zeros(1, 1, self.hidden_dim, device = self.device),
                torch.zeros(1, 1, self.hidden_dim, device = self.device))
        
    def forward(self, indices, scores):
    
        embeddings = self.embeddings(indices)
        
        embeddings_concat_score = torch.cat((embeddings, scores.unsqueeze(dim = 1)), dim = 1) 
        
        altered_embeddings = F.relu(self.altered_embeddings(embeddings_concat_score))
        
        lstm_out, self.hidden = self.lstm(altered_embeddings.view(len(indices), 1, -1), self.hidden)
        
        method = F.log_softmax(self.method_classifier(self.hidden[0]).view(-1, self.method_dim), dim = 1)
        
        return method

model_Method = MethodModel(vocabulary_size = VOCABULARY_SIZE,
                           embedding_dim = EMBEDDING_DIM,
                           altered_embedding_dim = ALTERED_EMBEDDING_DIM,
                           hidden_dim = HIDDEN_DIM,
                           method_dim = METHOD_DIM,
                           device = DEVICE)

model_Method = model_Method.to(DEVICE)

optimizer_MethodModel = optim.Adam(model_Method.parameters(), lr = 1e-4)

### Train Goal Model

In [12]:
dev_goal_accuracies = []
dev_requested_accuracies = []
dev_method_accuracies = []

test_goal_accuracies = []
test_requested_accuracies = []
test_method_accuracies = []

goal_early_stopping = EarlyStopping(patience = 2)

train_indices = np.arange(raw_X_train.shape[0])

for epoch in range(NUM_EPOCHS):
    
    np.random.shuffle(train_indices)
    
    logging.info("Epoch\t{}/{}".format(epoch + 1, NUM_EPOCHS))
    
    model_Goal = model_Goal.train()
    
    for raw_X_train_dialog, raw_Y_train_dialog in tqdm_notebook(zip(raw_X_train[train_indices], raw_Y_train[train_indices]), total = len(raw_X_train)):

        model_Goal.hidden = model_Goal.init_hidden()

        for raw_X_train_turn, raw_Y_train_turn in zip(raw_X_train_dialog["turns"], raw_Y_train_dialog["turns"]):
            
            optimizer_GoalModel.zero_grad()
            
            indices, scores = get_index_and_score(raw_X_train_turn, token_to_index, mode = "train", device = DEVICE)
            
            goal_food, goal_pricerange, goal_name, goal_area = model_Goal(indices, scores)
            
            loss_goal_food = GOAL_LOSS_FUNCTION(goal_food,
                                                retrieve_gold_GoalFood(raw_Y_train_turn,
                                                                       ontology = ontology,
                                                                       device = DEVICE))
            
            loss_goal_pricerange = GOAL_LOSS_FUNCTION(goal_pricerange,
                                                      retrieve_gold_GoalPriceRange(raw_Y_train_turn,
                                                                                   ontology = ontology,
                                                                                   device = DEVICE))
            
            loss_goal_name = GOAL_LOSS_FUNCTION(goal_name,
                                                retrieve_gold_GoalName(raw_Y_train_turn,
                                                                       ontology = ontology,
                                                                       device = DEVICE))
            
            loss_goal_area = GOAL_LOSS_FUNCTION(goal_area,
                                                retrieve_gold_GoalArea(raw_Y_train_turn,
                                                                       ontology = ontology,
                                                                       device = DEVICE))

            loss = loss_goal_food + loss_goal_pricerange + loss_goal_name + loss_goal_area
            
            loss.backward(retain_graph = True)
            
            optimizer_GoalModel.step()
    
    dev_goal_accuracy, \
    dev_goal_l2, \
    dev_requested_accuracy, \
    dev_requested_l2, \
    dev_method_accuracy, \
    dev_method_l2 = get_scores(model_Goal, model_Requested, model_Method, raw_X_dev, raw_Y_dev, dataset = "dstc2_dev")
    logging.info("DEV Acc:\t\t{}({})\t\t{}({})\t\t{}({})".format(dev_goal_accuracy,
                                                                 np.around(dev_goal_accuracy, decimals = 2),
                                                                 dev_requested_accuracy,
                                                                 np.around(dev_requested_accuracy, decimals = 2),
                                                                 dev_method_accuracy,
                                                                 np.around(dev_method_accuracy, decimals = 2)))
    logging.info("DEV L2:\t\t{}({})\t\t{}({})\t\t{}({})".format(dev_goal_l2,
                                                                np.around(dev_goal_l2, decimals = 2),
                                                                dev_requested_l2,
                                                                np.around(dev_requested_l2, decimals = 2),
                                                                dev_method_l2,
                                                                np.around(dev_method_l2, decimals = 2)))
    dev_goal_accuracies.append(dev_goal_accuracy)
    dev_requested_accuracies.append(dev_requested_accuracy)
    dev_method_accuracies.append(dev_method_accuracy)

    test_goal_accuracy, \
    test_goal_l2, \
    test_requested_accuracy, \
    test_requested_l2, \
    test_method_accuracy, \
    test_method_l2 = get_scores(model_Goal, model_Requested, model_Method, raw_X_test, raw_Y_test, dataset = "dstc2_test")
    logging.info("TEST Acc:\t\t{}({})\t\t{}({})\t\t{}({})".format(test_goal_accuracy,
                                                                  np.around(test_goal_accuracy, decimals = 2),
                                                                  test_requested_accuracy,
                                                                  np.around(test_requested_accuracy, decimals = 2),
                                                                  test_method_accuracy,
                                                                  np.around(test_method_accuracy, decimals = 2)))
    logging.info("TEST L2:\t\t{}({})\t\t{}({})\t\t{}({})".format(test_goal_l2,
                                                                 np.around(test_goal_l2, decimals = 2),
                                                                 test_requested_l2,
                                                                 np.around(test_requested_l2, decimals = 2),
                                                                 test_method_l2,
                                                                 np.around(test_method_l2, decimals = 2)))
    test_goal_accuracies.append(test_goal_accuracy)
    test_requested_accuracies.append(test_requested_accuracy)
    test_method_accuracies.append(test_method_accuracy)
        
    goal_early_stopping.on_epoch_end(epoch = (epoch + 1), current_value = (dev_goal_accuracy))
    
    if goal_early_stopping.wait == 0:
        torch.save(model_Goal.state_dict(), "model_GMR_Goal.pt")
    
    if goal_early_stopping.stop_training:
        break

plotly_plot(dev_goal_accuracies, test_goal_accuracies, "Goal")

INFO:root:Epoch	1/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.3905109(0.39)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.7803549(0.78)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.3186087(0.32)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.8308832(0.83)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	2/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.5448384(0.54)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.6214707(0.62)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.5508308(0.55)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.6095616(0.61)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	3/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.5740355(0.57)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5692543(0.57)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6024358(0.6)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5336226(0.53)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	4/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.5920229(0.59)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5570582(0.56)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6226649(0.62)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5168681(0.52)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	5/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.5888947(0.59)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5637573(0.56)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6438229(0.64)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.4939857(0.49)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	6/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.5865485(0.59)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5641493(0.56)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6124471(0.61)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5362461(0.54)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	7/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6037539(0.6)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.552765(0.55)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6199814(0.62)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5294719(0.53)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	8/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6076642(0.61)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5436291(0.54)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6442357(0.64)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5065986(0.51)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	9/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6063608(0.61)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5620762(0.56)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.640933(0.64)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5156537(0.52)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	10/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6076642(0.61)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5527126(0.55)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6515636(0.65)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5071808(0.51)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	11/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6141814(0.61)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5584425(0.56)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6553824(0.66)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.4994026(0.5)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	12/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6021898(0.6)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5689128(0.57)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6544535(0.65)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5090958(0.51)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	13/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6157456(0.62)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5499661(0.55)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6443389(0.64)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5147623(0.51)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	14/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.62122(0.62)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5518452(0.55)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6474352(0.65)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5070138(0.51)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	15/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6162669(0.62)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5556958(0.56)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6402105(0.64)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5191476(0.52)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	16/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6071429(0.61)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5732511(0.57)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6539375(0.65)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5131092(0.51)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	17/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	18/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.604536(0.6)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5691399(0.57)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6266901(0.63)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5437139(0.54)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	19/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6152242(0.62)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.575073(0.58)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6442357(0.64)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5261277(0.53)		0.8031016(0.8)		0.48694(0.49)
INFO:root:Epoch	20/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6167883(0.62)		0.5268662(0.53)		0.2066736(0.21)
INFO:root:DEV L2:		0.5548768(0.55)		0.8014078(0.8)		0.4973794(0.5)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.655692(0.66)		0.5520056(0.55)		0.1792898(0.18)
INFO:root:TEST L2:		0.5182194(0.52)		0.8031016(0.8)		0.48694(0.49)


### Load Goal Model

In [13]:
model_Goal = GoalModel(vocabulary_size = VOCABULARY_SIZE,
                       embedding_dim = EMBEDDING_DIM,
                       altered_embedding_dim = ALTERED_EMBEDDING_DIM,
                       hidden_dim = HIDDEN_DIM,
                       goal_food_dim = GOAL_FOOD_DIM,
                       goal_pricerange_dim = GOAL_PRICERANGE_DIM,
                       goal_name_dim = GOAL_NAME_DIM,
                       goal_area_dim = GOAL_AREA_DIM,
                       device = DEVICE)
model_Goal = model_Goal.to(DEVICE)
model_Goal.load_state_dict(torch.load("model_GMR_Goal.pt"))
model_Goal.eval()

GoalModel(
  (embeddings): Embedding(1149, 170)
  (altered_embeddings): Linear(in_features=171, out_features=300, bias=True)
  (lstm): LSTM(300, 100)
  (goal_food_classifier): Linear(in_features=100, out_features=93, bias=True)
  (goal_pricerange_classifier): Linear(in_features=100, out_features=5, bias=True)
  (goal_name_classifier): Linear(in_features=100, out_features=115, bias=True)
  (goal_area_classifier): Linear(in_features=100, out_features=7, bias=True)
)

### Print scores

In [14]:
get_scores(model_Goal, model_Requested, model_Method, raw_X_dev, raw_Y_dev, dataset = "dstc2_dev")

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




(0.6238269, 0.5467848, 0.5268662, 0.8014078, 0.2066736, 0.4973794)

In [15]:
get_scores(model_Goal, model_Requested, model_Method, raw_X_test, raw_Y_test, dataset = "dstc2_test")

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




(0.6546599, 0.5148621, 0.5520056, 0.8031016, 0.1792898, 0.48694)

### Train Requested Model

In [16]:
dev_goal_accuracies = []
dev_requested_accuracies = []
dev_method_accuracies = []

test_goal_accuracies = []
test_requested_accuracies = []
test_method_accuracies = []

requested_early_stopping = EarlyStopping(patience = 2)

train_indices = np.arange(raw_X_train.shape[0])

for epoch in range(NUM_EPOCHS):
    
    np.random.shuffle(train_indices)
    
    logging.info("Epoch\t{}/{}".format(epoch + 1, NUM_EPOCHS))
    
    model_Requested = model_Requested.train()
    
    for raw_X_train_dialog, raw_Y_train_dialog in tqdm_notebook(zip(raw_X_train[train_indices], raw_Y_train[train_indices]), total = len(raw_X_train)):

        model_Requested.hidden = model_Requested.init_hidden()

        for raw_X_train_turn, raw_Y_train_turn in zip(raw_X_train_dialog["turns"], raw_Y_train_dialog["turns"]):
            
            optimizer_RequestedModel.zero_grad()
            
            indices, scores = get_index_and_score(raw_X_train_turn, token_to_index, mode = "train", device = DEVICE)
            
            requested = model_Requested(indices, scores)
            
            loss_requested = REQUESTED_LOSS_FUNCTION(requested,
                                                     retrieve_gold_Requested(raw_Y_train_turn,
                                                                             ontology = ontology,
                                                                             device = DEVICE))
            loss_requested.backward(retain_graph = True)
            
            optimizer_RequestedModel.step()
    
    dev_goal_accuracy, \
    dev_goal_l2, \
    dev_requested_accuracy, \
    dev_requested_l2, \
    dev_method_accuracy, \
    dev_method_l2 = get_scores(model_Goal, model_Requested, model_Method, raw_X_dev, raw_Y_dev, dataset = "dstc2_dev")
    logging.info("DEV Acc:\t\t{}({})\t\t{}({})\t\t{}({})".format(dev_goal_accuracy,
                                                                 np.around(dev_goal_accuracy, decimals = 2),
                                                                 dev_requested_accuracy,
                                                                 np.around(dev_requested_accuracy, decimals = 2),
                                                                 dev_method_accuracy,
                                                                 np.around(dev_method_accuracy, decimals = 2)))
    logging.info("DEV L2:\t\t{}({})\t\t{}({})\t\t{}({})".format(dev_goal_l2,
                                                                np.around(dev_goal_l2, decimals = 2),
                                                                dev_requested_l2,
                                                                np.around(dev_requested_l2, decimals = 2),
                                                                dev_method_l2,
                                                                np.around(dev_method_l2, decimals = 2)))
    dev_goal_accuracies.append(dev_goal_accuracy)
    dev_requested_accuracies.append(dev_requested_accuracy)
    dev_method_accuracies.append(dev_method_accuracy)

    test_goal_accuracy, \
    test_goal_l2, \
    test_requested_accuracy, \
    test_requested_l2, \
    test_method_accuracy, \
    test_method_l2 = get_scores(model_Goal, model_Requested, model_Method, raw_X_test, raw_Y_test, dataset = "dstc2_test")
    logging.info("TEST Acc:\t\t{}({})\t\t{}({})\t\t{}({})".format(test_goal_accuracy,
                                                              np.around(test_goal_accuracy, decimals = 2),
                                                              test_requested_accuracy,
                                                              np.around(test_requested_accuracy, decimals = 2),
                                                              test_method_accuracy,
                                                              np.around(test_method_accuracy, decimals = 2)))
    logging.info("TEST L2:\t\t{}({})\t\t{}({})\t\t{}({})".format(test_goal_l2,
                                                              np.around(test_goal_l2, decimals = 2),
                                                              test_requested_l2,
                                                              np.around(test_requested_l2, decimals = 2),
                                                              test_method_l2,
                                                              np.around(test_method_l2, decimals = 2)))
    test_goal_accuracies.append(test_goal_accuracy)
    test_requested_accuracies.append(test_requested_accuracy)
    test_method_accuracies.append(test_method_accuracy)
        
    requested_early_stopping.on_epoch_end(epoch = (epoch + 1), current_value = (dev_requested_accuracy))
    
    if requested_early_stopping.wait == 0:
        torch.save(model_Requested.state_dict(), "model_GMR_Requested.pt")
    
    if requested_early_stopping.stop_training:
        break

plotly_plot(dev_requested_accuracies, test_requested_accuracies, "Requested")

INFO:root:Epoch	1/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9509556(0.95)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0914085(0.09)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9644035(0.96)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0668484(0.07)
INFO:root:Epoch	2/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9585287(0.96)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0784542(0.08)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9666609(0.97)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0602641(0.06)
INFO:root:Epoch	3/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9632167(0.96)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0693886(0.07)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9718701(0.97)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0510348(0.05)
INFO:root:Epoch	4/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.961053(0.96)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0722257(0.07)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9739538(0.97)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0469178(0.05)
INFO:root:Epoch	5/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9635774(0.96)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0672695(0.07)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9722174(0.97)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0511782(0.05)
INFO:root:Epoch	6/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9642986(0.96)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0676361(0.07)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9765584(0.98)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0447948(0.04)
INFO:root:Epoch	7/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9628561(0.96)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0695333(0.07)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9743011(0.97)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0463556(0.05)
INFO:root:Epoch	8/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9617743(0.96)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0718451(0.07)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9753429(0.98)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0457573(0.05)
INFO:root:Epoch	9/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9650198(0.97)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0659797(0.07)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9746484(0.97)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.045915(0.05)
INFO:root:Epoch	10/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9679048(0.97)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0610563(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9762111(0.98)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0426277(0.04)
INFO:root:Epoch	11/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9671836(0.97)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0626319(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9746484(0.97)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0457326(0.05)
INFO:root:Epoch	12/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9646592(0.96)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.0648155(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9743011(0.97)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0450655(0.05)
INFO:root:Epoch	13/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9585287(0.96)		0.2066736(0.21)
INFO:root:DEV L2:		0.5467848(0.55)		0.8014078(0.8)		0.081264(0.08)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9710019(0.97)		0.1792898(0.18)
INFO:root:TEST L2:		0.5148621(0.51)		0.8031016(0.8)		0.0551098(0.06)


### Load Requested Model

In [17]:
model_Requested = RequestedModel(vocabulary_size = VOCABULARY_SIZE,
                                 embedding_dim = EMBEDDING_DIM,
                                 altered_embedding_dim = ALTERED_EMBEDDING_DIM,
                                 hidden_dim = HIDDEN_DIM,
                                 requested_dim = REQUESTED_DIM,
                                 device = DEVICE)

model_Requested = model_Requested.to(DEVICE)
model_Requested.load_state_dict(torch.load("model_GMR_Requested.pt"))
model_Requested.eval()

RequestedModel(
  (embeddings): Embedding(1149, 170)
  (altered_embeddings): Linear(in_features=171, out_features=300, bias=True)
  (lstm): LSTM(300, 100)
  (requested_classifier): Linear(in_features=100, out_features=8, bias=True)
)

### Print scores

In [18]:
get_scores(model_Goal, model_Requested, model_Method, raw_X_dev, raw_Y_dev, dataset = "dstc2_dev")

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




(0.6238269, 0.5467848, 0.9679048, 0.8014078, 0.2066736, 0.0610563)

In [19]:
get_scores(model_Goal, model_Requested, model_Method, raw_X_test, raw_Y_test, dataset = "dstc2_test")

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




(0.6546599, 0.5148621, 0.9762111, 0.8031016, 0.1792898, 0.0426277)

### Train Method Model

In [20]:
dev_goal_accuracies = []
dev_requested_accuracies = []
dev_method_accuracies = []

test_goal_accuracies = []
test_requested_accuracies = []
test_method_accuracies = []

method_early_stopping = EarlyStopping(patience = 2)

train_indices = np.arange(raw_X_train.shape[0])

for epoch in range(NUM_EPOCHS):
    
    np.random.shuffle(train_indices)
    
    logging.info("Epoch\t{}/{}".format(epoch + 1, NUM_EPOCHS))
    
    model_Method = model_Method.train()
    
    for raw_X_train_dialog, raw_Y_train_dialog in tqdm_notebook(zip(raw_X_train[train_indices], raw_Y_train[train_indices]), total = len(raw_X_train)):

        model_Method.hidden = model_Method.init_hidden()

        for raw_X_train_turn, raw_Y_train_turn in zip(raw_X_train_dialog["turns"], raw_Y_train_dialog["turns"]):
            
            optimizer_MethodModel.zero_grad()
            
            indices, scores = get_index_and_score(raw_X_train_turn, token_to_index, mode = "train", device = DEVICE)
            
            method = model_Method(indices, scores)
            
            loss_method = METHOD_LOSS_FUNCTION(method,
                                               retrieve_gold_Method(raw_Y_train_turn,
                                                                    ontology = ontology,
                                                                    device = DEVICE))
            
            loss_method.backward(retain_graph = True)
            
            optimizer_MethodModel.step()
    
    dev_goal_accuracy, \
    dev_goal_l2, \
    dev_requested_accuracy, \
    dev_requested_l2, \
    dev_method_accuracy, \
    dev_method_l2 = get_scores(model_Goal, model_Requested, model_Method, raw_X_dev, raw_Y_dev, dataset = "dstc2_dev")
    logging.info("DEV Acc:\t\t{}({})\t\t{}({})\t\t{}({})".format(dev_goal_accuracy,
                                                                 np.around(dev_goal_accuracy, decimals = 2),
                                                                 dev_requested_accuracy,
                                                                 np.around(dev_requested_accuracy, decimals = 2),
                                                                 dev_method_accuracy,
                                                                 np.around(dev_method_accuracy, decimals = 2)))
    logging.info("DEV L2:\t\t{}({})\t\t{}({})\t\t{}({})".format(dev_goal_l2,
                                                                np.around(dev_goal_l2, decimals = 2),
                                                                dev_requested_l2,
                                                                np.around(dev_requested_l2, decimals = 2),
                                                                dev_method_l2,
                                                                np.around(dev_method_l2, decimals = 2)))
    dev_goal_accuracies.append(dev_goal_accuracy)
    dev_requested_accuracies.append(dev_requested_accuracy)
    dev_method_accuracies.append(dev_method_accuracy)

    test_goal_accuracy, \
    test_goal_l2, \
    test_requested_accuracy, \
    test_requested_l2, \
    test_method_accuracy, \
    test_method_l2 = get_scores(model_Goal, model_Requested, model_Method, raw_X_test, raw_Y_test, dataset = "dstc2_test")
    logging.info("TEST Acc:\t\t{}({})\t\t{}({})\t\t{}({})".format(test_goal_accuracy,
                                                                  np.around(test_goal_accuracy, decimals = 2),
                                                                  test_requested_accuracy,
                                                                  np.around(test_requested_accuracy, decimals = 2),
                                                                  test_method_accuracy,
                                                                  np.around(test_method_accuracy, decimals = 2)))
    logging.info("TEST L2:\t\t{}({})\t\t{}({})\t\t{}({})".format(test_goal_l2,
                                                                 np.around(test_goal_l2, decimals = 2),
                                                                 test_requested_l2,
                                                                 np.around(test_requested_l2, decimals = 2),
                                                                 test_method_l2,
                                                                 np.around(test_method_l2, decimals = 2)))
    test_goal_accuracies.append(test_goal_accuracy)
    test_requested_accuracies.append(test_requested_accuracy)
    test_method_accuracies.append(test_method_accuracy)
        
    method_early_stopping.on_epoch_end(epoch = (epoch + 1),
                                    current_value = (dev_method_accuracy))
    
    if method_early_stopping.wait == 0:
        torch.save(model_Method.state_dict(), "model_GMR_Method.pt")
    
    if method_early_stopping.stop_training:
        break

plotly_plot(dev_method_accuracies, test_method_accuracies, "Method")

INFO:root:Epoch	1/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9679048(0.97)		0.8983445(0.9)
INFO:root:DEV L2:		0.5467848(0.55)		0.1678201(0.17)		0.0610563(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9762111(0.98)		0.9216954(0.92)
INFO:root:TEST L2:		0.5148621(0.51)		0.1251132(0.13)		0.0426277(0.04)
INFO:root:Epoch	2/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9679048(0.97)		0.9050698(0.91)
INFO:root:DEV L2:		0.5467848(0.55)		0.1609077(0.16)		0.0610563(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9762111(0.98)		0.9275452(0.93)
INFO:root:TEST L2:		0.5148621(0.51)		0.1179819(0.12)		0.0426277(0.04)
INFO:root:Epoch	3/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9679048(0.97)		0.9030005(0.9)
INFO:root:DEV L2:		0.5467848(0.55)		0.1623555(0.16)		0.0610563(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9762111(0.98)		0.9333949(0.93)
INFO:root:TEST L2:		0.5148621(0.51)		0.1104696(0.11)		0.0426277(0.04)
INFO:root:Epoch	4/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9679048(0.97)		0.9017072(0.9)
INFO:root:DEV L2:		0.5467848(0.55)		0.1613608(0.16)		0.0610563(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9762111(0.98)		0.9268268(0.93)
INFO:root:TEST L2:		0.5148621(0.51)		0.1196902(0.12)		0.0426277(0.04)
INFO:root:Epoch	5/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9679048(0.97)		0.9107605(0.91)
INFO:root:DEV L2:		0.5467848(0.55)		0.1560546(0.16)		0.0610563(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9762111(0.98)		0.9294951(0.93)
INFO:root:TEST L2:		0.5148621(0.51)		0.1198075(0.12)		0.0426277(0.04)
INFO:root:Epoch	6/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9679048(0.97)		0.9071392(0.91)
INFO:root:DEV L2:		0.5467848(0.55)		0.1615604(0.16)		0.0610563(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9762111(0.98)		0.9300082(0.93)
INFO:root:TEST L2:		0.5148621(0.51)		0.1173611(0.12)		0.0426277(0.04)
INFO:root:Epoch	7/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9679048(0.97)		0.9011899(0.9)
INFO:root:DEV L2:		0.5467848(0.55)		0.1652222(0.17)		0.0610563(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9762111(0.98)		0.9213875(0.92)
INFO:root:TEST L2:		0.5148621(0.51)		0.1268471(0.13)		0.0426277(0.04)
INFO:root:Epoch	8/30


HBox(children=(IntProgress(value=0, max=3224), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:DEV Acc:		0.6238269(0.62)		0.9679048(0.97)		0.9022245(0.9)
INFO:root:DEV L2:		0.5467848(0.55)		0.1659191(0.17)		0.0610563(0.06)


HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




INFO:root:TEST Acc:		0.6546599(0.65)		0.9762111(0.98)		0.9291872(0.93)
INFO:root:TEST L2:		0.5148621(0.51)		0.1164198(0.12)		0.0426277(0.04)


### Load Method Model

In [21]:
model_Method = MethodModel(vocabulary_size = VOCABULARY_SIZE,
                           embedding_dim = EMBEDDING_DIM,
                           altered_embedding_dim = ALTERED_EMBEDDING_DIM,
                           hidden_dim = HIDDEN_DIM,
                           method_dim = METHOD_DIM,
                           device = DEVICE)

model_Method = model_Method.to(DEVICE)
model_Method.load_state_dict(torch.load("model_GMR_Method.pt"))
model_Method.eval()

MethodModel(
  (embeddings): Embedding(1149, 170)
  (altered_embeddings): Linear(in_features=171, out_features=300, bias=True)
  (lstm): LSTM(300, 100)
  (method_classifier): Linear(in_features=100, out_features=5, bias=True)
)

### Print scores

In [22]:
get_scores(model_Goal, model_Requested, model_Method, raw_X_dev, raw_Y_dev, dataset = "dstc2_dev")

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




(0.6238269, 0.5467848, 0.9679048, 0.1560546, 0.9107605, 0.0610563)

In [23]:
get_scores(model_Goal, model_Requested, model_Method, raw_X_test, raw_Y_test, dataset = "dstc2_test")

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




(0.6546599, 0.5148621, 0.9762111, 0.1198075, 0.9294951, 0.0426277)