### Check GPU availability

In [1]:
!nvidia-smi

Tue Feb 19 18:46:04 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.145                Driver Version: 384.145                   |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  Off  | 00000000:02:00.0 Off |                  N/A |
| 23%   29C    P8    17W / 250W |    589MiB / 11172MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  Off  | 00000000:03:00.0 Off |                  N/A |
| 23%   36C    P8    17W / 250W |     10MiB / 11172MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                            

### Import libraries

In [2]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import copy
import numpy as np
import time
import json
import plotly
import logging
logging.getLogger().setLevel(logging.INFO)

from pprint import pprint
from tqdm import tqdm_notebook
from idst_util import trivial
from idst_util import dstc2
from dstc2.dstc2_scripts import score

from plotly.graph_objs import Scatter, Layout, Histogram, Histogram2d
from plotly.graph_objs.layout import Margin
plotly.offline.init_notebook_mode(connected = True)

[nltk_data] Downloading package punkt to /home/is/andrei-
[nltk_data]     cc/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


### Check DSTC2 availability and retrieve data

In [3]:
trivial.print_idst()

dstc2.check()

raw_X_train, raw_Y_train, \
raw_X_dev, raw_Y_dev, \
raw_X_test, raw_Y_test, \
ontology = dstc2.retrieve_raw_datasets()

INFO:root:+--------------------------------+
INFO:root:|         _ ____  ___________    |
INFO:root:|        (_) __ \/ ___/_  __/    |
INFO:root:|       / / / / /\__ \ / /       |
INFO:root:|      / / /_/ /___/ // /        |
INFO:root:|     /_/_____//____//_/         |
INFO:root:|                                |
INFO:root:+--------------------------------+
INFO:root:|Incremental Dialog State Tracker|
INFO:root:+--------------------------------+
INFO:root:+--------------------------------+
INFO:root:|     Dialog State Tracker 2     |
INFO:root:|         Data Checker           |
INFO:root:+--------------------------------+
INFO:root:Looking for dstc2 directory in .
INFO:root:dstc2 was found!
INFO:root:Looking for dstc2_traindev directory in ./dstc2
INFO:root:dstc2_traindev was found!
INFO:root:Looking for dstc2_test directory in ./dstc2
INFO:root:dstc2_test was found!
INFO:root:Looking for dstc2_scripts directory in ./dstc2
INFO:root:dstc2_scripts was found!
INFO:root:Done!
INFO:root:+-

HBox(children=(IntProgress(value=0, max=1612), HTML(value='')))

INFO:root:Extracting raw dev features





HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

INFO:root:Reading dstc2_test.flist
INFO:root:Asserted 1117 dialogs for dstc2_test.flist
INFO:root:Extracting raw test features





HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

INFO:root:Done!





### Set device

In [4]:
logging.info("+--------------------------------+")
logging.info("|             Device             |")
logging.info("+--------------------------------+")

GPU_ID = 0
DEVICE = torch.device("cuda:{}".format(GPU_ID) if torch.cuda.is_available() else "cpu")
if str(DEVICE) == "cpu":
    logging.warning("Running on CPU")
else:
    logging.info("Running on GPU {}".format(GPU_ID))

INFO:root:+--------------------------------+
INFO:root:|             Device             |
INFO:root:+--------------------------------+
INFO:root:Running on GPU 0


### Create vocabularies

In [5]:
logging.info("+--------------------------------+")
logging.info("|          Vocabulary            |")
logging.info("+--------------------------------+")
logging.info("Creating token_to_index, index_to_token and token_to_count dictionaries")

token_to_index = {"<unk>": 0}
index_to_token = {0: "<unk>"}
token_to_count = {"<unk>": 1}

for raw_train_dialog in tqdm_notebook(raw_X_train):
    for raw_train_turn in raw_train_dialog["turns"]:
        tokens_scores = raw_train_turn["system"] + raw_train_turn["user"]
        for token_score in tokens_scores:
            token = token_score[0]
            if token not in token_to_index:
                token_to_index[token] = len(token_to_index)
                index_to_token[len(token_to_index)] = token
                token_to_count[token] = 1
            else:
                token_to_count[token] += 1
                
assert len(token_to_index) == len(index_to_token)
assert len(token_to_index) == len(token_to_count)

logging.info("Done!")

INFO:root:+--------------------------------+
INFO:root:|          Vocabulary            |
INFO:root:+--------------------------------+
INFO:root:Creating token_to_index, index_to_token and token_to_count dictionaries


HBox(children=(IntProgress(value=0, max=1612), HTML(value='')))

INFO:root:Done!





### Execution configuration

In [6]:
logging.info("+--------------------------------+")
logging.info("|         Configuration          |")
logging.info("+--------------------------------+")

VOCABULARY_SIZE = len(token_to_index)

# NOTE: we add +2 because of null and dontcare cases
GOAL_FOOD_DIM = len(ontology["informable"]["food"]) + 2 
GOAL_PRICERANGE_DIM = len(ontology["informable"]["pricerange"]) + 2
GOAL_NAME_DIM = len(ontology["informable"]["name"]) + 2
GOAL_AREA_DIM = len(ontology["informable"]["area"]) + 2

METHOD_DIM = len(ontology["method"])

REQUESTED_DIM = len(ontology["requestable"])

EMBEDDING_DIM = 170
ALTERED_EMBEDDING_DIM = 300
HIDDEN_DIM = 100

NUM_EPOCHS = 100
BATCH_SIZE = 10
PATIENCE = 4

GOAL_LOSS_FUNCTION = nn.CrossEntropyLoss()
METHOD_LOSS_FUNCTION = nn.CrossEntropyLoss()
REQUESTED_LOSS_FUNCTION = nn.BCELoss()

logging.info("VOCABULARY_SIZE:\t\t\t{}".format(VOCABULARY_SIZE))

logging.info("GOAL_FOOD_DIM:\t\t\t{}".format(GOAL_FOOD_DIM))
logging.info("GOAL_PRICERANGE_DIM:\t\t\t{}".format(GOAL_PRICERANGE_DIM))
logging.info("GOAL_NAME_DIM:\t\t\t{}".format(GOAL_NAME_DIM))
logging.info("GOAL_AREA_DIM:\t\t\t{}".format(GOAL_AREA_DIM))

logging.info("METHOD_DIM:\t\t\t\t{}".format(METHOD_DIM))

logging.info("REQUESTED_DIM:\t\t\t{}".format(REQUESTED_DIM))

logging.info("EMBEDDING_DIM:\t\t\t{}".format(EMBEDDING_DIM))
logging.info("ALTERED_EMBEDDING_DIM:\t\t{}".format(ALTERED_EMBEDDING_DIM))
logging.info("HIDDEN_DIM:\t\t\t\t{}".format(HIDDEN_DIM))

logging.info("NUM_EPOCHS:\t\t\t\t{}".format(NUM_EPOCHS))
logging.info("BATCH_SIZE:\t\t\t\t{}".format(BATCH_SIZE))
logging.info("PATIENCE:\t\t\t\t{}".format(PATIENCE))

logging.info("GOAL_LOSS_FUNCTION:\t\t\t{}".format(GOAL_LOSS_FUNCTION))
logging.info("METHOD_LOSS_FUNCTION:\t\t\t{}".format(METHOD_LOSS_FUNCTION))
logging.info("REQUESTED_LOSS_FUNCTION:\t\t{}".format(REQUESTED_LOSS_FUNCTION))

INFO:root:+--------------------------------+
INFO:root:|         Configuration          |
INFO:root:+--------------------------------+
INFO:root:VOCABULARY_SIZE:			897
INFO:root:GOAL_FOOD_DIM:			93
INFO:root:GOAL_PRICERANGE_DIM:			5
INFO:root:GOAL_NAME_DIM:			115
INFO:root:GOAL_AREA_DIM:			7
INFO:root:METHOD_DIM:				5
INFO:root:REQUESTED_DIM:			8
INFO:root:EMBEDDING_DIM:			170
INFO:root:ALTERED_EMBEDDING_DIM:		300
INFO:root:HIDDEN_DIM:				100
INFO:root:NUM_EPOCHS:				100
INFO:root:BATCH_SIZE:				10
INFO:root:PATIENCE:				4
INFO:root:GOAL_LOSS_FUNCTION:			CrossEntropyLoss()
INFO:root:METHOD_LOSS_FUNCTION:			CrossEntropyLoss()
INFO:root:REQUESTED_LOSS_FUNCTION:		BCELoss()


### Utilities

In [7]:
def get_index_and_score(turn, token_to_index, mode, device):
    
    indices = []
    scores = []
    
    if mode == "train": # LecTrack 4.3: Out-of-Vocabulary Words
        for system_token, system_token_score in turn["system"]:
            indices.append(token_to_index[system_token])
            scores.append(system_token_score)
        for user_token, user_token_score in turn["user"]:
            if np.random.binomial(n = 1, p = 0.1) == 1:
                indices.append(token_to_index["<unk>"])
            else:
                indices.append(token_to_index[user_token])
            scores.append(user_token_score)
    else:
        tokens_scores = turn["system"] + turn["user"]
        for token, score in tokens_scores:
            if token not in token_to_index:
                indices.append(token_to_index["<unk>"])
            else:
                indices.append(token_to_index[token])
            scores.append(score)
            
    assert len(indices) == len(scores)
    
    return torch.tensor(indices, dtype = torch.long, device = device), torch.tensor(scores, dtype = torch.float, device = device)

# --------------------

class EarlyStopping():
    
    def __init__(self, min_delta = 0, patience = 0):
        
        self.min_delta = min_delta
        self.patience = patience
        self.wait = 0
        self.stopped_epoch = 0
        self.best = -np.Inf
        self.stop_training = False
    
    def on_epoch_end(self, epoch, current_value):
        if np.greater((current_value - self.min_delta), self.best):
            self.best = current_value
            self.wait = 0
        else:
            self.wait += 1
            if self.wait > self.patience:
                self.stopped_epoch = epoch
                self.stop_training = True
        return self.stop_training

# --------------------
    
def get_incremental_index_and_percentage(percentage, length):
    
    incremental_index = -1
    new_percentage = None
    
    if length != 0:
        incremental_index = int(np.around(percentage * length)) - 1
        new_percentage = np.around(((incremental_index + 1) / length), decimals = 2)
    
    return incremental_index, new_percentage

# --------------------

def make_tracker(model_All, raw_X, raw_Y, dataset, percentage = 1.0):
    
    model_All = model_All.eval()
    
    percentage_points = []
    
    with torch.no_grad():
        tracker_json = {}
        tracker_json["dataset"] = dataset
        tracker_json["sessions"] = []

        start_time = time.time()
        
        for raw_X_dialog, raw_Y_dialog in tqdm_notebook(zip(raw_X, raw_Y), total = len(raw_X)):
            
            model_All.hidden = model_All.init_hidden()
            
            session = {}
            session["session-id"] = raw_X_dialog["session-id"]
            session["turns"] = []

            for turn_num, (raw_X_turn, raw_Y_turn) in enumerate(zip(raw_X_dialog["turns"], raw_Y_dialog["turns"])):

                indices, scores = get_index_and_score(raw_X_turn, token_to_index, mode = "eval", device = DEVICE)
                
                # NOTE: percentage is based on user utterance
                incremental_index, new_percentage_point = get_incremental_index_and_percentage(percentage = percentage, length = len(raw_X_turn["user"]))
                incremental_index += len(raw_X_turn["system"])
                if new_percentage_point != None:
                    percentage_points.append(new_percentage_point)
                
                goal_priceranges, goal_areas, goal_names, goal_foods, requesteds, methods = model_All(indices, scores)
                goal_pricerange = goal_priceranges[incremental_index]
                goal_name = goal_names[incremental_index]
                goal_area = goal_areas[incremental_index]
                goal_food = goal_foods[incremental_index]
                requested = requesteds[incremental_index]
                method = methods[incremental_index]
                
                turn = {}
                turn["num"] = turn_num
                turn["goal-labels"] = {}
                turn["goal-labels"]["food"] = retrieve_output_GoalFood(goal_food, ontology)
                turn["goal-labels"]["pricerange"] = retrieve_output_GoalPricerange(goal_pricerange, ontology)
                turn["goal-labels"]["name"] = retrieve_output_GoalName(goal_name, ontology)
                turn["goal-labels"]["area"] = retrieve_output_GoalArea(goal_area, ontology)
                turn["requested-slots"] = retrieve_output_Requested(requested, ontology)
                turn["method-label"] = retrieve_output_Method(method, ontology)
                
                session["turns"].append(turn)
                
            tracker_json["sessions"].append(session)
            
        end_time = time.time()
        tracker_json["wall-time"] = end_time - start_time
        
        return tracker_json, np.around(np.mean(np.array(percentage_points)), decimals = 2)

# --------------------
    
def get_scores(tracker, dataset, ontology):
    
    scores_dict = None

    if dataset == "dstc2_train":
        scores_dict = score.compute_score(dataset = "dstc2_train", dataroot = "dstc2/dstc2_traindev/data", tracker_output = tracker, ontology = ontology)
    elif dataset == "dstc2_dev":
        scores_dict = score.compute_score(dataset = "dstc2_dev", dataroot = "dstc2/dstc2_traindev/data", tracker_output = tracker, ontology = ontology)
    else: # dataset == "dstc2_test"
        scores_dict = score.compute_score(dataset = "dstc2_test", dataroot = "dstc2/dstc2_test/data", tracker_output = tracker, ontology = ontology)
            
    return scores_dict

# --------------------

def retrieve_gold_GoalPricerange(raw_Y, ontology, device):
    ontology_informable_pricerange = ontology["informable"]["pricerange"]
    raw_goal_pricerange = raw_Y["goal"]["pricerange"]
    goal_pricerange = 0
    if raw_goal_pricerange != None:
        if raw_goal_pricerange == "dontcare":
            goal_pricerange = 1
        else:    
            goal_pricerange = ontology_informable_pricerange.index(raw_goal_pricerange) + 2
    return torch.tensor([goal_pricerange], dtype = torch.long, device = device)

def retrieve_output_GoalPricerange(output_tensor, ontology):
    ontology_informable_pricerange = ontology["informable"]["pricerange"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    goal_pricerange_dict = {}
    goal_pricerange_dict["dontcare"] = output_tensor[1].item()
    for index in range(len(output_tensor) - 2):     
        goal_pricerange_dict[ontology_informable_pricerange[index]] = output_tensor[index + 2].item()
    return goal_pricerange_dict

# --------------------

def retrieve_gold_GoalArea(raw_Y, ontology, device):
    ontology_informable_area = ontology["informable"]["area"]
    raw_goal_area = raw_Y["goal"]["area"]
    goal_area = 0
    if raw_goal_area != None:
        if raw_goal_area == "dontcare":
            goal_area = 1
        else:    
            goal_area = ontology_informable_area.index(raw_goal_area) + 2
    return torch.tensor([goal_area], dtype = torch.long, device = device)

def retrieve_output_GoalArea(output_tensor, ontology):
    ontology_informable_area = ontology["informable"]["area"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    goal_area_dict = {}
    goal_area_dict["dontcare"] = output_tensor[1].item()
    for index in range(len(output_tensor) - 2):
        goal_area_dict[ontology_informable_area[index]] = output_tensor[index + 2].item()
    return goal_area_dict

# --------------------

def retrieve_gold_GoalName(raw_Y, ontology, device):
    ontology_informable_name = ontology["informable"]["name"]
    raw_goal_name = raw_Y["goal"]["name"]
    goal_name = 0
    if raw_goal_name != None:
        if raw_goal_name == "dontcare":
            goal_name = 1
        else:    
            goal_name = ontology_informable_name.index(raw_goal_name) + 2
    return torch.tensor([goal_name], dtype = torch.long, device = device)

def retrieve_output_GoalName(output_tensor, ontology):
    ontology_informable_name = ontology["informable"]["name"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    goal_name_dict = {}
    goal_name_dict["dontcare"] = output_tensor[1].item()
    for index in range(len(output_tensor) - 2):
        goal_name_dict[ontology_informable_name[index]] = output_tensor[index + 2].item()
    return goal_name_dict

# --------------------

def retrieve_gold_GoalFood(raw_Y, ontology, device):
    ontology_informable_food = ontology["informable"]["food"]
    raw_goal_food = raw_Y["goal"]["food"]
    goal_food = 0
    if raw_goal_food != None:
        if raw_goal_food == "dontcare":
            goal_food = 1
        else:    
            goal_food = ontology_informable_food.index(raw_goal_food) + 2
    return torch.tensor([goal_food], dtype = torch.long, device = device)

def retrieve_output_GoalFood(output_tensor, ontology):
    ontology_informable_food = ontology["informable"]["food"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    goal_food_dict = {}
    goal_food_dict["dontcare"] = output_tensor[1].item() 
    for index in range(len(output_tensor) - 2):
        goal_food_dict[ontology_informable_food[index]] = output_tensor[index + 2].item()
    return goal_food_dict

# --------------------

def retrieve_gold_Requested(raw_Y, ontology, device):
    ontology_requestable = ontology["requestable"]
    raw_gold_requested = raw_Y["requested"]
    gold_requested = np.zeros(len(ontology_requestable), dtype = float)
    if len(raw_gold_requested) != 0:
        for requested in raw_gold_requested:
            gold_requested[ontology_requestable.index(requested)] = 1.0
    return torch.tensor([gold_requested], dtype = torch.float, device = device)

def retrieve_output_Requested(output_tensor, ontology):
    ontology_requestable = ontology["requestable"]
    output_tensor = output_tensor.view(-1)
    requested_dict = {}
    for index in range(len(output_tensor)):
        probability_value = output_tensor[index].item()
        requested_dict[ontology_requestable[index]] = probability_value
    return requested_dict

# --------------------

def retrieve_gold_Method(raw_Y, ontology, device):
    ontology_method = ontology["method"]
    raw_gold_method = raw_Y["method"]
    gold_method = ontology_method.index(raw_gold_method)
    return torch.tensor([gold_method], dtype = torch.long, device = device)

def retrieve_output_Method(output_tensor, ontology):
    ontology_method = ontology["method"]
    output_tensor = output_tensor.view(-1)
    output_tensor = torch.exp(output_tensor)
    method_dict = {}
    for index in range(len(output_tensor)):
        method_dict[ontology_method[index]] = output_tensor[index].item()
    return method_dict

### iDST All Model

In [8]:
class iDSTAllModel(nn.Module):
    
    def __init__(self, vocabulary_size, embedding_dim, altered_embedding_dim, hidden_dim,
                 goal_pricerange_dim, goal_area_dim, goal_name_dim, goal_food_dim, requested_dim, method_dim, device):
        super(iDSTAllModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.device = device
        self.goal_pricerange_dim = goal_pricerange_dim
        self.goal_area_dim = goal_area_dim
        self.goal_name_dim = goal_name_dim
        self.goal_food_dim = goal_food_dim
        self.requested_dim = requested_dim
        self.method_dim = method_dim
        self.embeddings = nn.Embedding(num_embeddings = vocabulary_size, embedding_dim = embedding_dim)
        self.altered_embeddings = nn.Linear(in_features = (embedding_dim + 1), out_features = altered_embedding_dim) # +1 for the ASR-score
        self.lstm = nn.LSTM(input_size = altered_embedding_dim, hidden_size = hidden_dim)
        self.goal_pricerange_classifier = nn.Linear(in_features = hidden_dim, out_features = goal_pricerange_dim)
        self.goal_area_classifier = nn.Linear(in_features = hidden_dim, out_features = goal_area_dim)
        self.goal_name_classifier = nn.Linear(in_features = hidden_dim, out_features = goal_name_dim)
        self.goal_food_classifier = nn.Linear(in_features = hidden_dim, out_features = goal_food_dim)
        self.requested_classifier = nn.Linear(in_features = hidden_dim, out_features = requested_dim)
        self.method_classifier = nn.Linear(in_features = hidden_dim, out_features = method_dim)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_dim, device = self.device),
                torch.zeros(1, 1, self.hidden_dim, device = self.device))

    def forward(self, indices, scores):
        embeddings = self.embeddings(indices)
        embeddings_concat_score = torch.cat((embeddings, scores.unsqueeze(dim = 1)), dim = 1) 
        altered_embeddings = F.relu(self.altered_embeddings(embeddings_concat_score))
        lstm_out, self.hidden = self.lstm(altered_embeddings.view(len(indices), 1, -1), self.hidden)
        goal_pricerange_output = F.log_softmax(self.goal_pricerange_classifier(lstm_out).view(-1, self.goal_pricerange_dim), dim = 1)
        goal_area_output = F.log_softmax(self.goal_area_classifier(lstm_out).view(-1, self.goal_area_dim), dim = 1)
        goal_name_output = F.log_softmax(self.goal_name_classifier(lstm_out).view(-1, self.goal_name_dim), dim = 1)
        goal_food_output = F.log_softmax(self.goal_food_classifier(lstm_out).view(-1, self.goal_food_dim), dim = 1)
        requested_output = torch.sigmoid(self.requested_classifier(lstm_out).view(-1, self.requested_dim))
        method_output = F.log_softmax(self.method_classifier(lstm_out).view(-1, self.method_dim), dim = 1)
        return goal_pricerange_output, goal_area_output, goal_name_output, goal_food_output, requested_output, method_output

### All Model

In [9]:
model_All = iDSTAllModel(vocabulary_size = VOCABULARY_SIZE,
                         embedding_dim = EMBEDDING_DIM,
                         altered_embedding_dim = ALTERED_EMBEDDING_DIM,
                         hidden_dim = HIDDEN_DIM,
                         goal_pricerange_dim = GOAL_PRICERANGE_DIM,
                         goal_area_dim = GOAL_AREA_DIM,
                         goal_name_dim = GOAL_NAME_DIM,
                         goal_food_dim = GOAL_FOOD_DIM,
                         requested_dim = REQUESTED_DIM,
                         method_dim = METHOD_DIM,
                         device = DEVICE).to(DEVICE)
optimizer_All = optim.Adam(model_All.parameters(), lr = 1e-3, amsgrad = True) 

### Train iDST All Model

In [10]:
all_early_stopping = EarlyStopping(patience = PATIENCE)

train_indices_loader = torch.utils.data.DataLoader(np.arange(raw_X_train.shape[0]), batch_size = BATCH_SIZE, shuffle = True)

for epoch in range(NUM_EPOCHS):
    
    logging.info("Epoch\t{}/{}".format(epoch + 1, NUM_EPOCHS))
    
    if not all_early_stopping.stop_training:
        model_All = model_All.train()
    
    for train_indices in tqdm_notebook(train_indices_loader, total = len(train_indices_loader)):
        
        if not all_early_stopping.stop_training:
            optimizer_All.zero_grad()
            all_accumulated_loss = 0
        
        for raw_X_train_dialog, raw_Y_train_dialog in zip(raw_X_train[train_indices], raw_Y_train[train_indices]):
            
            if not all_early_stopping.stop_training:
                model_All.hidden = model_All.init_hidden()
                
            for raw_X_train_turn, raw_Y_train_turn in zip(raw_X_train_dialog["turns"], raw_Y_train_dialog["turns"]):

                indices, scores = get_index_and_score(raw_X_train_turn, token_to_index, mode = "train", device = DEVICE)
                
                if not all_early_stopping.stop_training:
                    goal_pricerange_outputs, goal_area_outputs, goal_name_outputs, goal_food_outputs, requested_outputs, method_outputs = model_All(indices, scores)
                    gold_goal_pricerange = retrieve_gold_GoalPricerange(raw_Y_train_turn, ontology = ontology, device = DEVICE).repeat(len(goal_pricerange_outputs))
                    all_accumulated_loss += GOAL_LOSS_FUNCTION(goal_pricerange_outputs, gold_goal_pricerange)
                    gold_goal_area = retrieve_gold_GoalArea(raw_Y_train_turn, ontology = ontology, device = DEVICE).repeat(len(goal_area_outputs))
                    all_accumulated_loss += GOAL_LOSS_FUNCTION(goal_area_outputs, gold_goal_area)
                    gold_goal_name = retrieve_gold_GoalName(raw_Y_train_turn, ontology = ontology, device = DEVICE).repeat(len(goal_name_outputs))
                    all_accumulated_loss += GOAL_LOSS_FUNCTION(goal_name_outputs, gold_goal_name)
                    gold_goal_food = retrieve_gold_GoalFood(raw_Y_train_turn, ontology = ontology, device = DEVICE).repeat(len(goal_food_outputs))
                    all_accumulated_loss += GOAL_LOSS_FUNCTION(goal_food_outputs, gold_goal_food)
                    gold_requested = retrieve_gold_Requested(raw_Y_train_turn, ontology = ontology, device = DEVICE).repeat(requested_outputs.size(0), 1)
                    all_accumulated_loss += REQUESTED_LOSS_FUNCTION(requested_outputs, gold_requested)
                    gold_method = retrieve_gold_Method(raw_Y_train_turn, ontology = ontology, device = DEVICE).repeat(len(method_outputs))
                    all_accumulated_loss += METHOD_LOSS_FUNCTION(method_outputs, gold_method)
                    
        if not all_early_stopping.stop_training:
            all_accumulated_loss.backward()
            optimizer_All.step()
        
    dev_tracker, _ = make_tracker(model_All, raw_X_dev, raw_Y_dev, dataset = "dstc2_dev", percentage = 1.0)
    
    dev_scores_dict = get_scores(dev_tracker, dataset = "dstc2_dev", ontology = ontology)
    
    logging.info(dev_scores_dict)
    
    current_score_value = dev_scores_dict["goal_pricerange_accuracy"] + dev_scores_dict["goal_area_accuracy"] + dev_scores_dict["goal_name_accuracy"] + dev_scores_dict["goal_food_accuracy"] + dev_scores_dict["requested_all_accuracy"] + dev_scores_dict["method_accuracy"]  

    all_early_stopping.on_epoch_end(epoch = (epoch + 1), current_value = current_score_value)
    
    if all_early_stopping.wait == 0:
        torch.save(model_All.state_dict(), "model_All.pt")
        
    if all_early_stopping.stop_training:
        break

INFO:root:Epoch	1/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.1794274, 'goal_food_l2': 0.904136, 'goal_pricerange_accuracy': 0.3438373, 'goal_pricerange_l2': 0.7788252, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1852795, 'goal_area_accuracy': 0.2961965, 'goal_area_l2': 0.78379, 'goal_joint_accuracy': 0.0573514, 'goal_joint_l2': 0.9873674, 'requested_all_accuracy': 0.6004327, 'requested_all_l2': 0.664727, 'method_accuracy': 0.7899638, 'method_l2': 0.3457163}
INFO:root:Epoch	2/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.2237017, 'goal_food_l2': 0.8669971, 'goal_pricerange_accuracy': 0.6713203, 'goal_pricerange_l2': 0.5439237, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1852623, 'goal_area_accuracy': 0.5142258, 'goal_area_l2': 0.6812801, 'goal_joint_accuracy': 0.1631908, 'goal_joint_l2': 0.9615847, 'requested_all_accuracy': 0.6004327, 'requested_all_l2': 0.6152298, 'method_accuracy': 0.7969477, 'method_l2': 0.3118477}
INFO:root:Epoch	3/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.3049268, 'goal_food_l2': 0.8451595, 'goal_pricerange_accuracy': 0.6980455, 'goal_pricerange_l2': 0.4582337, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1849616, 'goal_area_accuracy': 0.7187781, 'goal_area_l2': 0.5227894, 'goal_joint_accuracy': 0.2546924, 'goal_joint_l2': 0.9200817, 'requested_all_accuracy': 0.6004327, 'requested_all_l2': 0.6363887, 'method_accuracy': 0.8039317, 'method_l2': 0.2810991}
INFO:root:Epoch	4/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.349534, 'goal_food_l2': 0.8096147, 'goal_pricerange_accuracy': 0.832868, 'goal_pricerange_l2': 0.2894157, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1848635, 'goal_area_accuracy': 0.7604073, 'goal_area_l2': 0.4331716, 'goal_joint_accuracy': 0.30683, 'goal_joint_l2': 0.874074, 'requested_all_accuracy': 0.6004327, 'requested_all_l2': 0.5705887, 'method_accuracy': 0.8259183, 'method_l2': 0.2733262}
INFO:root:Epoch	5/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.4094541, 'goal_food_l2': 0.7781223, 'goal_pricerange_accuracy': 0.812126, 'goal_pricerange_l2': 0.28794, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1851325, 'goal_area_accuracy': 0.753519, 'goal_area_l2': 0.3946339, 'goal_joint_accuracy': 0.3373306, 'goal_joint_l2': 0.8439496, 'requested_all_accuracy': 0.6004327, 'requested_all_l2': 0.5682468, 'method_accuracy': 0.8375582, 'method_l2': 0.246441}
INFO:root:Epoch	6/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.4733688, 'goal_food_l2': 0.7309267, 'goal_pricerange_accuracy': 0.8236937, 'goal_pricerange_l2': 0.2656325, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1852007, 'goal_area_accuracy': 0.7849656, 'goal_area_l2': 0.3555857, 'goal_joint_accuracy': 0.3852972, 'goal_joint_l2': 0.8116501, 'requested_all_accuracy': 0.6051208, 'requested_all_l2': 0.5099557, 'method_accuracy': 0.8484221, 'method_l2': 0.2420555}
INFO:root:Epoch	7/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.5129827, 'goal_food_l2': 0.6972842, 'goal_pricerange_accuracy': 0.808935, 'goal_pricerange_l2': 0.2767995, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1855178, 'goal_area_accuracy': 0.8008386, 'goal_area_l2': 0.3173957, 'goal_joint_accuracy': 0.4053702, 'goal_joint_l2': 0.7941512, 'requested_all_accuracy': 0.605842, 'requested_all_l2': 0.5063281, 'method_accuracy': 0.8484221, 'method_l2': 0.2401268}
INFO:root:Epoch	8/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.554261, 'goal_food_l2': 0.6627138, 'goal_pricerange_accuracy': 0.8368568, 'goal_pricerange_l2': 0.2523495, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1847706, 'goal_area_accuracy': 0.7915544, 'goal_area_l2': 0.3232612, 'goal_joint_accuracy': 0.4306569, 'goal_joint_l2': 0.7751808, 'requested_all_accuracy': 0.608727, 'requested_all_l2': 0.4905326, 'method_accuracy': 0.8517848, 'method_l2': 0.2369352}
INFO:root:Epoch	9/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.590213, 'goal_food_l2': 0.6199431, 'goal_pricerange_accuracy': 0.8432389, 'goal_pricerange_l2': 0.2307587, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1853583, 'goal_area_accuracy': 0.7957472, 'goal_area_l2': 0.3144737, 'goal_joint_accuracy': 0.463243, 'goal_joint_l2': 0.7407954, 'requested_all_accuracy': 0.6300036, 'requested_all_l2': 0.4593794, 'method_accuracy': 0.8559234, 'method_l2': 0.2298423}
INFO:root:Epoch	10/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.5955393, 'goal_food_l2': 0.5958338, 'goal_pricerange_accuracy': 0.8492222, 'goal_pricerange_l2': 0.2276292, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1846165, 'goal_area_accuracy': 0.8119197, 'goal_area_l2': 0.2986329, 'goal_joint_accuracy': 0.4627216, 'goal_joint_l2': 0.7321948, 'requested_all_accuracy': 0.6584926, 'requested_all_l2': 0.4364998, 'method_accuracy': 0.8528195, 'method_l2': 0.2355973}
INFO:root:Epoch	11/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.620506, 'goal_food_l2': 0.5777403, 'goal_pricerange_accuracy': 0.8504188, 'goal_pricerange_l2': 0.2222424, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1848736, 'goal_area_accuracy': 0.8131177, 'goal_area_l2': 0.296953, 'goal_joint_accuracy': 0.4887904, 'goal_joint_l2': 0.7082464, 'requested_all_accuracy': 0.7104219, 'requested_all_l2': 0.4059161, 'method_accuracy': 0.8554061, 'method_l2': 0.2251957}
INFO:root:Epoch	12/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6288282, 'goal_food_l2': 0.5608276, 'goal_pricerange_accuracy': 0.8611887, 'goal_pricerange_l2': 0.2174185, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.185227, 'goal_area_accuracy': 0.8065289, 'goal_area_l2': 0.3072923, 'goal_joint_accuracy': 0.491658, 'goal_joint_l2': 0.7038062, 'requested_all_accuracy': 0.7446809, 'requested_all_l2': 0.3838323, 'method_accuracy': 0.8618727, 'method_l2': 0.2187109}
INFO:root:Epoch	13/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.638482, 'goal_food_l2': 0.5493264, 'goal_pricerange_accuracy': 0.8520144, 'goal_pricerange_l2': 0.2270083, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1842459, 'goal_area_accuracy': 0.8167116, 'goal_area_l2': 0.2946073, 'goal_joint_accuracy': 0.4908759, 'goal_joint_l2': 0.7035826, 'requested_all_accuracy': 0.7493689, 'requested_all_l2': 0.3786045, 'method_accuracy': 0.86627, 'method_l2': 0.2114839}
INFO:root:Epoch	14/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6471372, 'goal_food_l2': 0.5210977, 'goal_pricerange_accuracy': 0.8436378, 'goal_pricerange_l2': 0.2351004, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1845227, 'goal_area_accuracy': 0.8011381, 'goal_area_l2': 0.3043612, 'goal_joint_accuracy': 0.4989572, 'goal_joint_l2': 0.6814508, 'requested_all_accuracy': 0.7641543, 'requested_all_l2': 0.3599963, 'method_accuracy': 0.8618727, 'method_l2': 0.2210034}
INFO:root:Epoch	15/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.650466, 'goal_food_l2': 0.5084299, 'goal_pricerange_accuracy': 0.8452333, 'goal_pricerange_l2': 0.2313288, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1861027, 'goal_area_accuracy': 0.8238994, 'goal_area_l2': 0.2796676, 'goal_joint_accuracy': 0.5, 'goal_joint_l2': 0.6777723, 'requested_all_accuracy': 0.7800216, 'requested_all_l2': 0.3519368, 'method_accuracy': 0.8644594, 'method_l2': 0.2216467}
INFO:root:Epoch	16/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6488016, 'goal_food_l2': 0.5019257, 'goal_pricerange_accuracy': 0.8536099, 'goal_pricerange_l2': 0.2259283, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1857074, 'goal_area_accuracy': 0.8149146, 'goal_area_l2': 0.2899239, 'goal_joint_accuracy': 0.4979145, 'goal_joint_l2': 0.6657409, 'requested_all_accuracy': 0.7637937, 'requested_all_l2': 0.364035, 'method_accuracy': 0.8613554, 'method_l2': 0.2211103}
INFO:root:Epoch	17/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6617843, 'goal_food_l2': 0.4916085, 'goal_pricerange_accuracy': 0.8663742, 'goal_pricerange_l2': 0.2113415, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1848615, 'goal_area_accuracy': 0.8227014, 'goal_area_l2': 0.2786273, 'goal_joint_accuracy': 0.5182482, 'goal_joint_l2': 0.6609745, 'requested_all_accuracy': 0.7673999, 'requested_all_l2': 0.3578135, 'method_accuracy': 0.8649767, 'method_l2': 0.2160512}
INFO:root:Epoch	18/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.671771, 'goal_food_l2': 0.4909989, 'goal_pricerange_accuracy': 0.8556043, 'goal_pricerange_l2': 0.2186913, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1855032, 'goal_area_accuracy': 0.8161126, 'goal_area_l2': 0.2827645, 'goal_joint_accuracy': 0.5239833, 'goal_joint_l2': 0.6530299, 'requested_all_accuracy': 0.7951677, 'requested_all_l2': 0.3275933, 'method_accuracy': 0.8579928, 'method_l2': 0.21864}
INFO:root:Epoch	19/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6674434, 'goal_food_l2': 0.4814666, 'goal_pricerange_accuracy': 0.8631831, 'goal_pricerange_l2': 0.213395, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1856993, 'goal_area_accuracy': 0.83528, 'goal_area_l2': 0.2610568, 'goal_joint_accuracy': 0.5286757, 'goal_joint_l2': 0.6430179, 'requested_all_accuracy': 0.787234, 'requested_all_l2': 0.3372282, 'method_accuracy': 0.8603207, 'method_l2': 0.2192113}
INFO:root:Epoch	20/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6667776, 'goal_food_l2': 0.4852017, 'goal_pricerange_accuracy': 0.8456322, 'goal_pricerange_l2': 0.237784, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.186396, 'goal_area_accuracy': 0.818808, 'goal_area_l2': 0.2803331, 'goal_joint_accuracy': 0.5182482, 'goal_joint_l2': 0.6504366, 'requested_all_accuracy': 0.7688424, 'requested_all_l2': 0.3561331, 'method_accuracy': 0.8533368, 'method_l2': 0.2273834}
INFO:root:Epoch	21/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.662783, 'goal_food_l2': 0.4840882, 'goal_pricerange_accuracy': 0.8408456, 'goal_pricerange_l2': 0.2453927, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1859257, 'goal_area_accuracy': 0.8140162, 'goal_area_l2': 0.2858237, 'goal_joint_accuracy': 0.5104275, 'goal_joint_l2': 0.6569661, 'requested_all_accuracy': 0.7706455, 'requested_all_l2': 0.3488621, 'method_accuracy': 0.8579928, 'method_l2': 0.2229681}
INFO:root:Epoch	22/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.662783, 'goal_food_l2': 0.4755407, 'goal_pricerange_accuracy': 0.8611887, 'goal_pricerange_l2': 0.212195, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1860972, 'goal_area_accuracy': 0.8230009, 'goal_area_l2': 0.2690455, 'goal_joint_accuracy': 0.5255474, 'goal_joint_l2': 0.6402881, 'requested_all_accuracy': 0.7832672, 'requested_all_l2': 0.3331629, 'method_accuracy': 0.8595447, 'method_l2': 0.2240503}
INFO:root:Epoch	23/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6707723, 'goal_food_l2': 0.4756857, 'goal_pricerange_accuracy': 0.8591943, 'goal_pricerange_l2': 0.2160794, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1858809, 'goal_area_accuracy': 0.8107218, 'goal_area_l2': 0.2938905, 'goal_joint_accuracy': 0.5250261, 'goal_joint_l2': 0.6400882, 'requested_all_accuracy': 0.7994951, 'requested_all_l2': 0.3182197, 'method_accuracy': 0.8592861, 'method_l2': 0.2167514}
INFO:root:Epoch	24/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6780959, 'goal_food_l2': 0.4668013, 'goal_pricerange_accuracy': 0.8595931, 'goal_pricerange_l2': 0.2172161, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1859454, 'goal_area_accuracy': 0.8265948, 'goal_area_l2': 0.2696827, 'goal_joint_accuracy': 0.5328467, 'goal_joint_l2': 0.6349449, 'requested_all_accuracy': 0.8092319, 'requested_all_l2': 0.3077394, 'method_accuracy': 0.8608381, 'method_l2': 0.2156442}
INFO:root:Epoch	25/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6741012, 'goal_food_l2': 0.4677305, 'goal_pricerange_accuracy': 0.8520144, 'goal_pricerange_l2': 0.2303912, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1864503, 'goal_area_accuracy': 0.8152141, 'goal_area_l2': 0.2851698, 'goal_joint_accuracy': 0.5252868, 'goal_joint_l2': 0.641757, 'requested_all_accuracy': 0.7893978, 'requested_all_l2': 0.3299989, 'method_accuracy': 0.8657527, 'method_l2': 0.2137137}
INFO:root:Epoch	26/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6790945, 'goal_food_l2': 0.4715134, 'goal_pricerange_accuracy': 0.8659753, 'goal_pricerange_l2': 0.2088414, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1862492, 'goal_area_accuracy': 0.8313866, 'goal_area_l2': 0.2635253, 'goal_joint_accuracy': 0.5404067, 'goal_joint_l2': 0.6288591, 'requested_all_accuracy': 0.8067075, 'requested_all_l2': 0.3057918, 'method_accuracy': 0.8579928, 'method_l2': 0.2185461}
INFO:root:Epoch	27/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6767643, 'goal_food_l2': 0.4630026, 'goal_pricerange_accuracy': 0.8556043, 'goal_pricerange_l2': 0.2259956, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.186298, 'goal_area_accuracy': 0.818209, 'goal_area_l2': 0.2805664, 'goal_joint_accuracy': 0.5344108, 'goal_joint_l2': 0.6313006, 'requested_all_accuracy': 0.8023801, 'requested_all_l2': 0.3125405, 'method_accuracy': 0.8528195, 'method_l2': 0.2279978}
INFO:root:Epoch	28/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6677763, 'goal_food_l2': 0.4723655, 'goal_pricerange_accuracy': 0.8484244, 'goal_pricerange_l2': 0.230426, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1861208, 'goal_area_accuracy': 0.8283917, 'goal_area_l2': 0.2646347, 'goal_joint_accuracy': 0.5252868, 'goal_joint_l2': 0.6421508, 'requested_all_accuracy': 0.8031013, 'requested_all_l2': 0.3088536, 'method_accuracy': 0.8629074, 'method_l2': 0.2165852}
INFO:root:Epoch	29/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6800932, 'goal_food_l2': 0.4585046, 'goal_pricerange_accuracy': 0.8496211, 'goal_pricerange_l2': 0.2359391, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1862744, 'goal_area_accuracy': 0.8274933, 'goal_area_l2': 0.2646013, 'goal_joint_accuracy': 0.5393639, 'goal_joint_l2': 0.6211138, 'requested_all_accuracy': 0.8121168, 'requested_all_l2': 0.3023479, 'method_accuracy': 0.8644594, 'method_l2': 0.2157169}
INFO:root:Epoch	30/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.683755, 'goal_food_l2': 0.4556029, 'goal_pricerange_accuracy': 0.8571998, 'goal_pricerange_l2': 0.2166697, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1862709, 'goal_area_accuracy': 0.8286912, 'goal_area_l2': 0.2669732, 'goal_joint_accuracy': 0.5406674, 'goal_joint_l2': 0.6256877, 'requested_all_accuracy': 0.8153624, 'requested_all_l2': 0.2944208, 'method_accuracy': 0.8605794, 'method_l2': 0.218628}
INFO:root:Epoch	31/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6727696, 'goal_food_l2': 0.462218, 'goal_pricerange_accuracy': 0.853211, 'goal_pricerange_l2': 0.2259756, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1858911, 'goal_area_accuracy': 0.8298892, 'goal_area_l2': 0.2675207, 'goal_joint_accuracy': 0.5315433, 'goal_joint_l2': 0.6349649, 'requested_all_accuracy': 0.8225748, 'requested_all_l2': 0.2781162, 'method_accuracy': 0.8623901, 'method_l2': 0.2186153}
INFO:root:Epoch	32/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.671771, 'goal_food_l2': 0.4508894, 'goal_pricerange_accuracy': 0.8671719, 'goal_pricerange_l2': 0.207705, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.186419, 'goal_area_accuracy': 0.8274933, 'goal_area_l2': 0.270431, 'goal_joint_accuracy': 0.5385819, 'goal_joint_l2': 0.6212192, 'requested_all_accuracy': 0.8142806, 'requested_all_l2': 0.2942453, 'method_accuracy': 0.8554061, 'method_l2': 0.2228709}
INFO:root:Epoch	33/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6750999, 'goal_food_l2': 0.4560606, 'goal_pricerange_accuracy': 0.8679697, 'goal_pricerange_l2': 0.2082356, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1864188, 'goal_area_accuracy': 0.8301887, 'goal_area_l2': 0.2687154, 'goal_joint_accuracy': 0.5404067, 'goal_joint_l2': 0.6203148, 'requested_all_accuracy': 0.815723, 'requested_all_l2': 0.2893188, 'method_accuracy': 0.8546301, 'method_l2': 0.2287089}
INFO:root:Epoch	34/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6834221, 'goal_food_l2': 0.4423658, 'goal_pricerange_accuracy': 0.8356602, 'goal_pricerange_l2': 0.2292381, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1863623, 'goal_area_accuracy': 0.8298892, 'goal_area_l2': 0.269634, 'goal_joint_accuracy': 0.536757, 'goal_joint_l2': 0.6186854, 'requested_all_accuracy': 0.8168049, 'requested_all_l2': 0.2874643, 'method_accuracy': 0.8497155, 'method_l2': 0.2261533}
INFO:root:Epoch	35/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6824234, 'goal_food_l2': 0.4507909, 'goal_pricerange_accuracy': 0.8691663, 'goal_pricerange_l2': 0.1980752, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1864809, 'goal_area_accuracy': 0.8340821, 'goal_area_l2': 0.2634912, 'goal_joint_accuracy': 0.5474453, 'goal_joint_l2': 0.616392, 'requested_all_accuracy': 0.8222142, 'requested_all_l2': 0.282803, 'method_accuracy': 0.8564408, 'method_l2': 0.2238429}
INFO:root:Epoch	36/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6790945, 'goal_food_l2': 0.452558, 'goal_pricerange_accuracy': 0.8583965, 'goal_pricerange_l2': 0.2226202, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1865072, 'goal_area_accuracy': 0.8274933, 'goal_area_l2': 0.2644154, 'goal_joint_accuracy': 0.5398853, 'goal_joint_l2': 0.6224608, 'requested_all_accuracy': 0.8175261, 'requested_all_l2': 0.289065, 'method_accuracy': 0.8608381, 'method_l2': 0.2201914}
INFO:root:Epoch	37/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6900799, 'goal_food_l2': 0.4469025, 'goal_pricerange_accuracy': 0.8460311, 'goal_pricerange_l2': 0.2385214, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1865803, 'goal_area_accuracy': 0.8283917, 'goal_area_l2': 0.2627365, 'goal_joint_accuracy': 0.5456204, 'goal_joint_l2': 0.6146105, 'requested_all_accuracy': 0.8200505, 'requested_all_l2': 0.2859265, 'method_accuracy': 0.8548888, 'method_l2': 0.2218904}
INFO:root:Epoch	38/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6787617, 'goal_food_l2': 0.4537631, 'goal_pricerange_accuracy': 0.8583965, 'goal_pricerange_l2': 0.2240368, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1865492, 'goal_area_accuracy': 0.8233004, 'goal_area_l2': 0.268391, 'goal_joint_accuracy': 0.5346715, 'goal_joint_l2': 0.6247631, 'requested_all_accuracy': 0.818608, 'requested_all_l2': 0.2902628, 'method_accuracy': 0.8556648, 'method_l2': 0.2235395}
INFO:root:Epoch	39/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6691079, 'goal_food_l2': 0.4712361, 'goal_pricerange_accuracy': 0.8528121, 'goal_pricerange_l2': 0.214965, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1863158, 'goal_area_accuracy': 0.8134172, 'goal_area_l2': 0.2885391, 'goal_joint_accuracy': 0.528415, 'goal_joint_l2': 0.6372206, 'requested_all_accuracy': 0.8189686, 'requested_all_l2': 0.2883526, 'method_accuracy': 0.8554061, 'method_l2': 0.2280622}
INFO:root:Epoch	40/100


HBox(children=(IntProgress(value=0, max=162), HTML(value='')))




HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




INFO:root:{'goal_food_accuracy': 0.6814248, 'goal_food_l2': 0.4531595, 'goal_pricerange_accuracy': 0.8575987, 'goal_pricerange_l2': 0.2174959, 'goal_name_accuracy': 0.9065421, 'goal_name_l2': 0.1863577, 'goal_area_accuracy': 0.818808, 'goal_area_l2': 0.2816717, 'goal_joint_accuracy': 0.540146, 'goal_joint_l2': 0.623869, 'requested_all_accuracy': 0.8200505, 'requested_all_l2': 0.2836612, 'method_accuracy': 0.8533368, 'method_l2': 0.2307491}


### Load iDST All Model

In [11]:
model_All = iDSTAllModel(vocabulary_size = VOCABULARY_SIZE,
                         embedding_dim = EMBEDDING_DIM,
                         altered_embedding_dim = ALTERED_EMBEDDING_DIM,
                         hidden_dim = HIDDEN_DIM,
                         goal_pricerange_dim = GOAL_PRICERANGE_DIM,
                         goal_area_dim = GOAL_AREA_DIM,
                         goal_name_dim = GOAL_NAME_DIM,
                         goal_food_dim = GOAL_FOOD_DIM,
                         requested_dim = REQUESTED_DIM,
                         method_dim = METHOD_DIM,
                         device = DEVICE).to(DEVICE)
model_All.load_state_dict(torch.load("model_All.pt"))
model_All.eval()

iDSTAllModel(
  (embeddings): Embedding(897, 170)
  (altered_embeddings): Linear(in_features=171, out_features=300, bias=True)
  (lstm): LSTM(300, 100)
  (goal_pricerange_classifier): Linear(in_features=100, out_features=5, bias=True)
  (goal_area_classifier): Linear(in_features=100, out_features=7, bias=True)
  (goal_name_classifier): Linear(in_features=100, out_features=115, bias=True)
  (goal_food_classifier): Linear(in_features=100, out_features=93, bias=True)
  (requested_classifier): Linear(in_features=100, out_features=8, bias=True)
  (method_classifier): Linear(in_features=100, out_features=5, bias=True)
)

### Print scores

In [12]:
dev_tracker, _ = make_tracker(model_All, raw_X_dev, raw_Y_dev, dataset = "dstc2_dev", percentage = 1.0)
get_scores(dev_tracker, dataset = "dstc2_dev", ontology = ontology)

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))




{'goal_food_accuracy': 0.6824234,
 'goal_food_l2': 0.4507909,
 'goal_pricerange_accuracy': 0.8691663,
 'goal_pricerange_l2': 0.1980752,
 'goal_name_accuracy': 0.9065421,
 'goal_name_l2': 0.1864809,
 'goal_area_accuracy': 0.8340821,
 'goal_area_l2': 0.2634912,
 'goal_joint_accuracy': 0.5474453,
 'goal_joint_l2': 0.616392,
 'requested_all_accuracy': 0.8222142,
 'requested_all_l2': 0.282803,
 'method_accuracy': 0.8564408,
 'method_l2': 0.2238429}

In [13]:
test_tracker, _ = make_tracker(model_All, raw_X_test, raw_Y_test, dataset = "dstc2_test", percentage = 1.0)
get_scores(test_tracker, dataset = "dstc2_test", ontology = ontology)

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




{'goal_food_accuracy': 0.7550174,
 'goal_food_l2': 0.3484232,
 'goal_pricerange_accuracy': 0.8415325,
 'goal_pricerange_l2': 0.246297,
 'goal_name_accuracy': 0.961326,
 'goal_name_l2': 0.077427,
 'goal_area_accuracy': 0.8169552,
 'goal_area_l2': 0.2771145,
 'goal_joint_accuracy': 0.5609454,
 'goal_joint_l2': 0.5856348,
 'requested_all_accuracy': 0.8393818,
 'requested_all_l2': 0.2425617,
 'method_accuracy': 0.8952176,
 'method_l2': 0.1626297}

### Plotting

In [18]:
def frange(start, stop, step):
    i = start
    while i < stop:
        i = np.around(i, decimals = 2)
        yield i
        i += step    
    
#--------------------

def plotly_plot_incremental(goal_pricerange_accuracies, goal_area_accuracies, goal_name_accuracies, goal_food_accuracies,
                            requested_accuracies, method_accuracies, percentages, dataset):
    
    if dataset == "dstc2_train":
        dataset = "TRAIN"
    elif dataset == "dstc2_dev":
        dataset = "DEV"
    else:
        dataset = "TEST"
        
    plotly.offline.iplot({"data": [Scatter(x = percentages, y = goal_pricerange_accuracies, mode = "lines+markers", name = "{} Goal Pricerange Accuracy".format(dataset), marker = dict(color = "#1abc9c")),
                                   Scatter(x = percentages, y = goal_area_accuracies, mode = "lines+markers", name = "{} Goal Area Accuracy".format(dataset), marker = dict(color = "#3498db")),
                                   Scatter(x = percentages, y = goal_name_accuracies, mode = "lines+markers", name = "{} Goal Name Accuracy".format(dataset), marker = dict(color = "#9b59b6")),
                                   Scatter(x = percentages, y = goal_food_accuracies, mode = "lines+markers", name = "{} Goal Food Accuracy".format(dataset), marker = dict(color = "#e74c3c")),
                                   Scatter(x = percentages, y = requested_accuracies, mode = "lines+markers", name = "{} Requested Accuracy".format(dataset), marker = dict(color = "#34495e")),
                                   Scatter(x = percentages, y = method_accuracies, mode = "lines+markers", name = "{} Method Accuracy".format(dataset), marker = dict(color = "#f1c40f"))],
                            "layout": Layout(title = "<b>{} Percentage - Accuracy</b>".format(dataset),
                                             xaxis = dict(title = "<b>Percentage</b>",
                                                          dtick = 0.1,
                                                          titlefont = dict(color = "#34495e")),
                                             yaxis = dict(title = "<b>Accuracy</b>",
                                                          dtick = 0.05,
                                                          titlefont = dict(color = "#34495e")),
                                             margin = Margin(b = 150))})

In [16]:
dev_goal_pricerange_accuracies = []
dev_goal_area_accuracies = []
dev_goal_name_accuracies = []
dev_goal_food_accuracies = []
dev_requested_accuracies = []
dev_method_accuracies = []
dev_percentages = []

test_goal_pricerange_accuracies = []
test_goal_area_accuracies = []
test_goal_name_accuracies = []
test_goal_food_accuracies = []
test_requested_accuracies = []
test_method_accuracies = []
test_percentages = []

percentages = list(frange(0.1, 1.05, 0.1))
for percentage in tqdm_notebook(percentages, total = len(percentages)):
    
    dev_incremental_tracker, dev_incremental_percentage = make_tracker(model_All, raw_X_dev, raw_Y_dev, dataset = "dstc2_dev", percentage = percentage)
    dev_percentages.append(dev_incremental_percentage)
    dev_scores_dict = get_scores(dev_incremental_tracker, dataset = "dstc2_dev", ontology = ontology)
    dev_goal_pricerange_accuracies.append(dev_scores_dict["goal_pricerange_accuracy"])
    dev_goal_area_accuracies.append(dev_scores_dict["goal_area_accuracy"])
    dev_goal_name_accuracies.append(dev_scores_dict["goal_name_accuracy"])
    dev_goal_food_accuracies.append(dev_scores_dict["goal_food_accuracy"])
    dev_requested_accuracies.append(dev_scores_dict["requested_all_accuracy"])
    dev_method_accuracies.append(dev_scores_dict["method_accuracy"])
    
    test_incremental_tracker, test_incremental_percentage = make_tracker(model_All, raw_X_test, raw_Y_test, dataset = "dstc2_test", percentage = percentage)
    test_percentages.append(test_incremental_percentage)
    test_scores_dict = get_scores(test_incremental_tracker, dataset = "dstc2_test", ontology = ontology)
    test_goal_pricerange_accuracies.append(test_scores_dict["goal_pricerange_accuracy"])
    test_goal_area_accuracies.append(test_scores_dict["goal_area_accuracy"])
    test_goal_name_accuracies.append(test_scores_dict["goal_name_accuracy"])
    test_goal_food_accuracies.append(test_scores_dict["goal_food_accuracy"])
    test_requested_accuracies.append(test_scores_dict["requested_all_accuracy"])
    test_method_accuracies.append(test_scores_dict["method_accuracy"])

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))

HBox(children=(IntProgress(value=0, max=506), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))




In [19]:
plotly_plot_incremental(dev_goal_pricerange_accuracies, dev_goal_area_accuracies, dev_goal_name_accuracies, dev_goal_food_accuracies,
                        dev_requested_accuracies, dev_method_accuracies, dev_percentages, "dstc2_dev")

In [20]:
plotly_plot_incremental(test_goal_pricerange_accuracies, test_goal_area_accuracies, test_goal_name_accuracies, test_goal_food_accuracies,
                        test_requested_accuracies, test_method_accuracies, test_percentages, "dstc2_test")