# Installation

In [1]:
! pip install transformers



# Mount

In [2]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/TODS

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/TODS


# BERT

## Bert

In [3]:
"""
Theis Module contains the BERT model architecture.
"""

import torch
import torch.nn as nn



# Embeddings
# 1. Cluster similar words together.
# 2. Preserve different relationships between words such as: semantic, syntactic, linear,
# and since BERT is bidirectional it will also preserve contextual relationships as well.
class Embeddings(nn.Module):
    """
    Embedding layer for BERT.

    This layer takes input_ids and token_type_ids as inputs and generates word embeddings
    using three types of embeddings: word, position, and token_type embeddings.

    :param config: BERT configuration.
    :type config: Config
    """
    def __init__(self, config):
        super(Embeddings, self).__init__()
        # Bert uses 3 types of embeddings: word, position, and token_type (segment type).
        # LayerNorm is used to normalize the sum of the embeddings.
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.padding_idx)
        self.position_embeddings = nn.Embedding(config.max_seq, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.token_types, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids, token_type_ids): # input_ids: [batch_size, seq_len] token_type_ids: [batch_size, seq_len]
        """
        Forward pass of the Embeddings layer.

        :param input_ids: The input token IDs.
        :type input_ids: torch.Tensor
        :param token_type_ids: The token type IDs.
        :type token_type_ids: torch.Tensor
        :return: The generated embeddings.
        :rtype: torch.Tensor
        """
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        seq_len = input_ids.size(1)
        position_ids = torch.arange(seq_len).unsqueeze(0).expand_as(input_ids).to(device) # position_ids: [batch_size, seq_len]
        word_embeddings = self.word_embeddings(input_ids)  # word_embeddings: [batch_size, seq_len, hidden_size]
        position_embeddings = self.position_embeddings(position_ids) # position_embeddings: [batch_size, seq_len, hidden_size]
        token_type_embeddings = self.token_type_embeddings(token_type_ids) # token_type_embeddings: [batch_size, seq_len, hidden_size]
        embeddings = word_embeddings + position_embeddings + token_type_embeddings # embeddings: [batch_size, seq_len, hidden_size]
        # Normalize by subtracting the mean and dividing by the standard deviation calculated across the feature dimension
        # then multiply by a learned gain parameter and add to a learned bias parameter.
        embeddings = self.layer_norm(embeddings) # embeddings: [batch_size, seq_len, hidden_size]
        embeddings = self.dropout(embeddings) # embeddings: [batch_size, seq_len, hidden_size]
        return embeddings

# Encoder layer
class EncoderLayer(nn.Module):
    """
    Encoder layer for BERT.

    This layer contains self-attention, layer normalization, and position-wise feed-forward network.

    :param config: BERT configuration.
    :type config: Config
    """

    def __init__(self, config):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(config)
        self.self_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.self_dropout = nn.Dropout(config.dropout)
        self.position_wise_feed_forward = PositionWiseFeedForward(config)
        self.ffn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.ffn_dropout = nn.Dropout(config.dropout)

    def forward(self, input, attention_mask):
        """
        Forward pass of the EncoderLayer.

        :param input: The input tensor.
        :type input: torch.Tensor
        :param attention_mask: The attention mask.
        :type attention_mask: torch.Tensor
        :return: The output tensor.
        :rtype: torch.Tensor
        """

        # Multi-head attention
        context, attention = self.self_attention(input, input, input, attention_mask) # context: [batch_size, seq_len, hidden_size] attention: [batch_size, heads, seq_len, seq_len]
        # Add and normalize
        context = self.self_dropout(context) # context: [batch_size, seq_len, hidden_size]
        output = self.self_layer_norm(input + context) # output: [batch_size, seq_len, hidden_size]
        # Position-wise feed-forward network
        context = self.position_wise_feed_forward(output) # context: [batch_size, seq_len, hidden_size]
        # Add and normalize
        context = self.ffn_dropout(context) # context: [batch_size, seq_len, hidden_size]
        output = self.ffn_layer_norm(output + context) # output: [batch_size, seq_len, hidden_size]
        return output, attention

# Multi-head attention
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention layer for BERT.

    This layer performs multi-head self-attention and returns the output context.

    :param config: BERT configuration.
    :type config: Config
    """
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()
        self.config = config
        self.w_q = nn.Linear(config.hidden_size, config.hidden_size)
        self.w_k = nn.Linear(config.hidden_size, config.hidden_size)
        self.w_v = nn.Linear(config.hidden_size, config.hidden_size)
        self.w_o = nn.Linear(config.hidden_size, config.hidden_size)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, query, key, value, attention_mask):
        """
        Forward pass of the MultiHeadAttention.

        :param query: The query tensor.
        :type query: torch.Tensor
        :param key: The key tensor.
        :type key: torch.Tensor
        :param value: The value tensor.
        :type value: torch.Tensor
        :param attention_mask: The attention mask.
        :type attention_mask: torch.Tensor
        :return: The output context.
        :rtype: torch.Tensor
        """

        # query: [batch_size, seq_len, hidden_size] key: [batch_size, seq_len, hidden_size]
        # value: [batch_size, seq_len, hidden_size] attention_mask: [batch_size, seq_len_q, seq_len_k]

        batch_size, seq_len, hidden_size = query.size()

        query = self.w_q(query).view(batch_size, seq_len, self.config.heads, hidden_size // self.config.heads).transpose(1, 2) # query: [batch_size, heads, seq_len, hidden_size // heads]
        key = self.w_k(key).view(batch_size, seq_len, self.config.heads, hidden_size // self.config.heads).transpose(1, 2) # key: [batch_size, heads, seq_len, hidden_size // heads]
        value = self.w_v(value).view(batch_size, seq_len, self.config.heads, hidden_size // self.config.heads).transpose(1, 2) # value: [batch_size, heads, seq_len, hidden_size // heads]

        # Scaled dot-product attention
        attention = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(float(hidden_size // self.config.heads))) # attention: [batch_size, heads, seq_len, seq_len]
        attention_mask = attention_mask.unsqueeze(1).repeat(1, self.config.heads, 1, 1) # attention_mask: [batch_size, heads, seq_len_q, seq_len_k]
        attention_mask = (attention_mask == 0)
        attention.masked_fill_(attention_mask, -1e9) # attention: [batch_size, heads, seq_len, seq_len]
        attention = self.softmax(attention) # attention: [batch_size, heads, seq_len, seq_len]
        attention = self.dropout(attention) # attention: [batch_size, heads, seq_len, seq_len]
        context = torch.matmul(attention, value) # context: [batch_size, heads, seq_len, hidden_size // heads]
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) # context: [batch_size, seq_len, hidden_size]
        output = self.w_o(context) # output: [batch_size, seq_len, hidden_size]
        return output, attention

# Position-wise feed-forward network
class PositionWiseFeedForward(nn.Module):
    """
    Position-wise feed-forward network layer for BERT.

    This layer applies two linear transformations with a GELU activation function.

    :param config: BERT configuration.
    :type config: Config
    """
    def __init__(self, config):
        super(PositionWiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(config.hidden_size, config.ff_size)
        self.linear2 = nn.Linear(config.ff_size, config.hidden_size)
        self.gelu = nn.GELU()

    def forward(self, input):
        """
        Forward pass of the PositionWiseFeedForward layer.

        :param input: The input tensor.
        :type input: torch.Tensor
        :return: The output tensor.
        :rtype: torch.Tensor
        """
        output = self.linear1(input) # output: [batch_size, seq_len, ff_size]
        output = self.gelu(output) # output: [batch_size, seq_len, ff_size]
        output = self.linear2(output) # output: [batch_size, seq_len, hidden_size]
        return output

# Bert
# 1. Puts it all together.
class Bert(nn.Module):
    """
    BERT model implementation.

    This model combines the Embeddings layer, EncoderLayers, and linear transformation layers
    to perform BERT-based processing.

    :param config: BERT configuration.
    :type config: Config
    """
    def __init__(self, config):
        super(Bert, self).__init__()
        self.embeddings = Embeddings(config)
        self.encoder = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
        self.linear = nn.Linear(config.hidden_size, config.hidden_size)
        self.tanh = nn.Tanh()

    def forward(self, input_ids, token_type_ids, attention_mask, return_dict=False): # input_ids: [batch_size, seq_len] token_type_ids: [batch_size, seq_len] attention_mask: [batch_size, seq_len]
        """
        Forward pass of the Bert model.

        :param input_ids: The input token IDs.
        :type input_ids: torch.Tensor
        :param token_type_ids: The token type IDs.
        :type token_type_ids: torch.Tensor
        :param attention_mask: The attention mask.
        :type attention_mask: torch.Tensor
        :param return_dict: Whether to return a dictionary or not, defaults to False.
        :type return_dict: bool
        :return: The sequence output and pooled output.
        :rtype: torch.Tensor, torch.Tensor
        """

        # Embedding
        output = self.embeddings(input_ids, token_type_ids) # output: [batch_size, seq_len, hidden_size]

        # Encoder
        attention_mask = attention_mask.unsqueeze(1).repeat(1, output.size(1), 1) # attention_mask: [batch_size, seq_len, seq_len]
        for encoder_layer in self.encoder:
            output, attention = encoder_layer(output, attention_mask) # output: [batch_size, seq_len, hidden_size] attention: [batch_size, heads, seq_len, seq_len]

        # Sequnce and pooled outputs
        sequence_output = output # sequence_output: [batch_size, seq_len, hidden_size]
        pooled_output = self.tanh(self.linear(sequence_output[:, 0])) # pooled_output: [batch_size, hidden_size]

        return sequence_output, pooled_output

## config

In [4]:
"""
This module contains the configuration class for BERT.
"""

# Bert configuration
class BERTConfig(object):
    """
    Configuration class for BERT.

    This class holds the configuration parameters for the BERT model.

    :param vocab_size: The size of the vocabulary, defaults to 30522.
    :type vocab_size: int
    :param hidden_size: The hidden size of the BERT model, defaults to 768.
    :type hidden_size: int
    :param encoder_layers: The number of encoder layers in the BERT model, defaults to 12.
    :type encoder_layers: int
    :param heads: The number of attention heads in the BERT model, defaults to 12.
    :type heads: int
    :param ff_size: The size of the feed-forward layer in the BERT model, defaults to 3072.
    :type ff_size: int
    :param token_types: The number of token types in the BERT model, defaults to 2.
    :type token_types: int
    :param max_seq: The maximum sequence length in the BERT model, defaults to 512.
    :type max_seq: int
    :param padding_idx: The padding index used in the BERT model, defaults to 0.
    :type padding_idx: int
    :param layer_norm_eps: The epsilon value for layer normalization in the BERT model, defaults to 1e-12.
    :type layer_norm_eps: float
    :param dropout: The dropout rate in the BERT model, defaults to 0.1.
    :type dropout: float
    """
    def __init__(self, vocab_size=30522, hidden_size=768, encoder_layers=12, heads=12, ff_size=3072, token_types=2, max_seq=512, padding_idx=0, layer_norm_eps=1e-12, dropout=0.1):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.encoder_layers = encoder_layers
        self.heads = heads
        self.ff_size = ff_size
        self.token_types = token_types
        self.max_seq = max_seq
        self.padding_idx = padding_idx
        self.layer_norm_eps = layer_norm_eps
        self.dropout = dropout

## utils

In [5]:
"""
This module contains utility functions for the BERT model.
"""

from collections import OrderedDict
from transformers import BertModel
# from botiverse.models.BERT.config import BERTConfig
# from botiverse.models.BERT.BERT import Bert


# Load pre-trained weights from trnasformers library
def LoadPretrainedWeights(model):
    """
    Load pre-trained weights from the transformers library.

    This function loads the pre-trained weights from the transformers library
    and updates the model's state_dict accordingly.

    :param model: The BERT model.
    :type model: Bert
    """

    # Get pre-trained weights from transformers library
    pretrained_model = BertModel.from_pretrained('bert-base-uncased')
    state_dict = pretrained_model.state_dict()

    # Delete position_ids from the state_dict if available
    if 'embeddings.position_ids' in state_dict.keys():
        del state_dict['embeddings.position_ids']

    # Get the new weights keys from the model
    new_keys = list(model.state_dict().keys())

    # Get the weights from the state_dict
    old_keys = list(state_dict.keys())
    weights = list(state_dict.values())

    # Create a new state_dict with the new keys
    new_state_dict = OrderedDict()
    for i in range(len(new_keys)):
        new_state_dict[new_keys[i]] = weights[i]
        # print(old_keys[i], '->', new_keys[i])

    model.load_state_dict(new_state_dict)

# Example comparing the outputs of the from scratch model to the pre-trained model from transformers library
import torch
from transformers import BertModel, BertTokenizer

def Example():
    """
    Example comparing the outputs of the from scratch model to the pre-trained model from transformers library.
    """

    # Build a BERT model from scratch
    config = BERTConfig()
    model = Bert(config)
    LoadPretrainedWeights(model)

    # Load pre-trained weights from the Transformers library
    pretrained_weights = 'bert-base-uncased'
    pretrained_model = BertModel.from_pretrained(pretrained_weights)

    # Tokenize the input sequence
    tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
    input_text = ["This is a sample input sequence.", "batyousef is awesome"]
    inputs = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt')

    # Set dropout to zero during inference
    model.eval()
    pretrained_model.eval()

    ids = inputs['input_ids']
    token_type_ids = inputs['token_type_ids']
    attention_mask = inputs['attention_mask']

    # Pass the inputs through both models and compare the outputs
    with torch.no_grad():
        model_output1, model_output2 = model(ids, token_type_ids, attention_mask)
        pretrained_output1, pretrained_output2 = pretrained_model(ids,
                                            attention_mask=attention_mask,
                                            token_type_ids=token_type_ids,
                                            return_dict=False
                                            )

    print(model_output1.size(), model_output1)
    print(pretrained_output1.size(), pretrained_output1)
    print()
    print()
    print(model_output2.size(), model_output2)
    print(pretrained_output2.size(), pretrained_output2)

    print(model)
    print(pretrained_model)

# TRIPPY

## config

In [6]:
"""
This Module has the configuration class for TRIPPY.
"""

import tokenizers
import os

# Trippy configuration
class TRIPPYConfig(object):
    """
    Configuration class for TRIPPY.

    This class holds the configuration parameters for the TRIPPY model.

    :param max_len: The maximum sequence length, defaults to 128.
    :type max_len: int
    :param train_batch_size: The batch size for training, defaults to 32.
    :type train_batch_size: int
    :param dev_batch_size: The batch size for development evaluation, defaults to 1.
    :type dev_batch_size: int
    :param test_batch_size: The batch size for testing, defaults to 1.
    :type test_batch_size: int
    :param epochs: The number of training epochs, defaults to 15.
    :type epochs: int
    :param hid_dim: The hidden dimension size, defaults to 768.
    :type hid_dim: int
    :param n_oper: The number of operations, defaults to 7.
    :type n_oper: int
    :param dropout: The dropout rate, defaults to 0.3.
    :type dropout: float
    :param vocab_path: The path to the vocabulary file, defaults to 'vocab.txt'.
    :type vocab_path: str
    :param ignore_idx: The index value to ignore, defaults to -100.
    :type ignore_idx: int
    :param oper2id: The mapping of operation names to IDs, defaults to {'carryover' : 0, 'dontcare': 1, 'update':2, 'refer':3, 'yes':4, 'no':5, 'inform':6}.
    :type oper2id: dict[str, int]
    :param weight_decay: The weight decay value, defaults to 0.0.
    :type weight_decay: float
    :param lr: The learning rate, defaults to 1e-4.
    :type lr: float
    :param adam_epsilon: The epsilon value for Adam optimizer, defaults to 1e-6.
    :type adam_epsilon: float
    :param warmup_proportion: The proportion of warmup steps, defaults to 0.1.
    :type warmup_proportion: float
    :param multiwoz: The path to the MultiWOZ dataset, defaults to False.
    :type multiwoz: str
    """
    def __init__(self,
                 max_len=128,
                 train_batch_size=32,
                 dev_batch_size=1,
                 test_batch_size=1,
                 epochs=15,
                 hid_dim=768,
                 n_oper=7,
                 dropout=0.3,
                 vocab_path='vocab.txt',
                 ignore_idx=-100,
                 oper2id={'carryover' : 0, 'dontcare': 1, 'update':2, 'refer':3, 'yes':4, 'no':5, 'inform':6},
                 weight_decay=0.0,
                 lr=1e-4,
                 adam_epsilon=1e-6,
                 warmup_proportion=0.1,
                 multiwoz=False):

        self.max_len = max_len
        self.train_batch_size = train_batch_size
        self.dev_batch_size = dev_batch_size
        self.test_batch_size = test_batch_size
        self.epochs = epochs
        self.hid_dim = hid_dim
        self.n_oper = n_oper
        self.dropout = dropout
        # cur_dir = os.path.dirname(os.path.abspath(__file__))
        cur_dir = ''
        self.vocab_path = os.path.join(cur_dir, vocab_path)
        self.tokenizer = tokenizers.BertWordPieceTokenizer(self.vocab_path, lowercase=True)
        self.ignore_idx = ignore_idx
        self.oper2id = oper2id
        self.weight_decay = weight_decay
        self.lr = lr
        self.adam_epsilon = adam_epsilon
        self.warmup_proportion = warmup_proportion
        self.multiwoz = multiwoz

## data

In [7]:
"""
This Module contains the data processing functions for TRIPPY.
"""

import json
import torch
import numpy as np
import re
from tqdm import tqdm

# from botiverse.models.TRIPPY.utils import RawDataInstance, DataInstance, normalize, is_included, included_with_label_maps, match_with_label_maps, mask_utterance

def get_ontology_label_maps(ontology_path, label_maps_path, domains):
  """
  Read ontology and label maps, and filter slots based on domains.

  :param ontology_path: The path to the ontology file.
  :type ontology_path: str

  :param label_maps_path: The path to the label maps file.
  :type label_maps_path: str

  :param domains: The list of domains to filter the slots.
  :type domains: list[str]

  :return: The sorted slot list and label maps.
  :rtype: tuple[list[str], dict]
  """

  # read ontology
  file = open(ontology_path)
  slot_list = json.load(file)

  # delete slots not in the domains
  del_slot = []
  for slot in slot_list:
    found = False
    for domain in domains:
      if domain in slot:
        found = True
    if found == False:
      del_slot.append(slot)
  for slot in del_slot:
    del slot_list[slot_list.index(slot)]

  # read label_maps
  file = open(label_maps_path)
  label_maps = json.load(file)

  return sorted(slot_list), label_maps


def read_raw_data(data_path, slot_list, max_len, domains, multiwoz):
  """
  Read raw data from the JSON file and preprocess it.

  :param data_path: The path to the JSON data file.
  :type data_path: str

  :param slot_list: The list of slots.
  :type slot_list: list[str]

  :param max_len: The maximum length of the input sequence.
  :type max_len: int

  :param domains: The list of domains.
  :type domains: list[str]

  :return: The list of raw data instances.
  :rtype: list[RawDataInstance]
  """

  # read data
  file = open(data_path)
  parsed_data = json.load(file)

  raw_data = []
  # loop over dialogues
  for dial_info in parsed_data:
    dial_idx = dial_info['dialogue_idx']

    history = []
    # loop over dialogue turns
    for turn in dial_info['dialogue']:

      # turn id
      turn_idx = turn['turn_idx']

      # turn utterances
      user_utter = turn['user_utterance']
      sys_utter = turn['system_utterance']

      # normalize utterances
      user_utter = ' '.join(normalize(user_utter, multiwoz))
      sys_utter = ' '.join(normalize(sys_utter, multiwoz))

      # get the changed slots in this turn
      turn_slots = turn['turn_slots']

      # Get system actions which will be used as the inform memory
      inform_mem = turn['system_act']

      # mask the system utterance by removing labels appeared in system acts
      sys_utter = ' '.join(mask_utterance(sys_utter, inform_mem, multiwoz, '[UNK]'))

      # append current instance
      raw_data.append(RawDataInstance(dial_idx,
                                      turn_idx,
                                      user_utter,
                                      sys_utter,
                                      history,
                                      turn_slots,
                                      inform_mem))

      # update history & last state for next turn
      history = [user_utter, sys_utter] + history

  return raw_data


def create_slot_span(input, target_value, tok_input_offsets, padding_len, label_maps):
  """
  Create a slot span given the input, target value, and token input offsets,
  by matching the target value as tokens with the input sequence.

  :param input: The input string.
  :type input: str

  :param target_value: The target value.
  :type target_value: str

  :param tok_input_offsets: The token input offsets.
  :type tok_input_offsets: list[tuple[int, int]]

  :param padding_len: The padding length.
  :type padding_len: int

  :param label_maps: The label maps.
  :type label_maps: dict

  :return: The slot span, span start index, and span end index.
  :rtype: tuple[list[int], int, int]
  """

  # get all possible variants of the slot value
  label_variants = [target_value]
  if target_value in label_maps:
    label_variants = label_variants + label_maps[target_value]

  # match the target value as tokens
  start, end = -1, -1
  found = False
  input_list = input.split()
  first_idx = input_list.index('[SEP]')
  max_idx = first_idx
  for label in label_variants:
    label_list = [item for item in map(str.strip, re.split("(\W+)", label)) if len(item) > 0]
    if found == True:
      break
    for idx in (j for j, e in enumerate(input_list) if(e == label_list[0] and j < max_idx)):
      if input_list[idx:idx + len(label_list)] == label_list:
        start, end = idx, idx + len(label_list) - 1
        found = True

  # mark the selected part as characters in the input
  input = " ".join(input_list)
  ch_start, ch_end = -1, -1
  acc_len = 0
  for idx, tok in enumerate(input_list):
    if start == idx:
      ch_start = acc_len + idx
    acc_len += len(tok)
    if end == idx:
      ch_end = acc_len + idx - 1

  # mark the target span in the input string
  char_target = [0] * len(input)
  if ch_start != -1 and ch_end != -1:
    for j in range(ch_start, ch_end + 1):
      if input[j] != " ":
        char_target[j] = 1

  # mark the target span after tokenization
  span = [0] * len(tok_input_offsets)
  for j, (offset1, offset2) in enumerate(tok_input_offsets):
    if sum(char_target[offset1:offset2]) > 0:
      span[j] = 1

  # update the target as tok_input_offsets doesn not include
  # [CLS] & [SEP] in the start & end of input string
  span = [0] + span + [0]

  # get the start & end index of the span if any
  # otherwise 0
  span_start = 0
  span_end = 0
  non_zero = np.nonzero(span)[0]
  if len(non_zero) > 0:
    span_start = non_zero[0]
    span_end = non_zero[-1]

  # pad the target span
  span = span + [0] * padding_len

  return span, span_start, span_end


def create_inputs(history, user_utter, sys_utter, tokenizer, max_len):
  """
  Create inputs for BERT using the history, user utterance, system utterance,
  by creating and tokenizing the input seqence and creating the masks.

  :param history: The history of utterances.
  :type history: list[str]

  :param user_utter: The user's utterance.
  :type user_utter: str

  :param sys_utter: The system's utterance.
  :type sys_utter: str

  :param tokenizer: The tokenizer to tokenize the input.
  :type tokenizer: transformers.PreTrainedTokenizer

  :param max_len: The maximum length of the input.
  :type max_len: int

  :return: The input string, token IDs, attention mask, token type IDs,
            token input offsets, tokenized input tokens, and padding length.
  :rtype: tuple[str, list[int], list[int], list[int], list[tuple[int, int]],
                list[str], int]
  """


  # create input string
  history = " ".join(history)
  current_utter = user_utter + ' [SEP] ' + sys_utter + ' [SEP] '
  input = current_utter + history
  input = " ".join(input.split())

  # tokenize and truncate input
  tok_input = tokenizer.encode(input)
  tok_input_tokens = tok_input.tokens[:max_len]
  tok_input_ids = tok_input.ids[:max_len]
  tok_input_offsets = tok_input.offsets[1:max_len-1]
  if tok_input_tokens[-1] != '[SEP]':
    tok_input_tokens[-1] = '[SEP]'
    tok_input_ids[-1] = 102

  # create mask & input type id
  mask = [1] * len(tok_input_ids)
  token_type_ids = []
  cnt = 0
  for i, token in enumerate(tok_input_tokens):
    token_type_ids.append(1 if cnt >= 2 else 0)
    cnt += 1 if token == '[SEP]' else 0

  # pad the inputs
  padding_len = max_len - len(tok_input_ids)
  ids = tok_input_ids + [0] * padding_len
  mask = mask + [0] * padding_len
  token_type_ids = token_type_ids + [0] * padding_len

  return input, ids, mask, token_type_ids, tok_input_offsets, tok_input_tokens, padding_len


def is_informed(value, target, label_maps, multiwoz):
  """
  Check if a value is informed by the system given a target value and label maps.

  :param value: The value to check.
  :type value: str

  :param target: The target value.
  :type target: str

  :param label_maps: The label maps.
  :type label_maps: dict

  :return: A tuple indicating if the value is informed and the informed value.
  :rtype: tuple[bool, str]
  """

  informed = False
  informed_value = 'none'

  target = ' '.join(normalize(target, multiwoz))

  if value == target or is_included(value, target) or is_included(target, value):
    informed = True
  if value in label_maps:
    informed = included_with_label_maps(target, value, label_maps)
  elif target in label_maps:
    informed = included_with_label_maps(value, target, label_maps)
  if informed: informed_value = value

  return informed, informed_value


def get_refered_slot(target_value, slot, last_state, non_referable_slots, non_referable_pairs, label_maps={}):
    """
    Get the referred slot if the user refers to another slot in the dialogue state given a target value, slot, last state,
    non-referable slots, non-referable pairs, and label maps.

    :param target_value: The target value.
    :type target_value: str

    :param slot: The slot to check.
    :type slot: str

    :param last_state: The last state.
    :type last_state: dict

    :param non_referable_slots: The list of non-referable slots.
    :type non_referable_slots: list[str]

    :param non_referable_pairs: The list of non-referable slot pairs.
    :type non_referable_pairs: list[tuple[str, str]]

    :param label_maps: The label maps.
    :type label_maps: dict, optional

    :return: The referred slot.
    :rtype: str
    """

    referred_slot = 'none'

    if slot in non_referable_slots:
        return referred_slot

    if slot in last_state and last_state[slot] == target_value:
      return referred_slot

    for s in last_state:

        if s in non_referable_slots:
            continue

        if ((slot, s) in non_referable_pairs) or ((s, slot) in non_referable_pairs):
          continue

        if slot == s:
          continue

        if match_with_label_maps(last_state[s], target_value, label_maps):
            referred_slot = s
            break

    return referred_slot


def create_labels(target_value, slot, last_state, input, tok_input_offsets, inform_mem, label_maps, padding_len, max_len, non_referable_slots, non_referable_pairs, multiwoz):
  """
  Create the target operation and the span labels for a slot.

  :param target_value: The target value.
  :type target_value: str

  :param slot: The slot.
  :type slot: str

  :param last_state: The last state.
  :type last_state: dict

  :param input: The input string.
  :type input: str

  :param tok_input_offsets: The token input offsets.
  :type tok_input_offsets: list[tuple[int, int]]

  :param inform_mem: The inform memory.
  :type inform_mem: dict

  :param label_maps: The label maps.
  :type label_maps: dict

  :param padding_len: The padding length.
  :type padding_len: int

  :param max_len: The maximum length of the input.
  :type max_len: int

  :param non_referable_slots: The list of non-referable slots (slots that can not use refering).
  :type non_referable_slots: list[str]

  :param non_referable_pairs: The list of non-referable slot pairs (slots pairs that can not refer to each other).
  :type non_referable_pairs: list[tuple[str, str]]

  :return: The operation, span, span start index, span end index, referred slot, and informed value.
  :rtype: tuple[str, list[int], int, int, str, str]
  """

  oper = 'carryover'
  span = [0] * max_len
  span_start = 0
  span_end = 0
  refered_slot = 'none'
  informed_value = 'none'

  # assert target_value != 'none', 'target value can not be none'

  if target_value in ['[NULL]', 'none']:
    oper = 'carryover'
  elif target_value in ['dontcare', 'yes', 'no']:
    oper = target_value
  else:
    span, span_start, span_end = create_slot_span(input,
                                                  target_value,
                                                  tok_input_offsets,
                                                  padding_len,
                                                  label_maps)

    informed = False
    if slot in inform_mem:
      assert len(inform_mem[slot]) == 1, 'greater than 1'
      informed, informed_value = is_informed(inform_mem[slot][0], target_value, label_maps, multiwoz)

    refered_slot = get_refered_slot(target_value, slot, last_state, non_referable_slots, non_referable_pairs, label_maps)

    if sum(span) != 0:
      oper = 'update'
    elif informed == True:
      oper = 'inform'
    elif refered_slot != 'none':
      oper = 'refer'
    else:
      oper = 'unpointable'

  return oper, span, span_start, span_end, refered_slot, informed_value


def create_data(raw_data, slot_list, label_maps, tokenizer, max_len, non_referable_slots, non_referable_pairs, multiwoz):
  """
  Create the data instances for training or evaluation.

  :param raw_data: The list of raw data instances.
  :type raw_data: list[RawDataInstance]

  :param slot_list: The list of slots.
  :type slot_list: list[str]

  :param label_maps: The label maps.
  :type label_maps: dict

  :param tokenizer: The tokenizer to tokenize the input.
  :type tokenizer: transformers.PreTrainedTokenizer

  :param max_len: The maximum length of the input.
  :type max_len: int

  :param non_referable_slots: The list of non-referable slots.
  :type non_referable_slots: list[str]

  :param non_referable_pairs: The list of non-referable slot pairs.
  :type non_referable_pairs: list[tuple[str, str]]

  :return: The list of data instances.
  :rtype: list[DataInstance]
  """

  data = []

  last_state = {}
  cur_state = {}
  prev_dial_idx = -1
  # loop over raw data
  for turn in tqdm(raw_data):

    # if new dialogue reset the state
    if turn.dial_idx != prev_dial_idx or turn.turn_idx == 0:
      cur_state = {}
      last_state = {}

    # update previous dialogue index
    prev_dial_idx = turn.dial_idx

    # create model inputs
    input, ids, mask, token_type_ids, tok_input_offsets, input_tokens, padding_len = create_inputs(turn.history,
                                                                                                   turn.user_utter,
                                                                                                   turn.sys_utter,
                                                                                                   tokenizer,
                                                                                                   max_len)


    target_values = []
    opers = []
    spans = []
    spans_start = []
    spans_end = []
    refer = ['none'] * len(slot_list)
    inform_aux_features = [0] * len(slot_list)
    ds_aux_features = [0] * len(slot_list)

    # for each slot determine its values
    for slot_idx, slot in enumerate(slot_list):

      # get the slot target value
      target_value = '[NULL]'
      if slot in turn.turn_slots:
        target_value = turn.turn_slots[slot]
      elif slot in cur_state:
        target_value = cur_state[slot]


      # get slot labels
      (oper,
       span,
       span_start,
       span_end,
       refered_slot,
       informed_value) = create_labels(target_value,
                                      slot,
                                      last_state,
                                      input,
                                      tok_input_offsets,
                                      turn.inform_mem,
                                      label_maps,
                                      padding_len,
                                      max_len,
                                      non_referable_slots,
                                      non_referable_pairs,
                                      multiwoz)

      if slot in cur_state and target_value == cur_state[slot] and oper in ['dontcare', 'yes', 'no', 'refer']:
        oper = 'carryover'


      # create auxiliary features
      # mark each informed slot as 1
      if slot in turn.inform_mem:
        inform_aux_features[slot_idx] = 1
      # mark each filled slot as 1
      if slot in cur_state:
        ds_aux_features[slot_idx] = 1

      # update the state
      if oper != 'carryover':
        cur_state[slot] = target_value
        if oper == 'unpointable':
          oper = 'carryover'

#       if turn.dial_idx == 'MUL2491.json' and turn.turn_idx == 8 and slot == 'restaurant-name':
#         print(oper)
#         print(span)
#         print(refered_slot)
#         print(informed_value)
#         print(last_state)
#         print(cur_state)

      target_values.append(target_value)
      opers.append(oper)
      spans.append(span)
      spans_start.append(span_start)
      spans_end.append(span_end)
      refer[slot_idx] = refered_slot

    data.append(DataInstance(ids,
                             mask,
                             token_type_ids,
                             spans,
                             spans_start,
                             spans_end,
                             padding_len,
                             input_tokens,
                             input,
                             opers,
                             target_values,
                             last_state.copy(),
                             cur_state.copy(),
                             refer,
                             inform_aux_features,
                             ds_aux_features))

    # update last state
    last_state = cur_state.copy()


  return data


def prepare_data(data_path, slot_list, label_maps, tokenizer, max_len, domains, non_referable_slots, non_referable_pairs, multiwoz):
  """
  Prepare the data for training or evaluation, this usually the function you want to call to preprocess the data for
  TripPy model, it encapsulates the whole process of preprcessing the data by calling the other functions in this
  module.

  :param data_path: The path to the JSON data file.
  :type data_path: str

  :param slot_list: The list of slots.
  :type slot_list: list[str]

  :param label_maps: The label maps.
  :type label_maps: dict

  :param tokenizer: The tokenizer to tokenize the input.
  :type tokenizer: transformers.PreTrainedTokenizer

  :param max_len: The maximum length of the input.
  :type max_len: int

  :param domains: The list of domains.
  :type domains: list[str]

  :param non_referable_slots: The list of non-referable slots.
  :type non_referable_slots: list[str]

  :param non_referable_pairs: The list of non-referable slot pairs.
  :type non_referable_pairs: list[tuple[str, str]]

  :return: The raw data and prepared data.
  :rtype: tuple[list[RawDataInstance], list[DataInstance]]
  """


  # create raw data
  raw_data = read_raw_data(data_path, slot_list, max_len, domains, multiwoz)

  # create data
  data = create_data(raw_data, slot_list, label_maps, tokenizer, max_len, non_referable_slots, non_referable_pairs, multiwoz)

  return raw_data, data


class Dataset(torch.utils.data.Dataset):
  """
  PyTorch Dataset for the TRIPPY model.

  :param data: The list of data instances.
  :type data: list[DataInstance]

  :param n_slots: The number of slots.
  :type n_slots: int

  :param oper2id: The mapping of operations to IDs.
  :type oper2id: dict[str, int]

  :param slot_list: The list of slots.
  :type slot_list: list[str]
  """

  def __init__(self, data, n_slots, oper2id, slot_list):

    # for k in inputs:
    #   inputs[k] = inputs[k][:32]

    self.ids = [turn.ids for turn in data]
    self.mask = [turn.mask for turn in data]
    self.token_type_ids = [turn.token_type_ids for turn in data]
    self.spans_start = [turn.spans_start for turn in data]
    self.spans_end = [turn.spans_end for turn in data]
    self.padding_len = [turn.padding_len for turn in data]
    self.input_tokens = [' '.join(turn.input_tokens) for turn in data]
    self.target_values = ['[VALUESEP]'.join(turn.target_values) for turn in data]
    self.opers = [[oper2id[oper] for oper in turn.opers] for turn in data]
    # get the index of the refered slot, in case the slot is not present in the slot_list then that means "none"
    # index of "none" is n_slots
    self.refer = [[(slot_list.index(r) if r in slot_list else n_slots) for r in turn.refer] for turn in data]
    self.inform_aux_features = [turn.inform_aux_features for turn in data]
    self.ds_aux_features = [turn.ds_aux_features for turn in data]


  def __len__(self):
    """
    Get the length of the dataset.

    :return: The length of the dataset.
    :rtype: int
    """
    return len(self.ids)

  def __getitem__(self, idx):
    """
    Get an item from the dataset at the given index.

    :param idx: The index of the item.
    :type idx: int

    :return: The item at the given index.
    :rtype: dict[str, torch.Tensor or str]
    """
    return {
        'ids': torch.tensor(self.ids[idx], dtype=torch.long),
        'mask': torch.tensor(self.mask[idx], dtype=torch.long),
        'token_type_ids': torch.tensor(self.token_type_ids[idx], dtype=torch.long),
        'spans_start': torch.tensor(self.spans_start[idx], dtype=torch.long),
        'spans_end': torch.tensor(self.spans_end[idx], dtype=torch.long),
        'padding_len': torch.tensor(self.padding_len[idx], dtype=torch.long),
        'input_tokens': self.input_tokens[idx],
        'target_values': self.target_values[idx],
        'opers': torch.tensor(self.opers[idx], dtype=torch.long),
        'refer': torch.tensor(self.refer[idx], dtype=torch.long),
        'inform_aux_features': torch.tensor(self.inform_aux_features[idx], dtype=torch.float),
        'ds_aux_features': torch.tensor(self.ds_aux_features[idx], dtype=torch.float)
    }

## evaluate

In [8]:
"""
This Module has the evaluation functions for TRIPPY.
"""
import torch
import numpy as np
from sklearn.metrics import f1_score
import copy
from tqdm import tqdm

# from botiverse.models.TRIPPY.utils import normalize, is_included, included_with_label_maps, create_span_output


def get_informed_value(value, target, label_maps, multiwoz):
  """
  Get the informed value based on the value and target, taking into account label maps.

  :param value: The original value.
  :type value: str
  :param target: The target value to compare with.
  :type target: str
  :param label_maps: The mapping of slot values to their variants.
  :type label_maps: dict
  :return: The informed value.
  :rtype: str
  """

  informed = False
  informed_value = value

  value = ' '.join(normalize(value, multiwoz))
  target = ' '.join(normalize(target, multiwoz))

  if value == target or is_included(value, target) or is_included(target, value):
    informed = True
  if value in label_maps:
    informed = included_with_label_maps(target, value, label_maps)
  elif target in label_maps:
    informed = included_with_label_maps(value, target, label_maps)
  if informed: informed_value = target

  return informed_value


def eval(raw_data, data, model, device, n_slots, slot_list, label_maps, oper2id, multiwoz):
  """
  Evaluate the model on the given data.

  :param raw_data: The raw data.
  :type raw_data: list
  :param data: The processed data.
  :type data: list
  :param model: The model to evaluate.
  :type model: nn.Module
  :param device: The device to run the evaluation on.
  :type device: torch.device
  :param n_slots: The number of slots.
  :type n_slots: int
  :param slot_list: The list of slots.
  :type slot_list: list
  :param label_maps: The mapping of slot values to their variants.
  :type label_maps: dict
  :param oper2id: The mapping of operations to their IDs.
  :type oper2id: dict
  :return: The evaluation metrics.
  :rtype: tuple
  """

  model.eval()

  # normalize label_maps
  label_maps_tmp = {}
  for v in label_maps:
      label_maps_tmp[' '.join(normalize(v, multiwoz))] = [' '.join(normalize(nv, multiwoz)) for nv in label_maps[v]]
  label_maps = label_maps_tmp


  # metrics
  Y_true, Y_pred = [], []
  per_slot_acc = {slot:[] for slot in slot_list}
  joint_goal_acc = []

  # state
  pred_last_state = {}
  pre_dialogue_idx = -1

  # debugging
  states = []
  sentences = []
  indices = []
  prev_idx = -1

  with torch.no_grad():

    for raw_turn, turn in tqdm(zip(raw_data, data), total=len(raw_data)):

      ids = torch.tensor(turn.ids, dtype=torch.long).unsqueeze(0).to(device)
      mask = torch.tensor(turn.mask, dtype=torch.long).unsqueeze(0).to(device)
      token_type_ids = torch.tensor(turn.token_type_ids, dtype=torch.long).unsqueeze(0).to(device)
      inform_aux_features = torch.tensor(turn.inform_aux_features, dtype=torch.float).unsqueeze(0).to(device)
      # ds_aux_features = torch.tensor(turn.ds_aux_features, dtype=torch.float).unsqueeze(0).to(device)
      input_tokens = ' '.join(turn.input_tokens)
      padding_len = turn.padding_len
      turn_idx = raw_turn.turn_idx
      dialogue_idx = raw_turn.dial_idx
      current_state = turn.cur_state
      inform_mem = raw_turn.inform_mem
      opers = turn.opers

      # new dialogue reset the state and the state auxiliary features
      if turn_idx == 0 or dialogue_idx != pre_dialogue_idx:
        pred_last_state = {}
        ds_aux_features = torch.zeros((1, n_slots)).to(device)

      # get model outputs
      slots_start_logits, slots_end_logits, slots_oper_logits, slots_refer_logits = model(ids=ids,
                                                                                          mask=mask,
                                                                                          token_type_ids=token_type_ids,
                                                                                          inform_aux_features=inform_aux_features,
                                                                                          ds_aux_features=ds_aux_features)

      # update the predicted state of each slot
      pred_current_state = pred_last_state.copy()
      for slot_idx, slot in enumerate(slot_list):

        # get the predicted operation
        pred_oper = slots_oper_logits[slot_idx][0].argmax(dim=-1).item()

        # keep track of operations for f1 score
        Y_pred.append(pred_oper)
        Y_true.append(oper2id[opers[slot_idx]])

        # update the slot based on the operation
        if pred_oper == oper2id['carryover']: # carryover
          continue
        elif pred_oper == oper2id['dontcare']: # dontcare
          pred_current_state[slot] = 'dontcare'
        elif pred_oper == oper2id['update']: # update
          pred_current_state[slot] = create_span_output(slots_start_logits[slot_idx][0].cpu().detach().numpy(),
                                                        slots_end_logits[slot_idx][0].cpu().detach().numpy(),
                                                        padding_len,
                                                        input_tokens)
        elif pred_oper == oper2id['refer']: # refer
          refered_slot = slots_refer_logits[slot_idx][0].argmax(dim=-1).item()
          if refered_slot != n_slots and slot_list[refered_slot] in pred_last_state:
            pred_current_state[slot] = pred_last_state[slot_list[refered_slot]]
        elif pred_oper == oper2id['yes']: # yes
          pred_current_state[slot] = 'yes'
        elif pred_oper == oper2id['no']: # no
          pred_current_state[slot] = 'no'
        elif pred_oper == oper2id['inform']: # inform
          if slot in inform_mem:
            pred_current_state[slot] = '§§' + inform_mem[slot][0]

      # update the state auxiliary features
      for slot_idx, slot in enumerate(slot_list):
          ds_aux_features[0, slot_idx] = 1 if slot in pred_current_state else 0

      # calculate accuracy
      joint = True
      for slot_idx, slot in enumerate(slot_list):

        # if not present in both
        if slot not in current_state and slot not in pred_current_state:
          per_slot_acc[slot].append(1.0)
          continue

        # if slot only in one of them then mark as 0
        if (slot in current_state and slot not in pred_current_state) or (slot not in current_state and slot in pred_current_state):
          joint = False
          per_slot_acc[slot].append(0.0)
          continue

        # get values
        val = current_state[slot]
        pred_val = pred_current_state[slot]

        # normalize values
        val = ' '.join(normalize(val, multiwoz))
        pred_val = ' '.join(normalize(pred_val, multiwoz))

        # handle inform
        if pred_val[0:3] == "§§ ":
          if pred_val[3:] != 'none':
              pred_val = get_informed_value(pred_val[3:], val, label_maps, multiwoz)
        elif pred_val[0:2] == "§§":
            if pred_val[2:] != 'none':
                pred_val = get_informed_value(pred_val[2:], val, label_maps, multiwoz)

        # match
        if pred_val == val:
          per_slot_acc[slot].append(1.0)
        elif val != 'none' and val != 'dontcare' and val != 'true' and val != 'false' and val in label_maps:
          no_match = True
          for variant in label_maps[val]:
              if variant == pred_val:
                  no_match = False
                  break
          if no_match:
              per_slot_acc[slot].append(0.0)
              joint = False
          else:
              per_slot_acc[slot].append(1.0)
        else:
            per_slot_acc[slot].append(0.0)
            joint = False

      # append joint
      joint_goal_acc.append(1.0 if joint else 0.0)


      # update vars for next turn
      pred_last_state = pred_current_state.copy()
      pre_dialogue_idx = dialogue_idx

#       # debugging
#       if per_slot_acc['attraction-name'][-1] < 0.99 and prev_idx != dialogue_idx:
#         print('dialogue_idx', dialogue_idx)
#         print('turn_idx', turn_idx)
#         print('pred_state', dict(sorted(pred_current_state.items())))
#         print('cur_state', dict(sorted(current_state.items())))
#         print('input tok', input_tokens)
#         print('inform_mem', inform_mem)
#         print('inform aux', inform_aux_features)
#         print('oper', opers[slot_list.index('attraction-name')])
#         prev_idx = dialogue_idx

      # debugging
      if joint == False:
        states.append((copy.deepcopy(pred_current_state), current_state))
        sentences.append(input_tokens)
        indices.append((dialogue_idx, turn_idx))

  # debugging
  # prev = ""
  # for i in range(len(states)):
  #   if prev == indices[i][0] or len(states[i][0]) != len(states[i][1]):
  #     continue
  #   if 'attraction-name' not in states[i][1]:# and 'restaurant-name' not in states[i][1]:
  #     continue
  #   prev = indices[i][0]
  #   print(dict(sorted(states[i][0].items())))
  #   print(dict(sorted(states[i][1].items())))
  #   print(sentences[i])
  #   print(indices[i])
  #   print("\n")

  # calculate per slot accuracy
  per_slot_acc = {slot: np.mean(acc) for slot, acc in per_slot_acc.items()}

  # calculate f1 scores
  macro_f1_score = f1_score(Y_true, Y_pred, average='macro')
  all_f1_score = f1_score(Y_true, Y_pred, average=None)

  return np.mean(joint_goal_acc), per_slot_acc, macro_f1_score, all_f1_score

## infer

In [9]:
"""
This Module has the inference functions for TRIPPY.
"""


import torch

# from botiverse.models.TRIPPY.data import create_inputs
# from botiverse.models.TRIPPY.utils import create_span_output


def infer(model, slot_list, current_state, history, sys_utter, user_utter, inform_mem, device, oper2id, tokenizer, max_len):
  """
  Infer the dialogue state using the TRIPPY model.

  :param model: The TRIPPY model for inference.
  :type model: TRIPPY
  :param slot_list: The list of slots.
  :type slot_list: list
  :param current_state: The current dialogue state.
  :type current_state: dict
  :param history: The dialogue history.
  :type history: list
  :param sys_utter: The system's utterance.
  :type sys_utter: str
  :param user_utter: The user's utterance.
  :type user_utter: str
  :param inform_mem: The inform memory.
  :type inform_mem: dict
  :param device: The device to run the inference on.
  :type device: torch.device
  :param oper2id: The mapping of operations to IDs.
  :type oper2id: dict
  :param tokenizer: The tokenizer to tokenize the input.
  :type tokenizer: transformers.PreTrainedTokenizer
  :param max_len: The maximum length of the input sequence.
  :type max_len: int
  :return: The predicted dialogue state.
  :rtype: dict
  """

  model.eval()

  # turn data to inputs
  input, ids, mask, token_type_ids, tok_input_offsets, input_tokens, padding_len = create_inputs(history,
                                                                                                 user_utter,
                                                                                                 sys_utter,
                                                                                                 tokenizer,
                                                                                                 max_len)


  # print(input, ids, mask, token_type_ids, tok_input_offsets, input_tokens, padding_len)


  with torch.no_grad():
    n_slots = len(slot_list)
    ids = torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
    mask = torch.tensor(mask, dtype=torch.long).unsqueeze(0).to(device)
    token_type_ids = torch.tensor(token_type_ids, dtype=torch.long).unsqueeze(0).to(device)
    inform_aux_features = torch.zeros((1, n_slots)).to(device)
    ds_aux_features = torch.zeros((1, n_slots)).to(device)
    input_tokens = ' '.join(input_tokens)
    padding_len = padding_len

    for slot_idx, slot in enumerate(slot_list):
      if slot in inform_mem:
        inform_aux_features[0, slot_idx] = 1
      if slot in current_state:
        ds_aux_features[0, slot_idx] = 1

    # print(slot_list)
    # print(inform_aux_features)
    # print(inform_mem)


    # get model outputs
    slots_start_logits, slots_end_logits, slots_oper_logits, slots_refer_logits = model(ids=ids,
                                                                                        mask=mask,
                                                                                        token_type_ids=token_type_ids,
                                                                                        inform_aux_features=inform_aux_features,
                                                                                        ds_aux_features=ds_aux_features)


    # update the predicted state of each slot
    pred_state = current_state.copy()
    for slot_idx, slot in enumerate(slot_list):

      # get the predicted operation
      pred_oper = slots_oper_logits[slot_idx][0].argmax(dim=-1).item()
      # print(slot, torch.softmax(slots_oper_logits[slot_idx][0], dim=-1))

      # update the slot based on the operation
      if pred_oper == oper2id['carryover']: # carryover
        continue
      elif pred_oper == oper2id['dontcare']: # dontcare
        pred_state[slot] = 'dontcare'
      elif pred_oper == oper2id['update']: # update
        pred_state[slot] = create_span_output(slots_start_logits[slot_idx][0].cpu().detach().numpy(),
                                              slots_end_logits[slot_idx][0].cpu().detach().numpy(),
                                              padding_len,
                                              input_tokens)
      elif pred_oper == oper2id['refer']: # refer
        refered_slot = slots_refer_logits[slot_idx][0].argmax(dim=-1).item()
        if refered_slot != n_slots and slot_list[refered_slot] in current_state:
          pred_state[slot] = current_state[slot_list[refered_slot]]
      elif pred_oper == oper2id['yes']: # yes
        pred_state[slot] = 'yes'
      elif pred_oper == oper2id['no']: # no
        pred_state[slot] = 'no'
      elif pred_oper == oper2id['inform']: # inform
        if slot in inform_mem:
          pred_state[slot] = '§§' + inform_mem[slot][0]


  return pred_state

## run

In [10]:
"""
This Module has the run functions for TRIPPY that train and evaluate the model.
"""

import torch
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup


# from botiverse.models.TRIPPY.data import prepare_data, Dataset
# from botiverse.models.TRIPPY.train import train
# from botiverse.models.TRIPPY.evaluate import eval


def run(model, domains, slot_list, label_maps, train_json, dev_json, test_json, device, non_referable_slots, non_referable_pairs, model_path, TRIPPY_config):
    """
    Train and evaluate the TRIPPY model.

    :param model: The TRIPPY model.
    :type model: TRIPPY
    :param domains: The domains to consider in the dataset.
    :type domains: list
    :param slot_list: The list of slots.
    :type slot_list: list
    :param label_maps: The mapping of slot values to their variants.
    :type label_maps: dict
    :param train_json: The path to the training dataset in JSON format.
    :type train_json: str
    :param dev_json: The path to the development dataset in JSON format.
    :type dev_json: str
    :param test_json: The path to the testing dataset in JSON format.
    :type test_json: str
    :param device: The device to train and evaluate the model on.
    :type device: torch.device
    :param non_referable_slots: The slots that are not referable.
    :type non_referable_slots: list
    :param non_referable_pairs: The pairs of slots that are not referable.
    :type non_referable_pairs: list
    :param model_path: The path to save the best model.
    :type model_path: str
    :param TRIPPY_config: The configuration for TRIPPY.
    :type TRIPPY_config: TRIPPYConfig
    """

    n_slots = len(slot_list)

    # train
    print('Preprocessing train set...')
    train_raw_data, train_data = prepare_data(train_json, slot_list, label_maps, TRIPPY_config.tokenizer, TRIPPY_config.max_len, domains, non_referable_slots, non_referable_pairs, TRIPPY_config.multiwoz)
    train_dataset = Dataset(train_data, n_slots, TRIPPY_config.oper2id, slot_list)
    train_sampler = torch.utils.data.RandomSampler(train_dataset)
    train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                    sampler=train_sampler,
                                                    batch_size=TRIPPY_config.train_batch_size)

    # dev
    print('Preprocessing dev set...')
    dev_raw_data, dev_data = prepare_data(dev_json, slot_list, label_maps, TRIPPY_config.tokenizer, TRIPPY_config.max_len, domains, non_referable_slots, non_referable_pairs, TRIPPY_config.multiwoz)
    dev_dataset = Dataset(dev_data, n_slots, TRIPPY_config.oper2id, slot_list)
    dev_data_loader = torch.utils.data.DataLoader(dev_dataset,
                                                  batch_size=TRIPPY_config.dev_batch_size)

    # test
    print('Preprocessing test set...')
    test_raw_data, test_data = prepare_data(test_json, slot_list, label_maps, TRIPPY_config.tokenizer, TRIPPY_config.max_len, domains, non_referable_slots, non_referable_pairs, TRIPPY_config.multiwoz)
    test_dataset = Dataset(test_data, n_slots, TRIPPY_config.oper2id, slot_list)
    test_data_loader = torch.utils.data.DataLoader(test_dataset,
                                                   batch_size=TRIPPY_config.test_batch_size)

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": TRIPPY_config.weight_decay,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    # num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE * EPOCHS)
    num_train_steps = len(train_data_loader) * TRIPPY_config.epochs
    num_warmup_steps = int(num_train_steps * TRIPPY_config.warmup_proportion)

    optimizer = AdamW(optimizer_parameters, lr=TRIPPY_config.lr, eps=TRIPPY_config.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps
    )

    best_joint = -1
    for epoch in range(TRIPPY_config.epochs):
        print(f'\nEpoch: {epoch} ---------------------------------------------------------------')
        print('Training the model...')
        train(train_data_loader, model, optimizer, device, scheduler, n_slots, TRIPPY_config.ignore_idx, TRIPPY_config.oper2id)
        print('Evaluating the model on dev set...')
        # jaccard_score, macro_f1_score, all_f1_score = eval_f1_jac(dev_data_loader, model, device, n_slots)
        # joint_goal_acc, states, sentences, indices = eval_joint(dev_raw_data, dev_data, model, device, n_slots, slot_list, label_maps)
        joint_goal_acc, per_slot_acc, macro_f1_score, all_f1_score = eval(dev_raw_data, dev_data, model, device, n_slots, slot_list, label_maps, TRIPPY_config.oper2id, TRIPPY_config.multiwoz)
        # print(f'Joint Goal Acc: {joint_goal_acc}, Jaccard Score: {jaccard_score}, Macro F1 Score: {macro_f1_score}')
        print(f'Joint Goal Acc: {joint_goal_acc}')
        print(f'Per Slot Acc: {per_slot_acc}')
        print(f'Macro F1 Score: {macro_f1_score}')
        print(f'All f1 score = {all_f1_score}')
        if joint_goal_acc > best_joint:
            torch.save(model.state_dict(), model_path)
            best_joint = joint_goal_acc

    print('Loading best model on dev set...')
    model.load_state_dict(torch.load(model_path))
    print('Evaluating the model on test set...')
    joint_goal_acc, per_slot_acc, macro_f1_score, all_f1_score = eval(test_raw_data, test_data, model, device, n_slots, slot_list, label_maps, TRIPPY_config.oper2id, TRIPPY_config.multiwoz)
    print(f'Joint Goal Acc: {joint_goal_acc}')
    print(f'Per Slot Acc: {per_slot_acc}')
    print(f'Macro F1 Score: {macro_f1_score}')
    print(f'All f1 score = {all_f1_score}')




# from botiverse.bots import Theorizer
# from botiverse.models import NeuralNet
# from botiverse.preprocessors import BertEmbedder

## train

In [11]:
"""
This Module has the training functions for TRIPPY.
"""

import torch
import torch.nn as nn
from tqdm import tqdm

# from botiverse.models.TRIPPY.utils import AverageMeter


def span_loss_fn(start_logits, end_logits, targets_start, targets_end, ignore_idx):
  """
  Compute the span loss.

  :param start_logits: The start logits.
  :type start_logits: torch.Tensor
  :param end_logits: The end logits.
  :type end_logits: torch.Tensor
  :param targets_start: The start targets.
  :type targets_start: torch.Tensor
  :param targets_end: The end targets.
  :type targets_end: torch.Tensor
  :param ignore_idx: The index to ignore in the loss calculation.
  :type ignore_idx: int
  :return: The span loss.
  :rtype: torch.Tensor
  """
  l1 = nn.CrossEntropyLoss(ignore_index=ignore_idx, reduction='none')(start_logits, targets_start)
  l2 = nn.CrossEntropyLoss(ignore_index=ignore_idx, reduction='none')(end_logits, targets_end)
  return (l1 + l2) / 2.0

def oper_loss_fn(oper_logits, oper_labels, ignore_idx):
  """
  Compute the operation loss.

  :param oper_logits: The operation logits.
  :type oper_logits: torch.Tensor
  :param oper_labels: The operation labels.
  :type oper_labels: torch.Tensor
  :param ignore_idx: The index to ignore in the loss calculation.
  :type ignore_idx: int
  :return: The operation loss.
  :rtype: torch.Tensor
  """
  l = nn.CrossEntropyLoss(ignore_index=ignore_idx, reduction='none')(oper_logits, oper_labels)
  return l

def refer_loss_fn(refer_logits, refer_labels, ignore_idx):
  """
  Compute the refer loss.

  :param refer_logits: The refer logits.
  :type refer_logits: torch.Tensor
  :param refer_labels: The refer labels.
  :type refer_labels: torch.Tensor
  :param ignore_idx: The index to ignore in the loss calculation.
  :type ignore_idx: int
  :return: The refer loss.
  :rtype: torch.Tensor
  """
  l = nn.CrossEntropyLoss(ignore_index=ignore_idx, reduction='none')(refer_logits, refer_labels)
  return l

def train(data_loader, model, optimizer, device, scheduler, n_slots, ignore_idx, oper2id):
  """
  Perform the training loop for a model on the given data.

  :param data_loader: The data loader providing the training batches.
  :type data_loader: DataLoader

  :param model: The model to be trained.
  :type model: nn.Module

  :param optimizer: The optimizer used to update the model's parameters.
  :type optimizer: Optimizer

  :param device: The device (e.g., CPU or GPU) on which the training will be performed.
  :type device: torch.device

  :param scheduler: The scheduler for adjusting the learning rate during training.
  :type scheduler: _LRScheduler

  :param n_slots: The number of slots in the task.
  :type n_slots: int

  :param ignore_idx: The index to ignore during loss computation.
  :type ignore_idx: int

  :param oper2id: A dictionary mapping operation names to their corresponding IDs.
  :type oper2id: dict
  """

  model.train()

  losses = AverageMeter()

  tk0 = tqdm(data_loader)
  for i, batch in enumerate(tk0):

    ids = batch['ids'].to(device)
    mask = batch['mask'].to(device)
    token_type_ids = batch['token_type_ids'].to(device)
    spans_start = batch['spans_start'].to(device)
    spans_end = batch['spans_end'].to(device)
    opers = batch['opers'].to(device)
    refer = batch['refer'].to(device)
    inform_aux_features = batch['inform_aux_features'].to(device)
    ds_aux_features = batch['ds_aux_features'].to(device)

    optimizer.zero_grad()
    slots_start_logits, slots_end_logits, slots_oper_logits, slots_refer_logits = model(ids=ids,
                                                                                        mask=mask,
                                                                                        token_type_ids=token_type_ids,
                                                                                        inform_aux_features=inform_aux_features,
                                                                                        ds_aux_features=ds_aux_features)

    batch_loss = 0.0

    for slot in range(n_slots):

      oper_loss = oper_loss_fn(slots_oper_logits[slot], opers[:,slot], ignore_idx)

      span_loss = span_loss_fn(slots_start_logits[slot],
                               slots_end_logits[slot],
                               spans_start[:,slot],
                               spans_end[:,slot],
                               ignore_idx)
      token_is_pointable = (spans_start[:,slot] > 0).float()
      span_loss *= token_is_pointable

      refer_loss = refer_loss_fn(slots_refer_logits[slot], refer[:,slot], ignore_idx)
      token_is_referrable = (opers[:,slot] == oper2id['refer']).float()
      refer_loss *= token_is_referrable

      total_loss = 0.8 * oper_loss + 0.1 * span_loss + 0.1 * refer_loss

      batch_loss += total_loss.sum()

    # batch_loss /= n_slots
    batch_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
    losses.update(batch_loss.item(), ids.size(0))

    tk0.set_postfix(loss=losses.avg)

## TRIPPY_DST

In [12]:
"""
This Module has base code and interfaces for TripPy Dialogue State Tracker.
"""

import torch
from collections import OrderedDict

# from botiverse.models.TRIPPY.utils import normalize, mask_utterance
# from botiverse.models.TRIPPY.data import get_ontology_label_maps, prepare_data, Dataset
# from botiverse.models.TRIPPY.run import run
# from botiverse.models.TRIPPY.infer import infer
# from botiverse.models.TRIPPY.config import TRIPPYConfig
# from botiverse.models.TRIPPY.TRIPPY import TRIPPY


class TRIPPYDST:
    """
    TRIPPYDST is a class that represents the TripPy Dialogue State Tracker.

    It provides methods for loading the model, training the model, updating the dialogue state,
    getting the current dialogue state, deleting slots, resetting the tracker, and displaying the tracker information.

    :param domains: The list of domains to consider.
    :type domains: list[str]
    :param ontology_path: The path to the ontology file.
    :type ontology_path: str
    :param label_maps_path: The path to the label maps file.
    :type label_maps_path: str
    :param non_referable_slots: The list of non-referable slots.
    :type non_referable_slots: list[str]
    :param non_referable_pairs: The list of non-referable slot pairs.
    :type non_referable_pairs: list[tuple[str, str]]
    :param from_scratch: Whether to train the model from scratch.
    :type from_scratch: bool
    :param TRIPPY_config: The configuration for the TRIPPY model, defaults to TRIPPYConfig()
    :type TRIPPY_config: TRIPPYConfig, optional
    """

    def __init__(self, domains, ontology_path, label_maps_path, non_referable_slots, non_referable_pairs, from_scratch, BERT_config, TRIPPY_config=TRIPPYConfig()):
        self.domains = domains
        self.ontology_path = ontology_path
        self.label_maps_path = label_maps_path
        self.non_referable_slots = non_referable_slots
        self.non_referable_pairs = non_referable_pairs
        self.from_scratch = from_scratch
        self.BERT_config = BERT_config
        self.TRIPPY_config = TRIPPY_config

        slot_list, label_maps = get_ontology_label_maps(ontology_path, label_maps_path, domains)
        self.slot_list = slot_list
        self.n_slots = len(slot_list)
        self.label_maps = label_maps
        self.state = {}
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = TRIPPY(len(slot_list), TRIPPY_config.hid_dim, TRIPPY_config.n_oper, TRIPPY_config.dropout, from_scratch, BERT_config, TRIPPY_config).to(self.device)
        self.history = []

    def save_model(self, model_path):
        """
        Save the trained model.

        :param model_path: The path to save the model.
        :type model_path: str
        """
        torch.save(self.model.state_dict(), model_path)

    def load_model(self, model_path, test_path):
      """
      Load the trained model.

      :param model_path: The path to the saved model.
      :type model_path: str
      :param test_path: The path to the test data for evaluation.
      :type test_path: str
      """
      if self.from_scratch == True:
          # Get saved weights
          state_dict = torch.load(model_path, map_location=self.device)

          # Delete position_ids from the state_dict if available
          if 'bert.embeddings.position_ids' in state_dict.keys():
              del state_dict['bert.embeddings.position_ids']

          # Get the new weights keys from the model
          new_keys = list(self.model.state_dict().keys())

          # Get the weights from the state_dict
          old_keys = list(state_dict.keys())
          weights = list(state_dict.values())

          # Create a new state_dict with the new keys
          new_state_dict = OrderedDict()
          for i in range(len(new_keys)):
              new_state_dict[new_keys[i]] = weights[i]
              # print(old_keys[i], '->', new_keys[i])

          self.model.load_state_dict(new_state_dict)

      else:
          self.model.load_state_dict(torch.load(model_path, map_location=self.device))

      print('Model loaded successfully.')
      if test_path is not None:
        print('Preprocessing the data...')
        test_raw_data, test_data = prepare_data(test_path, self.slot_list, self.label_maps, self.TRIPPY_config.tokenizer, self.TRIPPY_config.max_len, self.domains, self.non_referable_slots, self.non_referable_pairs, self.TRIPPY_config.multiwoz)
        test_dataset = Dataset(test_data, self.n_slots, self.TRIPPY_config.oper2id, self.slot_list)
        test_data_loader = torch.utils.data.DataLoader(test_dataset,
                                                       batch_size=self.TRIPPY_config.test_batch_size)
        print('Evaluating the model on the data...')
        joint_goal_acc, per_slot_acc, macro_f1_score, all_f1_score = eval(test_raw_data, test_data, self.model, self.device, self.n_slots, self.slot_list, self.label_maps, self.TRIPPY_config.oper2id, self.TRIPPY_config.multiwoz)
        print(f'Joint Goal Acc: {joint_goal_acc}')
        print(f'Per Slot Acc: {per_slot_acc}')
        print(f'Macro F1 Score: {macro_f1_score}')
        print(f'All f1 score = {all_f1_score}')


    def train(self, train_path, dev_path, test_path, model_path):
      """
      Train the model.

      :param train_path: The path to the training data.
      :type train_path: str
      :param dev_path: The path to the development data for evaluation during training.
      :type dev_path: str
      :param test_path: The path to the test data for evaluation after training.
      :type test_path: str
      :param model_path: The path to save the trained model.
      :type model_path: str
      """
      run(self.model, self.domains, self.slot_list, self.label_maps, train_path, dev_path, test_path, self.device, self.non_referable_slots, self.non_referable_pairs, model_path, self.TRIPPY_config)

    def update_state(self, sys_utter, user_utter, inform_mem):
      """
      Update the dialogue state based on the system and user utterances.

      :param sys_utter: The system utterance.
      :type sys_utter: str
      :param user_utter: The user utterance.
      :type user_utter: str
      :param inform_mem: The inform memory containing previous slot-value pairs.
      :type inform_mem: dict[str, list[str]]
      :return: The updated dialogue state.
      :rtype: dict[str, str]
      """

      # normalize utterances
      user_utter = ' '.join(normalize(user_utter, self.TRIPPY_config.multiwoz))
      sys_utter = ' '.join(normalize(sys_utter, self.TRIPPY_config.multiwoz))
      # delex the system utterance
      sys_utter = ' '.join(mask_utterance(sys_utter, inform_mem, self.TRIPPY_config.multiwoz, '[UNK]'))

      self.state = infer(self.model, self.slot_list, self.state, self.history, sys_utter, user_utter, inform_mem, self.device, self.TRIPPY_config.oper2id, self.TRIPPY_config.tokenizer, self.TRIPPY_config.max_len)
      self.history = [user_utter, sys_utter] + self.history
      return self.state.copy()

    def get_dialogue_state(self):
      """
      Get a copy of the current dialogue state.

      :return: A copy of the dialogue state.
      :rtype: dict[str, str]
      """
      return self.state.copy()

    def delete_slots(self, domain, slot):
      """
      Delete slots from the dialogue state.

      If a domain is specified, all slots in that domain will be deleted.
      If a slot is specified, that specific slot will be deleted.
      If neither domain nor slot is specified, all slots will be deleted.

      :param domain: The domain to delete slots from.
      :type domain: str
      :param slot: The slot to delete.
      :type slot: str
      """
      keys = self.state.keys()
      if domain is not None:
        for key in keys:
          if domain in key:
            del self.state[key]
      elif slot is not None:
          if slot in keys:
            del self.state[slot]
      else:
        for key in keys:
          del self.state[key]

    def reset(self):
      """
      Reset the dialogue state.

      Remove all slots from the dialogue state and clear the history.
      """
      keys = list(self.state.keys())
      for key in keys:
        del self.state[key]

      self.history = []


    def __str__(self):
      """
      Return a string representation of the TRIPPYDST object.

      :return: A string representation of the object.
      :rtype: str
      """
      string = ''
      string = string + '\ndomains: ' + str(self.domains)
      string = string + '\nontology_path: ' + str(self.ontology_path)
      string = string + '\nlabel_maps_path: ' + str(self.label_maps_path)
      string = string + '\nnon_referable_slots: ' + str(self.non_referable_slots)
      string = string + '\nnon_referable_pairs: ' + str(self.non_referable_pairs)
      string = string + '\nfrom_scratch: ' + str(self.from_scratch)
      string = string + '\nslot_list: ' + str(self.slot_list)
      string = string + '\nn_slots: ' + str(self.n_slots)
      string = string + '\nlabel_maps: ' + str(self.label_maps)
      string = string + '\nstate: ' + str(self.state)
      string = string + '\ndevice: ' + str(self.device)
      # string = string + '\nmodel: ' + str(self.model)
      string = string + '\nhistory: ' + str(self.history)
      return string

## TRIPPY

In [13]:
"""
This Module has the TRIPPY model.
"""

import torch
import torch.nn as nn
from transformers import BertModel

# from botiverse.models.BERT.BERT import Bert
# from botiverse.models.BERT.config import BERTConfig
# from botiverse.models.BERT.utils import LoadPretrainedWeights

class TRIPPY(nn.Module):
  """
  TRIPPY (Task-oriented Reasoning and Inference for Pre-trained models with Pre-trained Ypesystem) model.

  This class implements the TRIPPY model for task-oriented dialogue understanding and slot filling.

  :param n_slots: The number of slots, corresponding to the number of dialogue slots to be filled.
  :type n_slots: int
  :param hid_dim: The hidden dimension size.
  :type hid_dim: int
  :param n_oper: The number of operations.
  :type n_oper: int
  :param dropout: The dropout rate.
  :type dropout: float
  :param from_scratch: Whether to build the BERT model from scratch or load pre-trained weights, defaults to False.
  :type from_scratch: bool
  :param BERT_config: The configuration for the BERT model, defaults to BERTConfig().
  :type BERT_config: BERTConfig
  """

  def __init__(self, n_slots, hid_dim, n_oper, dropout, from_scratch, BERT_config=BERTConfig(), TRIPPY_config=TRIPPYConfig()):
    super(TRIPPY, self).__init__()

    self.hid_dim = hid_dim
    self.n_oper = n_oper
    self.n_slots = n_slots

    if from_scratch == True:
        # Build a BERT model from scratch
        self.bert = Bert(BERT_config)
        LoadPretrainedWeights(self.bert)
    else:
        self.bert = BertModel.from_pretrained('bert-base-uncased')

    aux_dim = 2 * n_slots
    self.oper_layers = nn.ModuleList([nn.Linear(hid_dim + aux_dim, n_oper) for _ in range(n_slots)])
    self.span_layers = nn.ModuleList([nn.Linear(hid_dim, 2) for _ in range(n_slots)])
    self.refer_layers = nn.ModuleList([nn.Linear(hid_dim + aux_dim, n_slots + 1) for _ in range(n_slots)])
    self.dropout = nn.Dropout(dropout)

    # auxiliary features layers
    self.inform_aux_layer = nn.Linear(n_slots, n_slots)
    self.ds_aux_layer = nn.Linear(n_slots, n_slots)

  def forward(self, ids, mask, token_type_ids, inform_aux_features, ds_aux_features):
    """
    Forward pass of the TRIPPY model.

    :param ids: The input token IDs.
    :type ids: torch.Tensor, shape [batch size, seq len]
    :param mask: The attention mask indicating which tokens are valid.
    :type mask: torch.Tensor, shape [batch size, seq len]
    :param token_type_ids: The token type IDs.
    :type token_type_ids: torch.Tensor, shape [batch size, seq len]
    :param inform_aux_features: The auxiliary features for informing slots.
    :type inform_aux_features: torch.Tensor, shape [batch size, n_slots]
    :param ds_aux_features: The auxiliary features for dialogue state tracking.
    :type ds_aux_features: torch.Tensor, shape [batch size, n_slots]
    :return: Tuple containing the logits for slot start positions, slot end positions, slot operations, and slot references.
    :rtype: tuple(torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)
    """

    # ids = [batch size, seq len]
    # mask = [batch size, seq len]
    # token_type_ids = [batch size, seq len]

    # sequence_output = [batch size, seq len, hid_dim]
    # pooled_output = [batch size, hid_dim]
    sequence_output, pooled_output = self.bert(ids,
                                               attention_mask=mask,
                                               token_type_ids=token_type_ids,
                                               return_dict=False
                                               )

    sequence_output = self.dropout(sequence_output)
    pooled_output = self.dropout(pooled_output)

    # concatenate the auxiliary features
    # pooled_output = [batch size, hid_dim + 2 * n_slots]
    # print(pooled_output.shape)
    # print(inform_aux_features.shape)
    # print(ds_aux_features.shape)
    pooled_output = torch.cat((pooled_output, self.inform_aux_layer(inform_aux_features), self.ds_aux_layer(ds_aux_features)), 1)
    # print(pooled_output.shape)
    # print("\n\n\n\n")

    slots_start_logits = []
    slots_end_logits = []
    slots_oper_logits = []
    slots_refer_logits = []
    for slot in range(self.n_slots):

      # oper_logits = [batch size, n_oper]
      oper_logits = self.oper_layers[slot](pooled_output)

      # span_logits = [batch size, seq len, 2]
      span_logits = self.span_layers[slot](sequence_output)

      # start_logits = [batch size, seq len, 1]
      # end_logits = [batch size, seq len, 1]
      start_logits, end_logits = span_logits.split(1, dim=-1)

      # start_logits = [batch size, seq len]
      # end_logits = [batch size, seq len]
      start_logits = start_logits.squeeze(-1)
      end_logits = end_logits.squeeze(-1)

      # refer_logits = [batch size, n_slots + 1]
      refer_logits = self.refer_layers[slot](pooled_output)

      slots_start_logits.append(start_logits)
      slots_end_logits.append(end_logits)
      slots_oper_logits.append(oper_logits)
      slots_refer_logits.append(refer_logits)

    return slots_start_logits, slots_end_logits, slots_oper_logits, slots_refer_logits

## utils

In [14]:
import re
import string
import numpy as np

# from botiverse.models.TRIPPY.config import MULTIWOZ


class RawDataInstance():
  """
  Represents a raw data instance.

  :param dial_idx: Dialogue index.
  :type dial_idx: str
  :param turn_idx: Turn index.
  :type turn_idx: int
  :param user_utter: User utterance.
  :type user_utter: str
  :param sys_utter: System utterance.
  :type sys_utter: str
  :param history: Dialogue history.
  :type history: list[str]
  :param turn_slots: Slots for the current turn.
  :type turn_slots: dict[str, str]
  :param inform_mem: Informed slots from previous turns.
  :type inform_mem: dict[str, list[str]]
  """
  def __init__(self,
               dial_idx,
               turn_idx,
               user_utter,
               sys_utter,
               history,
               turn_slots,
               inform_mem):
    self.dial_idx = dial_idx
    self.turn_idx = turn_idx
    self.user_utter = user_utter
    self.sys_utter = sys_utter
    self.history = history
    self.turn_slots = turn_slots
    self.inform_mem = inform_mem

  def __str__(self):
    """
    Return a string representation of the RawDataInstance object.

    :return: A string representation of the object.
    :rtype: str
    """
    string = ''
    string = string + '\ndial_idx: ' + str(self.dial_idx)
    string = string + '\nturn_idx: ' + str(self.turn_idx)
    string = string + '\nuser_utter: ' + str(self.user_utter)
    string = string + '\nsys_utter: ' + str(self.sys_utter)
    string = string + '\nhistory: ' + str(self.history)
    string = string + '\nturn_slots: ' + str(self.turn_slots)
    string = string + '\ninform_mem: ' + str(self.inform_mem)
    return string

class DataInstance():
  """
  Represents a processed data instance.

  :param ids: Input IDs.
  :type ids: list[int]
  :param mask: Attention mask.
  :type mask: list[int]
  :param token_type_ids: Token type IDs.
  :type token_type_ids: list[int]
  :param spans: Spans.
  :type spans: list[int]
  :param spans_start: Start positions of spans.
  :type spans_start: list[int]
  :param spans_end: End positions of spans.
  :type spans_end: list[int]
  :param padding_len: Padding length.
  :type padding_len: int
  :param input_tokens: Input tokens.
  :type input_tokens: str
  :param input: Input text.
  :type input: str
  :param opers: Slot operations.
  :type opers: list[int]
  :param target_values: Target slot values.
  :type target_values: list[str]
  :param last_state: Last dialogue state.
  :type last_state: dict[str, str]
  :param cur_state: Current dialogue state.
  :type cur_state: dict[str, str]
  :param refer: Referenced slots.
  :type refer: list[int]
  :param inform_aux_features: Informed auxiliary features.
  :type inform_aux_features: list[float]
  :param ds_aux_features: Filled slot auxiliary features.
  :type ds_aux_features: list[float]
  """
  def __init__(self,
               ids,
               mask,
               token_type_ids,
               spans,
               spans_start,
               spans_end,
               padding_len,
               input_tokens,
               input,
               opers,
               target_values,
               last_state,
               cur_state,
               refer,
               inform_aux_features,
               ds_aux_features):
    self.ids = ids
    self.mask = mask
    self.token_type_ids = token_type_ids
    self.spans = spans
    self.spans_start = spans_start
    self.spans_end = spans_end
    self.padding_len = padding_len
    self.input_tokens = input_tokens
    self.input = input
    self.opers = opers
    self.target_values = target_values
    self.last_state = last_state
    self.cur_state = cur_state
    self.refer = refer
    self.inform_aux_features = inform_aux_features
    self.ds_aux_features = ds_aux_features

  def __str__(self):
    """
    Return a string representation of the DataInstance object.

    :return: A string representation of the object.
    :rtype: str
    """
    string = ''
    string = string + '\nids: ' + str(self.ids)
    string = string + '\nmask: ' + str(self.mask)
    string = string + '\ntoken_type_ids: ' + str(self.token_type_ids)
    string = string + '\nspans: ' + str(self.spans)
    string = string + '\nspans_start: ' + str(self.spans_start)
    string = string + '\nspans_end: ' + str(self.spans_end)
    string = string + '\npadding_len: ' + str(self.padding_len)
    string = string + '\ninput_tokens: ' + str(self.input_tokens)
    string = string + '\ninput: ' + str(self.input)
    string = string + '\nopers: ' + str(self.opers)
    string = string + '\ntarget_values: ' + str(self.target_values)
    string = string + '\nlast_state: ' + str(self.last_state)
    string = string + '\ncur_state: ' + str(self.cur_state)
    string = string + '\nrefer: ' + str(self.refer)
    string = string + '\ninform_aux_features: ' + str(self.inform_aux_features)
    string = string + '\nds_aux_features: ' + str(self.ds_aux_features)

    return string

class AverageMeter():
    """
    Computes and stores the average and current value.
    """
    def __init__(self):
        self.reset()

    def reset(self):
        """
        Reset the average meter.
        """
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        """
        Update the average meter with a new value.

        :param val: New value.
        :type val: float
        :param n: Number of instances the value represents.
        :type n: int
        """
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def normalize(text, multiwoz):
  """
  Normalize the given text by converting it to lowercase and splitting it into tokens.

  :param text: Input text.
  :type text: str
  :return: Normalized tokens.
  :rtype: list[str]
  """
  text_lower = text.lower()
  if multiwoz == True:
    text_norm = normalize_text(text_lower) # for mutliwoz only
  else:
    text_norm = text_lower
  text_tok = [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", text_norm)) if len(tok) > 0]
  return text_tok


def is_included(value, target):
  """
  Check if the target is included in the value.

  :param value: The value to check.
  :type value: str
  :param target: The target value to search for.
  :type target: str
  :return: True if the target is included in the value, False otherwise.
  :rtype: bool
  """
  included = False

  value = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0]
  target = [item for item in map(str.strip, re.split("(\W+)", target)) if len(item) > 0]

  for i in range(len(value)):
    if value[i:i + len(target)] == target:
      included = True

  return included


def included_with_label_maps(value, target, label_maps):
  """
  Check if the value is included in the target or any of its variants based on the label maps.

  :param value: The value to check.
  :type value: str
  :param target: The target value to search for.
  :type target: str
  :param label_maps: Dictionary of label maps.
  :type label_maps: dict[str, list[str]]
  :return: True if the value is included in the target or any of its variants, False otherwise.
  :rtype: bool
  """
  included = False

  variants = [target]
  if target in label_maps:
    variants += label_maps[target]

  for variant in variants:
    if value == variant or is_included(value, variant) or is_included(variant, value):
      included = True

  return included


def match_with_label_maps(value, target, label_maps={}):
    """
    Check if the value matches the target or any of its variants based on the label maps.

    :param value: The value to check.
    :type value: str
    :param target: The target value to match against.
    :type target: str
    :param label_maps: Dictionary of label maps.
    :type label_maps: dict[str, list[str]]
    :return: True if the value matches the target or any of its variants, False otherwise.
    :rtype: bool
    """
    equal = False
    if value == target:
      equal = True
    elif target in label_maps:
      for variant in label_maps[target]:
        if value == variant:
          equal = True

    return equal


def create_span_output(output_start, output_end, padding_len, input_tokens):
  """
  Create the span output based on the output start and end positions.

  :param output_start: Output start positions.
  :type output_start: list[int]
  :param output_end: Output end positions.
  :type output_end: list[int]
  :param padding_len: Padding length.
  :type padding_len: int
  :param input_tokens: Input tokens.
  :type input_tokens: str
  :return: The created span output.
  :rtype: str
  """
  mask = [0] * (len(output_start) - padding_len)

  if padding_len > 0:
    idx_start = np.argmax(output_start[1:-padding_len]) + 1
    idx_end = np.argmax(output_end[1:-padding_len]) + 1
  else:
    idx_start = np.argmax(output_start[1:]) + 1
    idx_end = np.argmax(output_end[1:]) + 1

  for mj in range(idx_start, idx_end + 1):
    mask[mj] = 1

  output_tokens = [x for p, x in enumerate(input_tokens.split()) if mask[p] == 1]
  output_tokens = [x for x in output_tokens if x not in ('[CLS]', '[SEP]')]

  final_output = ''
  for ot in output_tokens:
    if ot.startswith('##'):
      final_output = final_output + ot[2:]
    elif len(ot) == 1 and ot in string.punctuation:
      final_output = final_output + ot
    elif len(final_output) > 0 and final_output[-1] in string.punctuation:
      final_output = final_output + ot
    else:
      final_output = final_output + " " + ot

  final_output = final_output.strip()

  return final_output


def mask_utterance(utter, inform_mem, multiwoz, replace_with='[UNK]'):
  """
  Mask the utterance by replacing the informed values in the inform memory.

  :param utter: The utterance to mask.
  :type utter: list[str]
  :param inform_mem: The inform memory containing slot-value pairs.
  :type inform_mem: dict[str, list[str]]
  :param replace_with: The replacement token.
  :type replace_with: str
  :return: The masked utterance.
  :rtype: list[str]
  """
  utter = normalize(utter, multiwoz)
  for slot, informed_values in inform_mem.items():
    for informed_value in informed_values:
      informed_tok = normalize(informed_value, multiwoz)
      for i in range(len(utter)):
        if utter[i:i + len(informed_tok)] == informed_tok:
          utter[i:i + len(informed_tok)] = [replace_with] * len(informed_tok)
  return utter


def normalize_time(text):
    """
    Normalize the time format in the given text (specific to MultiWoz dataset).

    :param text: The input text.
    :type text: str
    :return: The normalized text.
    :rtype: str
    """

    # This code is only related to MultiWoz Dataset

    text = re.sub("(\d{1})(a\.?m\.?|p\.?m\.?)", r"\1 \2", text) # am/pm without space
    text = re.sub("(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)", r"\1\2:00 \3", text) # am/pm short to long form
    text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)", r"\1\2 \3:\4\5", text) # Missing separator
    text = re.sub("(^| )(\d{2})[;.,](\d{2})", r"\1\2:\3", text) # Wrong separator
    text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)", r"\1\2 \3:00\4", text) # normalize simple full hour time
    text = re.sub("(^| )(\d{1}:\d{2})", r"\g<1>0\2", text) # Add missing leading 0
    # Map 12 hour times to 24 hour times
    text = re.sub("(\d{2})(:\d{2}) ?p\.?m\.?", lambda x: str(int(x.groups()[0]) + 12 if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text)
    text = re.sub("(^| )24:(\d{2})", r"\g<1>00:\2", text) # Correct times that use 24 as hour
    return text


def normalize_text(text):
    """
    Normalize the text (specific to MultiWoz dataset).

    :param text: The input text.
    :type text: str
    :return: The normalized text.
    :rtype: str
    """

    # This code is only related to MultiWoz Dataset

    text = normalize_time(text)
    text = re.sub("n't", " not", text)
    text = re.sub("(^| )zero(-| )star([s.,? ]|$)", r"\g<1>0 star\3", text)
    text = re.sub("(^| )one(-| )star([s.,? ]|$)", r"\g<1>1 star\3", text)
    text = re.sub("(^| )two(-| )star([s.,? ]|$)", r"\g<1>2 star\3", text)
    text = re.sub("(^| )three(-| )star([s.,? ]|$)", r"\g<1>3 star\3", text)
    text = re.sub("(^| )four(-| )star([s.,? ]|$)", r"\g<1>4 star\3", text)
    text = re.sub("(^| )five(-| )star([s.,? ]|$)", r"\g<1>5 star\3", text)
    text = re.sub("archaelogy", "archaeology", text) # Systematic typo
    text = re.sub("guesthouse", "guest house", text) # Normalization
    text = re.sub("(^| )b ?& ?b([.,? ]|$)", r"\1bed and breakfast\2", text) # Normalization
    text = re.sub("bed & breakfast", "bed and breakfast", text) # Normalization
    text = re.sub("\t", " ", text) # Error
    text = re.sub("\n", " ", text) # Error
    return text


# deep_TODS

## utils

In [15]:
"""
This module ontains utility code used by the deep TODS module
such as Natural Language Understanding (NLU), Dialogue State Tracker (DST) ... etc.
"""

import random

class PriorityDP:
    """
    Dialogue Policy Optimizer that selects the next action based on priority.

    This policy selects the action with the highest priority that does not conflict with
    the slots already filled in the dialogue state.

    """
    def get_action(self, state, templates_slots):
        """
        Get the next action based on the given dialogue state and available action templates.

        :param state: The current dialogue state.
        :type state: dict

        :param templates_slots: The available different combinations of slots that can be filled by templates.
        :type templates_slots: list[tuple]

        :return: The index of the selected action template, or None if no action is available.
        :rtype: int or None
        """

        filled = []
        for slot, value in state.items():
          if value is not None:
            filled.append(slot)

        filled = tuple(sorted(filled))

        top_idx = -1
        for idx, slots in enumerate(templates_slots):
          if all(element not in slots for element in filled):
            top_idx = idx
            break

        action = None
        if top_idx != -1:
          action = top_idx

        return action

    def __str__(self):
      """
      Return a string representation of the PriorityDP policy.

      :return: A string representation of the PriorityDP policy.
      :rtype: str
      """
      string = ''
      string = string + '\nPriorityDP'
      return string


class RandomDP:
    """
    Dialogue Policy Optimizer that selects the next action randomly.

    This policy selects the action randomly from the available action templates
    that do not conflict with the slots already filled in the dialogue state.

    """
    def get_action(self, state, templates_slots):
        """
        Get the next action based on the given dialogue state and available action templates.

        :param state: The current dialogue state.
        :type state: dict

        :param templates_slots: The available different combinations of slots that can be filled by templates.
        :type templates_slots: list[tuple]

        :return: The index of the selected action template, or None if no action is available.
        :rtype: int or None
        """
        filled = []
        for slot, value in state.items():
          if value is not None:
            filled.append(slot)

        filled = tuple(sorted(filled))

        candidates = []
        for idx, slots in enumerate(templates_slots):
          if all(element not in slots for element in filled):
            candidates.append(idx)

        # print('candidates', candidates)

        max_len = len(candidates)

        action = None

        if max_len > 0:
          action = candidates[random.randint(0, max_len-1)]

        return action

    def __str__(self):
      """
      Return a string representation of the RandomDP policy.

      :return: A string representation of the RandomDP policy.
      :rtype: str
      """
      string = ''
      string = string + '\nRandomDP'
      return string


class TemplateBasedNLG:
    """
    Natural Language Generation module that generates responses based on predefined templates.

    This module uses a set of predefined templates containing utterances and corresponding system acts.
    Given an index, it generates the corresponding system utterance and system act.

    :param templates: The predefined templates for generating responses.
    :type templates: list[dict]
    """

    def __init__(self, templates):
      self.templates = templates
      self.templates_slots = []

      for template in templates:
        self.templates_slots.append(tuple(sorted(template['slots'])))

    def get_templates(self):
      """
      Get the predefined templates.

      :return: The predefined templates.
      :rtype: list[dict]
      """
      return self.templates

    def get_templates_slots(self):
      """
      Get the slots associated with the predefined templates.

      :return: The slots associated with the predefined templates.
      :rtype: list[tuple]
      """
      return self.templates_slots

    def generate(self, idx):
      """
      Generate a response based on the given index.

      :param idx: The index of the template to generate a response from.
      :type idx: int

      :return: The generated system utterance and system act.
      :rtype: tuple[str, list] or None
      """

      if idx < 0 or idx >= len(self.templates):
        return None, None

      return self.templates[idx]['utterance'], self.templates[idx]['system_act']

    def __str__(self):
      """
      Return a string representation of the TemplateBasedNLG module.

      :return: A string representation of the TemplateBasedNLG module.
      :rtype: str
      """
      string = ''
      string = string + '\ntemplates: ' + str(self.templates)
      string = string + '\ntemplates_slots: ' + str(self.templates_slots)
      return string

## deep_TODS

In [16]:
"""
This module contains the base code and interface of deep TODS .
"""

# from botiverse.bots.deep_TODS.utils import RandomDP, PriorityDP, TemplateBasedNLG
# from botiverse.models.TRIPPY.config import TRIPPYConfig
# from botiverse.models.TRIPPY.TRIPPY_DST import TRIPPYDST

import random

class DeepTODS:
  """
  Instantiate a Deep Task Oriented Dialogue System chat bot.
  It aims to assist the user in completing certain tasks in specific domains.
  The chat bot can use a Deep learning approach for training and inference.

  :param name: The chatbot's name.
  :type name: str

  :param domains: List of domain names.
  :type domains: list[str]

  :param ontology_path: Path to the ontology file.
  :type ontology_path: str

  :param label_maps_path: Path to the label maps file.
  :type label_maps_path: str

  :param policy: The dialogue policy to be used ('Random' or 'Priority').
  :type policy: str

  :param start: List of initial system utterances and corresponding system acts.
  :type start: list[dict]

  :param templates: The predefined templates for generating responses.
  :type templates: list[dict]

  :param non_referable_slots: List of non-referable slots, defaults to an empty list.
  :type non_referable_slots: list[str]

  :param non_referable_pairs: List of non-referable slot-value pairs, defaults to an empty list.
  :type non_referable_pairs: list[tuple]

  :param from_scratch: Indicates whether to use BERT model implemented from scratch in the library, defaults to False.
  :type from_scratch: bool
  """
  def __init__(self, name, domains, ontology_path, label_maps_path, policy, start, templates, non_referable_slots=[], non_referable_pairs=[], from_scratch=False, BERT_config=BERTConfig(), TRIPPY_config=TRIPPYConfig()):
    self.name = name
    self.domains = domains
    self.policy = policy
    self.start = start
    self.is_start = True
    self.dst = TRIPPYDST(domains, ontology_path, label_maps_path, non_referable_slots, non_referable_pairs, from_scratch, BERT_config, TRIPPY_config)
    self.dpo = RandomDP() if policy == 'Random' else PriorityDP() if policy == 'Priority' else None
    self.nlg = TemplateBasedNLG(templates)
    self.sys_utter = ''
    self.inform_mem = {}

  def train(self, train_path, dev_path, test_path, model_path):
    """
    Train the chatbot model with the given training data.

    :param train_path: Path to the training data file.
    :type train_path: str

    :param dev_path: Path to the development data file.
    :type dev_path: str

    :param test_path: Path to the testing data file.
    :type test_path: str

    :param model_path: Path to save the trained model.
    :type model_path: str
    """
    self.dst.train(train_path, dev_path, test_path, model_path)

  def load_dst_model(self, model_path, test_path=None):
    """
    Load a trained DST model from the given path.

    :param model_path: Path to the trained DST model.
    :type model_path: str

    :param test_path: Path to the testing data file, if applicable, defaults to None.
    :type test_path: str
    """
    self.dst.load_model(model_path, test_path)

  def infer(self, user_utter):
    """
    Infer a suitable response to the user's utterance.

    :param user_utter: The user's input utterance.
    :type user_utter: str

    :return: The chatbot's response.
    :rtype: str
    """
    response = None

    if self.is_start and len(self.start) > 0:
      temp = self.start[random.randint(0, len(self.start)-1)]
      response, inform_mem = temp['utterance'], temp['system_act']
      self.sys_utter = response
      self.inform_mem = inform_mem
    else:
      state = self.dst.update_state(self.sys_utter, user_utter, self.inform_mem)
      action = self.dpo.get_action(state, self.nlg.get_templates_slots())
      if action is not None:
        response, inform_mem = self.nlg.generate(action)
        self.sys_utter = response
        self.inform_mem = inform_mem

    self.is_start = False
    return response

  def suggest(self, template):
    """
    Set the system utterance and system act to suggest a specific response.

    :param template: The template containing the suggested system utterance and system act.
    :type template: dict
    """
    self.sys_utter = template['utterance']
    self.inform_mem = template['system_act']

  def get_dialogue_state(self):
    """
    Get the dialogue state.

    :return: The dialogue state.
    :rtype: dict
    """
    return self.dst.get_dialogue_state()

  def delete_slots(self, domain=None, slot=None):
    """
    Delete slots from the dialogue state.

    Note that:
    if domain!=None will delete all slots in that domain.
    if slot!=None will delete that slot.
    if both are None will delete all slots in all domains.

    :param domain: The domain from which to delete slots, defaults to None.
    :type domain: str

    :param slot: The slot to delete, defaults to None.
    :type slot: str
    """
    self.dst.delete_slots(domain, slot)

  def reset(self):
    """
    Reset the chatbot's state.
    """
    self.dst.reset()
    self.sys_utter = ''
    self.inform_mem = {}
    self.domain = self.domains[0]
    self.is_start = True

  def __str__(self):
    """
    Return a string representation of the chatbot.

    :return: A string representation of the chatbot.
    :rtype: str
    """
    string = ''
    string = string + '\nname: ' + str(self.name)
    string = string + '\ndomains: ' + str(self.domains)
    string = string + '\npolicy: ' + str(self.policy)
    string = string + '\nstart: ' + str(self.start)
    string = string + '\nis_start: ' + str(self.is_start)
    string = string + '\n\ndst: ' + str(self.dst)
    string = string + '\n\ndpo: ' + str(self.dpo)
    string = string + '\n\nnlg: ' + str(self.nlg)
    string = string + '\n\nsys_utter: ' + str(self.sys_utter)
    string = string + '\ninform_mem: ' + str(self.inform_mem)
    return string

# Demo

## Constants

In [17]:
CHATBOT_NAME = 'Tody'

DOMAINS = ["restaurant"]

ONTOLOGY_PATH = './Woz2/fixed/ontology.json'

LABEL_MAPS_PATH = './Woz2/fixed/label_maps.json'

TRAIN_DATA_PATH = './Woz2/fixed/train_dials.json'

DEV_DATA_PATH = './Woz2/fixed/dev_dials.json'

TEST_DATA_PATH = './Woz2/fixed/test_dials.json'

POLICY = 'Priority' # Priority or Random

START = [
    {
        'utterance': 'Hi I am Tody, I can help you reserve a restaurant?',
        'slots': [],
        'system_act': {}
    }
]

TEMPLATES = [
    {
        'utterance': 'what area do you want?',
        'slots': ['restaurant-area'],
        'system_act': {}
    },
    {
        'utterance': 'what is your preferred price range?',
        'slots': ['restaurant-price_range'],
        'system_act': {}
    },
    {
        'utterance': 'What kind of food do you want to eat?',
        'slots': ['restaurant-food'],
        'system_act': {}
    }

]

NON_REFERABLE_SLOTS = []

NON_REFERABLE_PAIRS = []

FROM_SCRATCH = True

MODEL_PATH = './Models/model.pt'

TRIPPY_CONFIG = TRIPPYConfig(multiwoz=False, epochs=5)

BERT_CONFIG = BERTConfig()

## Create chatbot

In [18]:
tods = DeepTODS(CHATBOT_NAME, DOMAINS, ONTOLOGY_PATH, LABEL_MAPS_PATH, POLICY, START, TEMPLATES, NON_REFERABLE_SLOTS, NON_REFERABLE_PAIRS, FROM_SCRATCH, BERT_CONFIG, TRIPPY_CONFIG)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## train the chatbot

In [19]:
tods.train(TRAIN_DATA_PATH, DEV_DATA_PATH, TEST_DATA_PATH, MODEL_PATH)

Preprocessing train set...


100%|██████████| 2536/2536 [00:05<00:00, 481.99it/s]


Preprocessing dev set...


100%|██████████| 830/830 [00:01<00:00, 801.95it/s]


Preprocessing test set...


100%|██████████| 1646/1646 [00:02<00:00, 714.80it/s]



Epoch: 0 ---------------------------------------------------------------
Training the model...


100%|██████████| 80/80 [00:52<00:00,  1.52it/s, loss=85.8]


Evaluating the model on dev set...


100%|██████████| 830/830 [00:12<00:00, 66.10it/s]


Joint Goal Acc: 0.6397590361445783
Per Slot Acc: {'restaurant-area': 0.8879518072289156, 'restaurant-food': 0.8674698795180723, 'restaurant-price_range': 0.844578313253012}
Macro F1 Score: 0.6629450218742398
All f1 score = [0.94145902 0.         0.94623656 0.76408451]

Epoch: 1 ---------------------------------------------------------------
Training the model...


100%|██████████| 80/80 [00:53<00:00,  1.50it/s, loss=19.5]


Evaluating the model on dev set...


100%|██████████| 830/830 [00:12<00:00, 64.78it/s]


Joint Goal Acc: 0.8024096385542169
Per Slot Acc: {'restaurant-area': 0.946987951807229, 'restaurant-food': 0.9240963855421687, 'restaurant-price_range': 0.908433734939759}
Macro F1 Score: 0.8569038731906121
All f1 score = [0.96505212 0.64583333 0.97120159 0.84552846]

Epoch: 2 ---------------------------------------------------------------
Training the model...


100%|██████████| 80/80 [00:54<00:00,  1.47it/s, loss=9.38]


Evaluating the model on dev set...


100%|██████████| 830/830 [00:12<00:00, 64.92it/s]


Joint Goal Acc: 0.8409638554216867
Per Slot Acc: {'restaurant-area': 0.9530120481927711, 'restaurant-food': 0.9578313253012049, 'restaurant-price_range': 0.9192771084337349}
Macro F1 Score: 0.9056948296028039
All f1 score = [0.98000615 0.72727273 0.9851925  0.93030794]

Epoch: 3 ---------------------------------------------------------------
Training the model...


100%|██████████| 80/80 [00:55<00:00,  1.44it/s, loss=5.14]


Evaluating the model on dev set...


100%|██████████| 830/830 [00:11<00:00, 71.36it/s]


Joint Goal Acc: 0.8819277108433735
Per Slot Acc: {'restaurant-area': 0.9662650602409638, 'restaurant-food': 0.9734939759036144, 'restaurant-price_range': 0.9349397590361446}
Macro F1 Score: 0.9367088980586386
All f1 score = [0.98426412 0.83333333 0.98711596 0.94212219]

Epoch: 4 ---------------------------------------------------------------
Training the model...


100%|██████████| 80/80 [00:56<00:00,  1.43it/s, loss=2.81]


Evaluating the model on dev set...


100%|██████████| 830/830 [00:12<00:00, 65.65it/s]


Joint Goal Acc: 0.8963855421686747
Per Slot Acc: {'restaurant-area': 0.9710843373493976, 'restaurant-food': 0.9795180722891567, 'restaurant-price_range': 0.9409638554216867}
Macro F1 Score: 0.9392466766749592
All f1 score = [0.98516687 0.83636364 0.99011858 0.94533762]
Loading best model on dev set...
Evaluating the model on test set...


100%|██████████| 1646/1646 [00:27<00:00, 60.44it/s]


Joint Goal Acc: 0.9009720534629405
Per Slot Acc: {'restaurant-area': 0.9939246658566221, 'restaurant-food': 0.9586877278250304, 'restaurant-price_range': 0.945321992709599}
Macro F1 Score: 0.9537028346064818
All f1 score = [0.98610002 0.88301887 0.98693759 0.95875486]


# load saved model

In [20]:
tods.load_dst_model(MODEL_PATH, TEST_DATA_PATH)

Model loaded successfully.
Preprocessing the data...


100%|██████████| 1646/1646 [00:01<00:00, 1449.84it/s]


Evaluating the model on the data...


100%|██████████| 1646/1646 [00:24<00:00, 67.99it/s]


Joint Goal Acc: 0.9009720534629405
Per Slot Acc: {'restaurant-area': 0.9939246658566221, 'restaurant-food': 0.9586877278250304, 'restaurant-price_range': 0.945321992709599}
Macro F1 Score: 0.9537028346064818
All f1 score = [0.98610002 0.88301887 0.98693759 0.95875486]


# infer

In [34]:
tods.reset()
print(tods.get_dialogue_state())

{}


In [35]:
response = tods.infer('')
print(tods.get_dialogue_state())
print(response)

{}
Hi I am Tody, I can help you reserve a restaurant?


In [36]:
response = tods.infer('Hi, I want to book a table in city center.')
print(tods.get_dialogue_state())
print(response)

{'restaurant-area': 'center'}
what is your preferred price range?


In [38]:
response = tods.infer('a cheap egyptain restaurant.')
print(tods.get_dialogue_state())
print(response)

{'restaurant-area': 'center', 'restaurant-food': 'egyptain', 'restaurant-price_range': 'cheap'}
None
