In [1]:
import os
import ast
import json
import pandas as pd

from tqdm import tqdm
from langchain import PromptTemplate

from src.DST.evaluate_utils import remapping
from src.DST.dst import SLOTS_DESCRIPTIONS
from src.config import CONFIG

from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from src.DST.evaluate_utils import unpack_belief_states


pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_colwidth', 500)



@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to utilize.
    """
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={"help": "The path of the HuggingFace model."}
    )
    use_int8: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use int8 model or not."}
    )
    use_deepspeed: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use deepspeed model or not."}
    )
    

@dataclass
class DataArguments:
    """
    Arguments pertaining to the data loading and preprocessing pipeline.
    """
    dataset_name: Optional[str] = field(
        default=None,
        metadata={"help": "Train dataset path"}
    )
    dataset_names: Optional[str] = field(
        default=None,
        metadata={"help": "Train dataset paths"}
    )
    root_data_path: Optional[str] = field(
        default="./data", metadata={"help": "The path to the data directory."},
    )
    mwoz_path: Optional[str] = field(
        default="/home/willy/instructod/MultiWOZ_2.1/",
        metadata={"help": "MWOZ path"}
    )
    dialog_history_limit_dst: Optional[int] = field(
        default=0,
        metadata={"help": "Lenght of dialogue history for dst"}
    )
    dialog_history_limit_dst_recorrect: Optional[int] = field(
        default=0,
        metadata={"help": "Lenght of dialogue history for dst update"}
    )
    dialog_history_limit_rg: Optional[int] = field(
        default=20,
        metadata={"help": "Lenght of dialogue history for response generation"}
    )
    dialog_history_limit_e2e: Optional[int] = field(
        default=20,
        metadata={"help": "Lenght of dialogue history for e2e"}
    )
    single_domain_only: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to keep only the single domain sample or not"}
    )
    with_slot_description: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use slot description or not for DST"}
    )
    with_req_inf_differentiation: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to differentiate between require and inform slot for DST"}
    )
    with_all_slots: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use all slots or not"}
    )
    debug_mode: Optional[bool] = field(
        default=False,
        metadata={"help": "debug mode to only try 20 samples"}
    )
    start_idx: Optional[int] = field(
        default=0,
        metadata={"help": "Starting index to restart the prediction if needed"}
    )
    save_path: Optional[str] = field(
        default="results/",
        metadata={"help": "save path"}
    )
    save_every: Optional[int] = field(
        default=5,
        metadata={"help": "every step to save in case api fail"}
    )
    db_format_type: Optional[str] = field(
        default="1",
        metadata={"help": "1 is more precise, 2 is more concise for db integration"},
    )

@dataclass
class PromptingArguments(TrainingArguments):
    """
    Arguments pertraining to the prompting pipeline.
    """
    output_dir: Optional[str] = field(
        default="./out",
        metadata={"help": "Output directory"},
    )
    task: Optional[str] = field(
        default="dst",
        metadata={"help": "Task to perform"}
    )
    max_requests_per_minute: Optional[int] = field(
        default=20,
        metadata={"help": "Max number of requests for OpenAI API."}
    )
    openai_api_key_name: Optional[str] = field(
        default="OPENAI_API_KEY",
        metadata={"help": "OpenAI API key name."}
    )


data_args = DataArguments()
data_args.dialog_history_limit_e2e = -1
data_args.dialog_history_limit_rg = -1


In [2]:
class PromptConstructor():
    def __init__(self, 
                 config):
        self.config = config
        self.instructions = config["INSTRUCTIONS"]
        self.prompt_templates = config["PROMPT_TEMPLATES"]
        
    def _get_slots_from_domains(self, domains, with_slot_description, with_req_inf_differentiation, with_all_slots):
        # slot_description = self.config["slot_descrpition"]
        if with_all_slots:
            domains = "all"

        if with_slot_description:
            with_req_inf_differentiation = False #Slot description is the discriminator

        if domains == "all":
            if with_req_inf_differentiation:
                req_slots = ", ".join(self.config["multiwoz21"]["all_requestable_slots"])
                inf_slots = ", ".join(self.config["multiwoz21"]["all_informable_slots"])
            else:
                slots = set(self.config["multiwoz21"]["all_requestable_slots"] + 
                            self.config["multiwoz21"]["all_informable_slots"])
                slots = ", ".join(slots)
        elif not isinstance(domains, list):
            raise ValueError("""Provided domain should be either 'all' or list of valid domain names:
                                - for multiwoz2.1 and 2.4: taxi, restaurant, hotel, train, attraction 
                                - for SGD: To-do""")
        else:
            req_slots = ""
            inf_slots = ""
            domain_req_slots = []
            domain_inf_slots = []
            for domain in domains:
                domain_req_slots += self.config["multiwoz21"]["requestable_slots"][domain]
                domain_inf_slots += self.config["multiwoz21"]["informable_slots"][domain]
            if with_req_inf_differentiation:
                domain_req_slots = set(domain_req_slots)
                domain_inf_slots = set(domain_inf_slots)
                req_slots += ", ".join(domain_req_slots)
                inf_slots += ", ".join(domain_inf_slots)
            else:
                slots = set(domain_req_slots + domain_inf_slots)
                slots = ", ".join(slots)

        if with_req_inf_differentiation:
            slots_info = f"Requestable slots: {req_slots}\nInformable slots: {inf_slots}"
        else:
            slots_info = f"{slots}"

        if with_slot_description:
            slots = slots.split(", ")
            slots_info = ""
            for slot in slots:
                if slot not in self.config["multiwoz21"]["all_informable_slots"]:
                    continue
                slots_info += f"name: {slot}, description: {SLOTS_DESCRIPTIONS[slot]}\n"
            slots_info = slots_info[:-2]
        
        return slots_info
    
    
    def _build_prompt(self, mode="", dialogue_context="", ontology="", slots="", dialogue_acts="", belief_states="", database=""):
        prompt = ""
        if mode == "dst":
            instruction = self.instructions["instruction_with_slots"]
            template_variables = self.prompt_templates["template_with_slots"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     slots=slots,
                                     dialogue_context=dialogue_context)
            
        elif mode == "dst_recorrect":
            instruction = self.instructions["instruction_with_slots_recorrect"]
            template_variables = self.prompt_templates["template_with_slots_recorrect"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])            
            prompt = template.format(instruction=instruction,
                                    slots=slots,
                                    dialogue_context=dialogue_context,
                                    belief_states=belief_states)
            
        elif mode == "database_query":
            instruction = self.instructions["instruction_query_database"]
            template_variables = self.prompt_templates["template_query_database"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     belief_states=belief_states)
            
        elif mode == "response_generation":
            example = self.config["EXAMPLES"]["response_generation"]
            
            instruction = self.instructions["instruction_response_generation"]
            template_variables = self.prompt_templates["template_response_generation"]
            template = PromptTemplate(input_variables = template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     example=example,
                                     dialogue_context=dialogue_context)
        elif mode == "e2e":
            instruction = self.instructions["instruction_e2e"]
            template_variables = self.prompt_templates["template_e2e"]
            template = PromptTemplate(input_variables = template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     database=database,
                                     dialogue_context=dialogue_context)

        else:
            raise ValueError("'mode' should be one of: [dst, dst_recorrect, database_query, response_generation, e2e]")
        
        return prompt


class MWOZ_Dataset(PromptConstructor):
    def __init__(self,
                 config,
                 data_args):
        PromptConstructor.__init__(self, config)
        self.dataset = {"id":[],
                        "dialogue_id":[],
                        "dialogue_context":[],
                        "turn":[],
                        "prompt_dst":[],
                        "prompt_dst_update":[],
                        "prompt_rg":[],
                        "prompt_e2e":[],
                        "domains":[],
                        "turn_domain":[],
                        "gold_turn_bs":[],
                        "gold_bs":[],
                        "gold_act":[],
                        "gold_response":[],
                        "gold_database_result":[],
                        }
        
        print("Loading data...")
        self.all_data, self.testfiles, self.system_acts = self._get_mwoz_data(data_args.mwoz_path)
        print("Loading databases...")
        self.dbs_lexicalized = self._get_dbs_lexicalized(data_args.mwoz_path)
        self.idx = 0
        self.dialog_history_limit_dst = data_args.dialog_history_limit_dst
        self.dialog_history_limit_rg = data_args.dialog_history_limit_rg
        self.dialog_history_limit_e2e = data_args.dialog_history_limit_e2e
        self.single_domain_only = data_args.single_domain_only
        self.with_slot_description = data_args.with_slot_description
        self.with_req_inf_differentiation = data_args.with_req_inf_differentiation
        self.with_all_slots = data_args.with_all_slots
        self.all_domains = ["restaurant", "taxi", "hotel", "train", "attraction"]

        print("Processing mwoz...")
        for sample in tqdm(self.all_data):
            if sample in self.testfiles:
                dialogue_log = self.all_data[sample]["log"]
                self._process_dialogue_log(sample=sample,
                                           dialogue_log=dialogue_log)

        self.dataset = pd.DataFrame(self.dataset)
        if self.single_domain_only:
            for index, row in tqdm(self.dataset.iterrows()):
                if len(row["domains"]) != 1:
                    self.dataset.drop(index, inplace=True)

                    
    def _get_mwoz_data(self, mwoz_path):
        data_path = os.path.join(mwoz_path, "data.json")
        testListFile_path = os.path.join(mwoz_path, "testListFile.txt")
        system_acts_path = os.path.join(mwoz_path, "system_acts.json")

        with open(data_path, "r") as f:
            all_data = json.load(f)
            
        with open(testListFile_path, "r") as f:
            testfiles = f.read()
        testfiles = testfiles.split("\n")
        
        with open(system_acts_path, "r") as f:
            system_acts = json.load(f)
            
        return all_data, testfiles, system_acts
    
    def _get_dbs_lexicalized(self, mwoz_path):
        domains = ["restaurant", "hotel", "train", "attraction"]
        keep_data = {"restaurant":["address", "area", "food", "name", "pricerange", "phone", "postcode"],
                    "attraction":["name", "area", "address", "type", "postcode"],
                    "hotel":["name", "address", "area", "phone", "postcode", "pricerange", "stars"],
                    "train":["departure", "destination"]}
        dbs_lexicalized = {}
        for domain in domains:
            db_path = os.path.join(mwoz_path, f"{domain}_db.json")
            with open(db_path, "r") as f:
                db_data = json.load(f)

            db_lexicalized = []
            for row in db_data:
                row_keep = []
                for key in keep_data[domain]:
                        if key in row:
                            row_keep.append(f"{key}: {row[key]}")
                db_lexicalized.append(", ".join(row_keep))
            dbs_lexicalized[domain] = "\n".join(set(db_lexicalized))

        return dbs_lexicalized
    
    def _process_dialogue_log(self, sample, dialogue_log):

        dialog_history_memory_dst = []
        dialog_history_memory_rg = []
        dialog_history_memory_e2e = []
        dialog_history_dst = ""
        dialog_history_rg = ""
        dialog_history_e2e = ""
        turn_domain = ""
        domains = self._get_domains_from_log(dialogue_log)
        slots = self._get_slots_from_domains(domains, 
                                             self.with_slot_description,
                                             self.with_req_inf_differentiation,
                                             self.with_all_slots) # or all

        for turn_nb, turn in enumerate(dialogue_log):

            if turn_nb % 2 == 0:
                speaker = "USER"
            else:
                speaker = "SYSTEM"
            
            utterance = f"""{speaker}: {turn["text"]}\n"""
            dialog_act = turn["dialog_act"]
            
            dialogue_context_dst = dialog_history_dst + utterance
            prompt_dst = self._build_prompt(mode="dst",
                                            slots=slots,
                                            dialogue_context=dialogue_context_dst)
            
            lexicalized_act = self._lexicalize_act(dialog_act)
            dialogue_context_rg = dialog_history_rg + utterance + f"ACT:{lexicalized_act}\nSYSTEM:"
            prompt_rg = self._build_prompt(mode="response_generation",
                                            dialogue_context=dialogue_context_rg)
            
            dialogue_context_e2e = dialog_history_e2e + utterance + "SYSTEM:"
            # need to have utterance level domain here
            cur_system_act = self.system_acts[sample.split(".")[0]][str((turn_nb//2)+1)]
            turn_domain = self._get_domain_from_turn(turn_domain, cur_system_act)
            if turn_domain and turn_domain != "taxi":
                database = self.dbs_lexicalized[turn_domain]
            else:
                database = ""
            prompt_e2e = self._build_prompt(mode="e2e",
                                            database=database,
                                            dialogue_context=dialogue_context_e2e)

            dialog_history_dst, dialog_history_memory_dst = self._update_dialogue_memory(utterance, 
                                                                                         dialogue_log, 
                                                                                         self.dialog_history_limit_dst, 
                                                                                         dialog_history_memory_dst)
            dialog_history_rg, dialog_history_memory_rg = self._update_dialogue_memory(utterance, 
                                                                                       dialogue_log, 
                                                                                       self.dialog_history_limit_rg,
                                                                                       dialog_history_memory_rg)
            dialog_history_e2e, dialog_history_memory_e2e = self._update_dialogue_memory(utterance, 
                                                                                         dialogue_log, 
                                                                                         self.dialog_history_limit_e2e, 
                                                                                         dialog_history_memory_e2e) 
                
            metadata = turn["metadata"]
            bspn_dict = {}
            if metadata:
                for domain in metadata:
                    slot_values = metadata[domain]["semi"]
                    for slot in slot_values:
                        value = slot_values[slot]
                        if value and value not in ["not mentioned", "none"]:
                            if domain in bspn_dict:
                                bspn_dict[domain].append(remapping(slot))
                                bspn_dict[domain].append(remapping(value))
                            else:
                                bspn_dict[domain] = [remapping(slot), remapping(value)]
                bspn = " ".join([f"[{domain}] {' '.join(bspn_dict[domain])}" for domain in bspn_dict])

            self.idx += 1
            if turn_nb % 2 == 0:
                self.dataset["gold_turn_bs"].append(dialog_act)
                self.dataset["dialogue_context"].append(dialogue_context_dst)
                self.dataset["gold_database_result"].append(None) 
                self.dataset["turn"].append(turn_nb//2)
                self.dataset["domains"].append(domains)
                self.dataset["id"].append(self.idx//2)
                self.dataset["dialogue_id"].append(sample)
                self.dataset["prompt_dst"].append(prompt_dst)
                self.dataset["prompt_dst_update"].append(prompt_dst)
                self.dataset["prompt_rg"].append(prompt_rg)
                self.dataset["prompt_e2e"].append(prompt_e2e)
                self.dataset["turn_domain"].append(turn_domain)
            else:
                self.dataset["gold_response"].append(utterance)
                self.dataset["gold_bs"].append(bspn)
                self.dataset["gold_act"].append(dialog_act)

    def _update_dialogue_memory(self, utterance, dialogue_log, dialog_history_limit, dialog_history_memory):
        if dialog_history_limit != 0:
            if dialog_history_limit == -1:
                dialog_history_limit = len(dialogue_log)
            if len(dialog_history_memory) >= dialog_history_limit:
                dialog_history_memory.pop(0)
            dialog_history_memory.append(utterance)

        dialog_history = "".join(dialog_history_memory)
        return dialog_history, dialog_history_memory
    
    def _lexicalize_act(self, act):
        lexicalized_acts = []
        lexicalize_mapping = {"leave": "leave time",
                              "arrive":"arrival time",
                              "departure":"departure place",
                              "post":"postcode",
                              "addr":"address"}

        for act, slot_values in act.items():


            if "request" in act.lower():
                requests = []
                for (slot, value) in slot_values:
                    slot = slot.lower()
                    if slot in lexicalize_mapping:
                        slot = lexicalize_mapping[slot]
                    if slot == "none":
                        break
                    else:
                        requests.append(slot)
                if requests:
                    lexicalized_act = "Request the user about " + ", ".join(requests) + "."
                    lexicalized_acts.append(lexicalized_act)

            elif "recommend" in act.lower():
                recommends = []
                for (slot, value) in slot_values:
                    slot, value = slot.lower(), value.lower()
                    if slot in lexicalize_mapping:
                        slot = lexicalize_mapping[slot]
                    if slot == "none":
                        break
                    else:
                        recommends.append(value)
                if recommends:
                    lexicalized_act = "Recommend the user for " + ", ".join(recommends) + "."
                    lexicalized_acts.append(lexicalized_act)

            elif "inform" in act.lower():
                informs = []
                for (slot, value) in slot_values:
                    slot, value = slot.lower(), value.lower()
                    if slot in lexicalize_mapping:
                        slot = lexicalize_mapping[slot]
                    if slot == "none":
                        break
                    else:
                        informs.append(f"the {slot} is {value}")
                if informs:
                    lexicalized_act = "Inform the user that " + ", ".join(informs) + "."  
                    lexicalized_acts.append(lexicalized_act)

            else:
                pass
        if lexicalized_acts:
            return " ".join(lexicalized_acts)
        else:
            return "None"
        
    def _get_domain_from_turn(self, domain, act):
        for k in act:
            turn_domain = k.lower().split("-")[0]
            if turn_domain in self.all_domains:
                return turn_domain
        return domain
            

    def _get_domains_from_log(self, dialogue_log):
        domains = []
        for log in dialogue_log:
            for domain_act in log["dialog_act"]:
                domain = domain_act.split("-")[0].lower()
                if domain in self.all_domains and domain not in domains:
                    domains.append(domain)
        return domains
                

In [3]:
mwoz = MWOZ_Dataset(CONFIG, data_args)
dataset = mwoz.dataset

Loading data...
Loading databases...
Processing mwoz...


100%|██████████| 10438/10438 [00:02<00:00, 3700.20it/s]


In [4]:
df_results = pd.read_csv("/home/willy/instructod/src/DST/results_single/gpt-3.5-turbo_0-end_singleDomainOnlyTrue_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3.csv")
df_results = df_results.rename(columns={'gold_bs':'gold_turn_bs'})
df_results = df_results.merge(dataset[['id', 'turn_domain', 'gold_bs']], on='id', how='left')
df_results.shape

(1059, 16)

In [5]:
import copy

def add_running_accumulated_bs_column(df, mode = 'preds', new_column_suffix=''):

    running_bs_list = []
    new_turn_domains = []
    turn_domains = df['turn_domain']
    dialogue_ids = df['dialogue_id']
    column_name = 'preds' if mode == 'preds' else 'gold_turn_bs'
    if 'gold' in mode:
        mode = 'gold'
    items = df[column_name]
    for i, item in enumerate(items):
        
        # bug correction, take next turn domain when it's not available
        turn_domain = turn_domains[i] if turn_domains[i] != '' else turn_domains[i+1]

        if i == 0:
            running_bs = {}
            running_bs[turn_domain] = {}
        elif dialogue_ids[i] == dialogue_ids[i-1]:
            running_bs = copy.deepcopy(running_bs_list[i-1])
        else:
            running_bs = {}
            running_bs[turn_domain] = {}
            
        if mode == 'preds':
            unpacked_item = unpack_belief_states(item, 'pred')
            if unpacked_item != ['none-none']:
                item_dict = ast.literal_eval(item) 
                if turn_domain not in list(running_bs.keys()):
                    running_bs[turn_domain] = {}
                for item_slot in item_dict.keys():
                    running_bs[turn_domain][item_slot] = item_dict[item_slot]
        elif mode == 'gold':
            unpacked_item = unpack_belief_states(item, 'gold')
            if unpacked_item != ['none-none']:
                item_dict = ast.literal_eval(item) if type(item) != type({}) else item
                item_dict = {items[0]:items[1] for items in list(item_dict.values())[0]}
                print(item_dict)
                if turn_domain not in list(running_bs.keys()):
                    running_bs[turn_domain] = {}
                for item_slot in item_dict.keys():
                    print(item_dict[item_slot])
                    running_bs[turn_domain][item_slot] = item_dict[item_slot]

        running_bs_list.append(running_bs)
        new_turn_domains.append(turn_domain)

    df[mode+'_bs'+new_column_suffix] = running_bs_list
    df['turn_domain'] = new_turn_domains

In [None]:
add_running_accumulated_bs_column(df_results, mode = 'preds')
add_running_accumulated_bs_column(df_results, mode = 'golds', new_column_suffix='_new')

In [7]:
df_results[['dialogue_id', 'turn_domain', 'preds', 'preds_bs', 'gold_bs', 'gold_bs_new']].iloc[390:400]

Unnamed: 0,dialogue_id,turn_domain,preds,preds_bs,gold_bs,gold_bs_new
390,SNG0459.json,restaurant,{'time': '19:00'},"{'restaurant': {'area': 'centre', 'pricerange': 'moderate', 'people': '4', 'time': '19:00', 'day': 'Wednesday'}}",[restaurant] price moderate name yipee noodle bar area centre,"{'restaurant': {'Price': 'moderate', 'Area': 'centre', 'Time': '19:00', 'Day': 'wednesday', 'People': '4'}}"
391,SNG0459.json,restaurant,"{'food': None, 'address': None, 'reference': None, 'name': None, 'area': None, 'day': None, 'postcode': None, 'time': None, 'phone': None, 'people': None, 'pricerange': None}","{'restaurant': {'area': 'centre', 'pricerange': 'moderate', 'people': '4', 'time': '19:00', 'day': 'Wednesday'}}",[restaurant] price moderate name yipee noodle bar area centre,"{'restaurant': {'Price': 'moderate', 'Area': 'centre', 'Time': '19:00', 'Day': 'wednesday', 'People': '4'}}"
392,SNG0459.json,restaurant,"{'food': None, 'address': None, 'reference': None, 'name': None, 'area': None, 'day': None, 'postcode': None, 'time': None, 'phone': None, 'people': None, 'pricerange': None}","{'restaurant': {'area': 'centre', 'pricerange': 'moderate', 'people': '4', 'time': '19:00', 'day': 'Wednesday'}}",[restaurant] price moderate name yipee noodle bar area centre,"{'restaurant': {'Price': 'moderate', 'Area': 'centre', 'Time': '19:00', 'Day': 'wednesday', 'People': '4'}}"
393,SNG0897.json,hotel,{'pricerange': 'moderate'},{'hotel': {'pricerange': 'moderate'}},[hotel] price moderate,{'hotel': {'Price': 'moderate'}}
394,SNG0897.json,hotel,"{'area': None, 'pricerange': 'moderate', 'internet': 'free'}","{'hotel': {'pricerange': 'moderate', 'area': None, 'internet': 'free'}}",[hotel] area dontcare price moderate internet yes,"{'hotel': {'Price': 'moderate', 'Internet': 'yes', 'Area': 'do nt care'}}"
395,SNG0897.json,hotel,"{'stars': '4', 'parking': 'free'}","{'hotel': {'pricerange': 'moderate', 'area': None, 'internet': 'free', 'stars': '4', 'parking': 'free'}}",[hotel] area dontcare parking yes price moderate stars 4 internet yes,"{'hotel': {'Price': 'moderate', 'Internet': 'yes', 'Area': 'do nt care', 'Parking': 'yes', 'Stars': '4'}}"
396,SNG0897.json,hotel,"{'phone': 'requested', 'area': 'requested'}","{'hotel': {'pricerange': 'moderate', 'area': 'requested', 'internet': 'free', 'stars': '4', 'parking': 'free', 'phone': 'requested'}}",[hotel] area dontcare parking yes price moderate stars 4 internet yes,"{'hotel': {'Price': 'moderate', 'Internet': 'yes', 'Area': '?', 'Parking': 'yes', 'Stars': '4', 'Phone': '?'}}"
397,SNG0897.json,hotel,"{'stay': None, 'address': None, 'parking': None, 'reference': None, 'name': None, 'area': None, 'internet': None, 'postcode': None, 'stars': None, 'day': None, 'phone': None, 'people': None, 'pricerange': None, 'type': None}","{'hotel': {'pricerange': 'moderate', 'area': 'requested', 'internet': 'free', 'stars': '4', 'parking': 'free', 'phone': 'requested'}}",[hotel] area dontcare parking yes price moderate stars 4 internet yes,"{'hotel': {'Price': 'moderate', 'Internet': 'yes', 'Area': '?', 'Parking': 'yes', 'Stars': '4', 'Phone': '?'}}"
398,SNG01943.json,hotel,"{'pricerange': 'expensive', 'type': 'guesthouse'}","{'hotel': {'pricerange': 'expensive', 'type': 'guesthouse'}}",[hotel] price expensive type guesthouse,"{'hotel': {'Type': 'guesthouse', 'Price': 'expensive'}}"
399,SNG01943.json,hotel,{'parking': 'free'},"{'hotel': {'pricerange': 'expensive', 'type': 'guesthouse', 'parking': 'free'}}",[hotel] parking yes price expensive type guesthouse,"{'hotel': {'Type': 'guesthouse', 'Price': 'expensive'}}"


## Multidomain

In [8]:
df_results_multidomain = pd.read_csv("/home/willy/instructod/src/DST/results_single/gpt-3.5-turbo_0-end_debugFalse_singleDomainOnlyFalse_withSlotDescriptionTrue_withSlotDifferentiationFalse_withAllSlotsTrue_dialogHistoryLimit0_prompt3.csv")
df_results_multidomain = df_results_multidomain.merge(dataset[['id', 'dialogue_id', 'turn_domain', 'gold_bs', 'gold_turn_bs']], on='id', how='left')

In [None]:
add_running_accumulated_bs_column(df_results_multidomain, mode = 'preds')
add_running_accumulated_bs_column(df_results_multidomain, mode = 'golds', new_column_suffix='_new')

In [10]:
df_results_multidomain[['dialogue_id', 'turn_domain', 'preds', 'preds_bs', 'gold_bs', 'gold_turn_bs']].iloc[390:400]

Unnamed: 0,dialogue_id,turn_domain,preds,preds_bs,gold_bs,gold_turn_bs
390,MUL0845.json,taxi,"{'name': None, 'leaveat': None, 'type': None, 'stars': None, 'food': None, 'area': None, 'stay': None, 'pricerange': None, 'people': None, 'destination': None, 'parking': None, 'internet': None, 'departure': None, 'day': None, 'arriveby': None, 'time': None}","{'restaurant': {'pricerange': 'affordable', 'food': 'Italian', 'area': 'postcode'}, 'attraction': {'area': 'same', 'type': 'cinema'}, 'taxi': {'leaveat': '22:15', 'destination': 'Vue Cinema', 'departure': 'Zizzi Cambridge'}}",[taxi] leave 22:15 dest vue cinema depart zizzi cambridge [restaurant] food italian price cheap name zizzi cambridge area centre [attraction] type cinema name vue cinema area centre,"{'general-thank': [['none', 'none']]}"
391,PMUL2755.json,hotel,"{'stars': '4', 'area': 'east'}","{'hotel': {'stars': '4', 'area': 'east'}}",[hotel] area east stars 4 type hotel,"{'Hotel-Inform': [['Area', 'east'], ['Stars', '4']]}"
392,PMUL2755.json,hotel,{'type': 'guesthouse'},"{'hotel': {'stars': '4', 'area': 'east', 'type': 'guesthouse'}}",[hotel] area east stars 4 type guesthouse,"{'Hotel-Inform': [['Type', 'guesthouse']]}"
393,PMUL2755.json,hotel,"{'people': '8', 'stay': '5', 'day': 'Wednesday'}","{'hotel': {'stars': '4', 'area': 'east', 'type': 'guesthouse', 'people': '8', 'stay': '5', 'day': 'Wednesday'}}",[hotel] area east price dontcare stars 4 type guesthouse,"{'Hotel-Inform': [['Stay', '5']]}"
394,PMUL2755.json,attraction,"{'area': 'centre', 'type': 'museum'}","{'hotel': {'stars': '4', 'area': 'east', 'type': 'guesthouse', 'people': '8', 'stay': '5', 'day': 'Wednesday'}, 'attraction': {'area': 'centre', 'type': 'museum'}}",[hotel] name autumn house area east price dontcare stars 4 type guesthouse [attraction] type museum area centre,"{'Attraction-Inform': [['Area', 'centre'], ['Type', 'museum']]}"
395,PMUL2755.json,attraction,"As there is no response from the system, it is not possible to generate the belief state for the very last dialogue turn. Please provide the complete conversation.","{'hotel': {'stars': '4', 'area': 'east', 'type': 'guesthouse', 'people': '8', 'stay': '5', 'day': 'Wednesday'}, 'attraction': {'area': 'centre', 'type': 'museum'}}",[hotel] name autumn house area east price dontcare stars 4 type guesthouse [attraction] type museum area centre,"{'Attraction-Request': [['Post', '?']]}"
396,PMUL2755.json,taxi,{'taxi_needed': 'True'},"{'hotel': {'stars': '4', 'area': 'east', 'type': 'guesthouse', 'people': '8', 'stay': '5', 'day': 'Wednesday'}, 'attraction': {'area': 'centre', 'type': 'museum'}, 'taxi': {'taxi_needed': 'True'}}",[taxi] dest castle galleries depart autumn house [hotel] name autumn house area east price dontcare stars 4 type guesthouse [attraction] type museum area centre,"{'Taxi-Inform': [['none', 'none']], 'Hotel-Inform': [['none', 'none']]}"
397,PMUL2755.json,taxi,"{'people': '8', 'stay': '5', 'day': 'Wednesday'}","{'hotel': {'stars': '4', 'area': 'east', 'type': 'guesthouse', 'people': '8', 'stay': '5', 'day': 'Wednesday'}, 'attraction': {'area': 'centre', 'type': 'museum'}, 'taxi': {'taxi_needed': 'True', 'people': '8', 'stay': '5', 'day': 'Wednesday'}}",[taxi] dest castle galleries depart autumn house [hotel] name autumn house area east price dontcare stars 4 type guesthouse [attraction] type museum area centre,"{'Hotel-Inform': [['Stay', '5'], ['Day', 'wednesday'], ['People', '8']]}"
398,PMUL2755.json,taxi,{'leaveat': '12:30'},"{'hotel': {'stars': '4', 'area': 'east', 'type': 'guesthouse', 'people': '8', 'stay': '5', 'day': 'Wednesday'}, 'attraction': {'area': 'centre', 'type': 'museum'}, 'taxi': {'taxi_needed': 'True', 'people': '8', 'stay': '5', 'day': 'Wednesday', 'leaveat': '12:30'}}",[taxi] leave 12:30 dest castle galleries depart autumn house [hotel] name autumn house area east price dontcare stars 4 type guesthouse [attraction] type museum area centre,"{'Taxi-Inform': [['Leave', '12:30']]}"
399,PMUL2755.json,taxi,"{'name': None, 'leaveat': None, 'type': None, 'stars': None, 'food': None, 'area': None, 'stay': None, 'pricerange': None, 'people': None, 'destination': None, 'parking': None, 'internet': None, 'departure': None, 'day': None, 'arriveby': None, 'time': None}","{'hotel': {'stars': '4', 'area': 'east', 'type': 'guesthouse', 'people': '8', 'stay': '5', 'day': 'Wednesday'}, 'attraction': {'area': 'centre', 'type': 'museum'}, 'taxi': {'taxi_needed': 'True', 'people': '8', 'stay': '5', 'day': 'Wednesday', 'leaveat': '12:30'}}",[taxi] leave 12:30 dest castle galleries depart autumn house [hotel] name autumn house area east price dontcare stars 4 type guesthouse [attraction] type museum area centre,"{'general-bye': [['none', 'none']]}"


In [11]:
for i in range(10):
    id = df_results["id"][i]
    pred = df_results["preds"][i]
    unpacked_pred = unpack_belief_states(pred, "pred")
    row = dataset.loc[dataset["id"] == id]
    print("unpacked pred: ", unpacked_pred)
    print("turn domain: ", row["turn_domain"].item())
    print("gold belief state: ", row["gold_bs"].item())
    print("-------")

unpacked pred:  ['depart-saint john"s college', 'dest-pizza hut fen ditton']
turn domain:  taxi
gold belief state:  [taxi] dest pizza hut fenditton depart saint johns college
-------
unpacked pred:  ['leave-17:15']
turn domain:  taxi
gold belief state:  [taxi] leave 17:15 dest pizza hut fenditton depart saint johns college
-------
No belief state: as there is no mention of any slot in last turn of conversation, belief state will be empty. here's belief state in json format:

{}
unpacked pred:  ['none-none']
turn domain:  taxi
gold belief state:  [taxi] leave 17:15 dest pizza hut fenditton depart saint johns college
-------
unpacked pred:  ['none-none']
turn domain:  taxi
gold belief state:  [taxi] leave 17:15 dest pizza hut fenditton depart saint johns college
-------
unpacked pred:  ['food-portuguese', 'area-cambridge']
turn domain:  restaurant
gold belief state:  [restaurant] food portugese
-------
unpacked pred:  ['price-moderate']
turn domain:  restaurant
gold belief state:  [resta