In [1]:
import os
import json
import collections.abc
import regex as re
import pandas as pd

from tqdm import tqdm
from langchain import PromptTemplate

from src.DST.config import CONFIG


class PromptConstructor():
    def __init__(self, 
                 config):
        self.config = config
        self.instructions = config["INSTRUCTIONS"]
        self.prompt_templates = config["PROMPT_TEMPLATES"]
        
    def _get_slots_from_domains(self, domains, ontology, with_slot_description, with_all_slots, with_slot_domain_diff):
        
        if with_all_slots:
            domains = ["restaurant", "train", "attraction", "hotel", "taxi"]
        
        slots = []
        for slot in list(ontology.keys()):
            splitted_slot = slot.split("-")
            if splitted_slot[0] in domains:
                if with_slot_domain_diff:
                    if splitted_slot[-1] not in slots:
                        slots.append(splitted_slot[-1])
                else:
                    slots.append(splitted_slot[0] + "-" + splitted_slot[-1])
        
        slots_info = []
        added_slots = []
        if with_slot_description:
            for slot in slots:
                splitted_slot = slot.split("-")
                if with_slot_domain_diff:
                    if slot in added_slots:
                        continue
                    slots_info.append(f"name: {slot}, description: {SLOTS_DESCRIPTIONS[slot.lower()]}")
                    added_slots.append(slot)
                else:
                    slots_info.append(f"name: {slot}, description: {SLOTS_DESCRIPTIONS[splitted_slot[1].lower()]}")

                    
            slots = slots_info
        
        slots_prompt = "\n".join(slots)
        if with_slot_domain_diff:
            return slots_prompt + f"\n\nDOMAINS: {', '.join(domains)}"
        else:
            return slots_prompt
            
                
                
#         if with_all_slots:
#             domains = "all"
        
#         if with_slot_description:
#             with_req_inf_differentiation = False #Slot description is the discriminator

#         if domains == "all":
#             if with_req_inf_differentiation:
#                 req_slots = ", ".join(self.config["multiwoz21"]["all_requestable_slots"])
#                 inf_slots = ", ".join(self.config["multiwoz21"]["all_informable_slots"])
#             else:
#                 slots = set(self.config["multiwoz21"]["all_requestable_slots"] + 
#                             self.config["multiwoz21"]["all_informable_slots"])
#                 slots = ", ".join(slots)
#         elif not isinstance(domains, list):
#             raise ValueError("""Provided domain should be either 'all' or list of valid domain names:
#                                 - for multiwoz2.1 and 2.4: taxi, restaurant, hotel, train, attraction""")
#         else:
#             req_slots = ""
#             inf_slots = ""
#             domain_req_slots = []
#             domain_inf_slots = []
#             for domain in domains:
#                 domain_req_slots += self.config["multiwoz21"]["requestable_slots"][domain]
#                 domain_inf_slots += self.config["multiwoz21"]["informable_slots"][domain]
#             if with_req_inf_differentiation:
#                 domain_req_slots = set(domain_req_slots)
#                 domain_inf_slots = set(domain_inf_slots)
#                 req_slots += ", ".join(domain_req_slots)
#                 inf_slots += ", ".join(domain_inf_slots)
#             else:
#                 slots = set(domain_req_slots + domain_inf_slots)
#                 slots = ", ".join(slots)

#         if with_req_inf_differentiation:
#             slots_info = f"Requestable slots: {req_slots}\nInformable slots: {inf_slots}"
#         else:
#             slots_info = f"{slots}"

#         if with_slot_description:
#             slots = slots.split(", ")
#             slots_info = ""
#             for slot in slots:
#                 if slot not in self.config["multiwoz21"]["all_informable_slots"]:
#                     continue
#                 slots_info += f"name: {slot}, description: {SLOTS_DESCRIPTIONS[slot]}\n"
#             slots_info = slots_info[:-2]
        
#         return slots_info
    
    
    def _build_prompt(self, mode="", dialogue_context="", ontology="", slots="", dialogue_acts="", belief_states="", database=""):
        prompt = ""
        if mode == "dst":
            instruction = self.instructions["instruction_with_slots"]
            template_variables = self.prompt_templates["template_with_slots"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     slots=slots,
                                     dialogue_context=dialogue_context)
            
        elif mode == "dst_recorrect":
            instruction = self.instructions["instruction_with_slots_recorrect"]
            template_variables = self.prompt_templates["template_with_slots_recorrect"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])            
            prompt = template.format(instruction=instruction,
                                    slots=slots,
                                    dialogue_context=dialogue_context,
                                    belief_states=belief_states)
            
        elif mode == "database_query":
            instruction = self.instructions["instruction_query_database"]
            template_variables = self.prompt_templates["template_query_database"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     belief_states=belief_states)
            
        elif mode == "response_generation":
            example = self.config["EXAMPLES"]["response_generation"]
            
            instruction = self.instructions["instruction_response_generation"]
            template_variables = self.prompt_templates["template_response_generation"]
            template = PromptTemplate(input_variables = template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     example=example,
                                     dialogue_context=dialogue_context)
        elif mode == "e2e":
            instruction = self.instructions["instruction_e2e"]
            template_variables = self.prompt_templates["template_e2e"]
            template = PromptTemplate(input_variables = template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     database=database,
                                     dialogue_context=dialogue_context)

        else:
            raise ValueError("'mode' should be one of: [dst, dst_recorrect, database_query, response_generation, e2e]")
        
        return prompt


class MWOZ_Dataset(PromptConstructor):
    def __init__(self,
                 config,
                 data_args):
        PromptConstructor.__init__(self, config)
        self.dataset = {"id":[],
                        "dialogue_id":[],
                        "dialogue_context":[],
                        "turn":[],
                        "prompt_dst":[],
                        "prompt_dst_update":[],
                        "prompt_rg":[],
                        "prompt_e2e":[],
                        "domains":[],
                        "turn_domain":[],
                        "gold_turn_bs":[],
                        "gold_bs":[],
                        "gold_act":[],
                        "gold_response":[],
                        "gold_database_result":[],
                        }
        
        print("Loading data...")
        self.all_data, self.testfiles, self.system_acts, self.ontology = self._get_mwoz_data(data_args.mwoz_path)
        print("Loading databases...")
        self.dbs_lexicalized = self._get_dbs_lexicalized(data_args.mwoz_path, data_args.db_format_type)
        self.idx = 0
        self.dialog_history_limit_dst = data_args.dialog_history_limit_dst
        self.dialog_history_limit_rg = data_args.dialog_history_limit_rg
        self.dialog_history_limit_e2e = data_args.dialog_history_limit_e2e
        self.single_domain_only = data_args.single_domain_only
        self.with_slot_description = data_args.with_slot_description
        self.with_slot_domain_diff = data_args.with_slot_domain_diff
        self.with_all_slots = data_args.with_all_slots
        self.all_domains = ["restaurant", "taxi", "hotel", "train", "attraction"]

        print("Processing mwoz...")
        for sample in tqdm(self.all_data):
            if sample in self.testfiles:
                dialogue_log = self.all_data[sample]["log"]
                self._process_dialogue_log(sample=sample,
                                           dialogue_log=dialogue_log)

        self.dataset = pd.DataFrame(self.dataset)
        if self.single_domain_only:
            for index, row in tqdm(self.dataset.iterrows()):
                if len(row["domains"]) != 1:
                    self.dataset.drop(index, inplace=True)

                    
    def _get_mwoz_data(self, mwoz_path):
        data_path = os.path.join(mwoz_path, "data.json")
        testListFile_path = os.path.join(mwoz_path, "testListFile.txt")
        system_acts_path = os.path.join(mwoz_path, "system_acts.json")
        ontology_path = os.path.join(mwoz_path, "ontology.json")

        with open(data_path, "r") as f:
            all_data = json.load(f)
            
        with open(testListFile_path, "r") as f:
            testfiles = f.read()
        testfiles = testfiles.split("\n")
        
        with open(system_acts_path, "r") as f:
            system_acts = json.load(f)
            
        with open(ontology_path, "r") as f:
            ontology = json.load(f)
            
        return all_data, testfiles, system_acts, ontology
    
    def _get_dbs_lexicalized(self, mwoz_path, format_type):
        domains = ["restaurant", "hotel", "train", "attraction"]
        keep_data = {"restaurant":["address", "area", "food", "name", "pricerange", "phone", "postcode"],
                    "attraction":["name", "area", "address", "type", "postcode"],
                    "hotel":["name", "address", "area", "phone", "postcode", "pricerange", "stars"],
                    "train":["departure", "destination"]}
        dbs_lexicalized = {}
        for domain in domains:
            db_path = os.path.join(mwoz_path, f"{domain}_db.json")
            with open(db_path, "r") as f:
                db_data = json.load(f)

            db_lexicalized = []
            if format_type == "1":
                for row in db_data:
                    row_keep = []
                    for key in keep_data[domain]:
                        if key in row:
                            row_keep.append(f"{key}: {row[key]}")
                    db_lexicalized.append(", ".join(row_keep))
            
            elif format_type == "2":
                #more concise db to fit in context length limit
                db_lexicalized.append(", ".join(keep_data[domain]))
                for row in db_data:
                    row_keep = []
                    for key in keep_data[domain]:
                        if key in row:
                            row_keep.append(f"{row[key]}")
                    db_lexicalized.append(", ".join(row_keep))
                    # db_lexicalized.append(", ".join([f"{row[key]}" for key in keep[domain]]))
            dbs_lexicalized[domain] = "\n".join(set(db_lexicalized))

        return dbs_lexicalized
    
    def _process_dialogue_log(self, sample, dialogue_log):

        dialog_history_memory_dst = []
        dialog_history_memory_rg = []
        dialog_history_memory_e2e = []
        dialog_history_dst = ""
        dialog_history_rg = ""
        dialog_history_e2e = ""
        turn_domain = ""
        domains = self._get_domains_from_log(dialogue_log)
        slots = self._get_slots_from_domains(domains=domains, 
                                             ontology=self.ontology,
                                             with_slot_description=self.with_slot_description,
                                             with_slot_domain_diff=self.with_slot_domain_diff,
                                             with_all_slots=self.with_all_slots) # or all

        for turn_nb, turn in enumerate(dialogue_log):

            if turn_nb % 2 == 0:
                speaker = "USER"
            else:
                speaker = "SYSTEM"
            
            utterance = f"""{speaker}: {turn["text"]}\n"""
            dialog_act = turn["dialog_act"]
            cur_system_act = self.system_acts[sample.split(".")[0]][str((turn_nb//2)+1)]
            
            dialogue_context_dst = dialog_history_dst + utterance
            prompt_dst = self._build_prompt(mode="dst",
                                            slots=slots,
                                            dialogue_context=dialogue_context_dst)
            
            lexicalized_act = self._lexicalize_act(cur_system_act)
            dialogue_context_rg = dialog_history_rg + utterance + f"ACT:{lexicalized_act}\nSYSTEM:"
            prompt_rg = self._build_prompt(mode="response_generation",
                                            dialogue_context=dialogue_context_rg)
            
            dialogue_context_e2e = dialog_history_e2e + utterance + "SYSTEM:"
    
            turn_domain = self._get_domain_from_turn(turn_domain, cur_system_act)
            if turn_domain and turn_domain != "taxi":
                database = self.dbs_lexicalized[turn_domain]
            else:
                database = ""
            prompt_e2e = self._build_prompt(mode="e2e",
                                            database=database,
                                            dialogue_context=dialogue_context_e2e).replace("\n\n\n", "\n")

            dialog_history_dst, dialog_history_memory_dst = self._update_dialogue_memory(utterance, 
                                                                                         dialogue_log, 
                                                                                         self.dialog_history_limit_dst, 
                                                                                         dialog_history_memory_dst)
            dialog_history_rg, dialog_history_memory_rg = self._update_dialogue_memory(utterance, 
                                                                                       dialogue_log, 
                                                                                       self.dialog_history_limit_rg,
                                                                                       dialog_history_memory_rg)
            dialog_history_e2e, dialog_history_memory_e2e = self._update_dialogue_memory(utterance, 
                                                                                         dialogue_log, 
                                                                                         self.dialog_history_limit_e2e, 
                                                                                         dialog_history_memory_e2e) 
                
            metadata = turn["metadata"]
            bspn = {}
            if metadata:
                for domain in domains:
                    for k, v in metadata[domain].items():
                        for slot, value in v.items():
                            if isinstance(value, str) and value not in ["", "not mentioned", "none"]:
                                bspn[domain+"-"+slot] = value
            self.idx += 1
            if turn_nb % 2 == 0:
                self.dataset["gold_turn_bs"].append(dialog_act)
                self.dataset["dialogue_context"].append(dialogue_context_dst)
                self.dataset["gold_database_result"].append(None) 
                self.dataset["turn"].append(turn_nb//2)
                self.dataset["domains"].append(domains)
                self.dataset["id"].append(self.idx//2)
                self.dataset["dialogue_id"].append(sample)
                self.dataset["prompt_dst"].append(prompt_dst)
                self.dataset["prompt_dst_update"].append(prompt_dst)
                self.dataset["prompt_rg"].append(prompt_rg)
                self.dataset["prompt_e2e"].append(prompt_e2e)
                self.dataset["turn_domain"].append(turn_domain)
            else:
                self.dataset["gold_response"].append(utterance)
                self.dataset["gold_bs"].append(bspn)
                self.dataset["gold_act"].append(dialog_act)

    def _update_dialogue_memory(self, utterance, dialogue_log, dialog_history_limit, dialog_history_memory):
        if dialog_history_limit != 0:
            if dialog_history_limit == -1:
                dialog_history_limit = len(dialogue_log)
            if len(dialog_history_memory) >= dialog_history_limit:
                dialog_history_memory.pop(0)
            dialog_history_memory.append(utterance)

        dialog_history = "".join(dialog_history_memory)
        return dialog_history, dialog_history_memory
    
    def _lexicalize_act(self, act):
        if act == "No Annotation":
            return "None"
        
        lexicalized_acts = []
        lexicalize_mapping = {"leave": "leave time",
                              "arrive":"arrival time",
                              "departure":"departure place",
                              "post":"postcode",
                              "addr":"address"}

        for act, slot_values in act.items():


            if "request" in act.lower():
                requests = []
                for (slot, value) in slot_values:
                    slot = slot.lower()
                    if slot in lexicalize_mapping:
                        slot = lexicalize_mapping[slot]
                    if slot == "none":
                        break
                    else:
                        requests.append(slot)
                if requests:
                    lexicalized_act = "Request the user about " + ", ".join(requests) + "."
                    lexicalized_acts.append(lexicalized_act)

            elif "recommend" in act.lower():
                recommends = []
                for (slot, value) in slot_values:
                    slot, value = slot.lower(), value.lower()
                    if slot in lexicalize_mapping:
                        slot = lexicalize_mapping[slot]
                    if slot == "none":
                        break
                    else:
                        recommends.append(value)
                if recommends:
                    lexicalized_act = "Recommend the user for " + ", ".join(recommends) + "."
                    lexicalized_acts.append(lexicalized_act)

            elif "inform" in act.lower():
                informs = []
                for (slot, value) in slot_values:
                    slot, value = slot.lower(), value.lower()
                    if slot in lexicalize_mapping:
                        slot = lexicalize_mapping[slot]
                    if slot == "none":
                        break
                    else:
                        informs.append(f"the {slot} is {value}")
                if informs:
                    lexicalized_act = "Inform the user that " + ", ".join(informs) + "."  
                    lexicalized_acts.append(lexicalized_act)

            else:
                pass
        if lexicalized_acts:
            return " ".join(lexicalized_acts)
        else:
            return "None"
        
    def _get_domain_from_turn(self, domain, act):
        for k in act:
            turn_domain = k.lower().split("-")[0]
            if turn_domain in self.all_domains:
                return turn_domain
        return domain
            

    def _get_domains_from_log(self, dialogue_log):
        domains = []
        for log in dialogue_log:
            for domain_act in log["dialog_act"]:
                domain = domain_act.split("-")[0].lower()
                if domain in self.all_domains and domain not in domains:
                    domains.append(domain)
        return domains
                

In [12]:
# results_path = "src/DST/results_single/gpt-3.5-turbo_0-end_singleDomainOnlyFalse_withSlotDescriptionFalse_dialogHistoryLimit0_latestSave.csv"
# results_path = "src/DST/results_single/gpt-4_0-end_singleDomainOnlyTrue_withSlotDescriptionFalse_dialogHistoryLimit0_latestSave.csv"
results_path = "/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_0-end_singleDomainOnlyFalse_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3_latestSave.csv"
# results_path ="/home/willy/InstrucTOD/src/DST/results_single/gpt-4_0-end_singleDomainOnlyTrue_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3.csv"
df_results = pd.read_csv(results_path)

In [13]:
len(df_results)

7352

In [15]:
merged_results = pd.merge(dataset, df_results, on=["id"])
merged_results = merged_results.loc[:, ~merged_results.columns.str.contains('^Unnamed')]
# save_path = results_path[:-4] + "_merged.csv"
# merged_results.to_csv(save_path)

In [19]:
count = 0
domain_count = {"multi":0}
for idx, row in merged_results.iterrows():
    if len(row["domains"]) == 1:
        count +=1
        domain = row["domains"][0]
        if domain in domain_count:
            domain_count[domain] += 1
        else:
            domain_count[domain] = 0
    else:
        domain_count["multi"] += 1


domain_count

{'multi': 6293,
 'taxi': 184,
 'restaurant': 289,
 'hotel': 380,
 'train': 159,
 'attraction': 42}

In [20]:
import os
import json
import pandas as pd

from src.DST.dst import cleaning_for_eval


def fix_typos(pred):
    pred = pred.replace("'", '"')
    pred = pred.replace('Catherine"s', "Catherine's")
    pred = pred.replace('John"s', "John's")
    pred = pred.replace('rosa"s', "rosa's")
    return pred

def unpack_belief_state(belief_state, type):
    slot_values = ""
    if type == "pred":
        domain = belief_state["domain"]
        act = belief_state["act"]
        bs = belief_state["belief_state"]
        domain_act = domain + "-" + act + ", "
        if isinstance(bs, dict):
            if bs:
                for k, v in bs.items():
                    if v:
                        v = str(v)
                        k, v = cleaning_for_eval(k, v)
                        slot_values += k + "-" + v + ", "
                    else:
                        k, _ = cleaning_for_eval(k, None)
                        slot_values += k + "-?, "
            else:
                slot_values += "None-None, "
        elif isinstance(bs, list):
            for slot_value in bs:
                for k, v in slot_value.items():
                    if k in ["slot", "slots"]:
                        _, v = cleaning_for_eval(None, v)
                        slot_values += v + "-"
                    elif k in ["value", "values"]:
                        _, v = cleaning_for_eval(None, v)
                        slot_values += v + ", "
                    else:
                        slot_values += "None-None, "
        if not bs:
            slot_values += "None-None, "

    elif type == "gold":
        if not belief_state:
            domain_act = "none-none, "
            slot_values += "none-none, "
        domain_act = ""
        for k, v in belief_state.items():
            domain_act += k + ", "
            for pair in v:
                slot, value = cleaning_for_eval(pair[0].lower(), pair[1].lower())
                slot_values += slot + "-" + value + ", "
            ##process list
    return domain_act[:-2].lower(), slot_values[:-2].lower() 

def process_preds_dst(results_df, save_path=None):
    all_preds = []
    all_golds = []
    all_domains = []
    all_prompts = []
    all_dialogues = []
    missed_states = 0 
    for idx, row in results_df.iterrows():
        domain = row["domains"]
        pred = row["preds"]
        gold = row["gold_bs"]
        prompt = row["prompt"]
        dialogue = prompt.split("CONTEXT:")[1]

        gold = fix_typos(gold)
        pred = fix_typos(pred)
        try:
            belief_state_gold =  json.loads(gold)
            belief_state_pred = json.loads(pred)
        except:
            print("Failed to load the json")
            print("gold", gold)
            print("pred", pred)
            print("=======================")
            continue
        if "domain" in belief_state_pred.keys() and "act" in belief_state_pred.keys() and "belief_state" in belief_state_pred.keys():
            domain_act_pred, slot_values_pred = unpack_belief_state(belief_state_pred, type="pred")
        else:
            missed_states += 1
            domain_act_pred, slot_values_pred = "none-none", "none-none"

        domain_act_gold, slot_values_gold = unpack_belief_state(belief_state_gold, type="gold")

        all_preds.append(domain_act_pred+"||"+slot_values_pred)
        all_golds.append(domain_act_gold+"||"+slot_values_gold)
        all_prompts.append(prompt)
        all_dialogues.append(dialogue)
        all_domains.append(domain)

    processed_results_df = pd.DataFrame({"prompts":all_prompts,
                                         "dialogue":all_dialogues,
                                         "preds":all_preds, 
                                         "golds":all_golds, 
                                         "domains":all_domains})
    print(f"Missed {missed_states} states")
    if save_path:
        processed_results_df.to_csv(save_path)

    return processed_results_df


def evaluate_single_domain_jga(processed_results_df, vocal=True, save_path=None):
    global_turns = 0
    global_jga = 0
    results_per_domain = {"taxi":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "restaurant":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "hotel":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "train":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "attraction":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0}}
    
    for idx, row in processed_results_df.iterrows():
        domain = row["domains"]
        preds = row["preds"].split("||")[1].split(", ")
        golds = row["gold_bs"].split("||")[1].split(", ")
        gold_slots = [gold.split("-")[1] for gold in golds]
        pred_slots = []
        for pred in preds:
            # if pred not in golds:
            #     print(f"NOT {pred}")
            # else:
            #     print(f"YES {pred}")

            if len(pred.split("-")) != 2:
                continue
            else:
                slot = pred.split("-")[1]
                pred_slots.append(slot)

        for gold_slot in gold_slots:
            if gold_slot in pred_slots:
                results_per_domain[domain]["correct_slots"] += 1
            results_per_domain[domain]["total_slots"] += 1

        if set(preds) == set(golds):
            results_per_domain[domain]["correct_turns_jga"] += 1
            global_jga += 1
        results_per_domain[domain]["turns"] += 1
        global_turns += 1 

    for domain in results_per_domain:
        results_per_domain[domain]["JGA"] = results_per_domain[domain]["correct_turns_jga"]/results_per_domain[domain]["turns"]
        results_per_domain[domain]["SLOT-F1"] = results_per_domain[domain]["correct_slots"]/results_per_domain[domain]["total_slots"]
        if vocal:
            print(f"""For {domain}, JGA: {results_per_domain[domain]["JGA"]} - SLOT-F1: {results_per_domain[domain]["SLOT-F1"]}""")
    
    if save_path:
        with open(save_path, "w") as f:
            json.dump(results_per_domain, f, indent=4)

    return results_per_domain




#utils to merge dataframe results
def retrieve_golds(MWOZ_dataset, start_sample_idx, last_sample_idx, results_df):
    dataset = MWOZ_dataset.dataset
    df = results_df.copy(deep=True)
    if last_sample_idx != 1053:
        df["golds"] = list(dataset["golds"][start_sample_idx:last_sample_idx+1])
        df["prompts"] = list(dataset["prompts"][start_sample_idx:last_sample_idx+1])
        df["domains"] = list(dataset["domains"][start_sample_idx:last_sample_idx+1])
        df["dialogue_ids"] = list(dataset["ids"][start_sample_idx:last_sample_idx+1])
        df["model_used"] = ["gpt-4" for _ in range(last_sample_idx-start_sample_idx+1)]
    else:
        df["golds"] = list(dataset["golds"][start_sample_idx:])
        df["prompts"] = list(dataset["prompts"][start_sample_idx:])
        df["domains"] = list(dataset["domains"][start_sample_idx:])
        df["dialogue_ids"] = list(dataset["ids"][start_sample_idx:])
        df["model_used"] = ["gpt-4" for _ in range(last_sample_idx-start_sample_idx)]
    return df

def merge_results(MWOZ_dataset, result_folder="src/DST/results/"):
    total_sample = 0
    merged_df = pd.DataFrame({})
    sorted_files = sorted(os.listdir(result_folder))
    for file in sorted_files:
        idxs = file.split("_")[1].split("-")
        start_idx, last_idx = int(idxs[0]), int(idxs[1])
        print(f"Processing samples between {start_idx} and {last_idx}")
        result_path = os.path.join(result_folder, file)
        results_df = pd.read_csv(result_path)
        total_sample += len(results_df)
        new_df = retrieve_golds(MWOZ_dataset, start_idx, last_idx, results_df)
        merged_df = merged_df.append(new_df, ignore_index=False)
    merged_df = merged_df.drop_duplicates(subset=["prompts", "dialogue_ids"], keep="first")
    return merged_df


In [21]:
SLOTS_REMAPPING = {
        # slots
        "address":"addr",
        "postcode":"post",
        "leaveat":"leave",
        "arriveby":"arrive",
        "pricerange":"price",
        "price":"fee",
        "reference":"ref",
        "departure":"depart",
        "destination":"dest",
        # values
        "unknown":"?", "inform":"?", "unk":"?", "needed":"?", "available":"?", "requested":"?", "request":"?", "n/a":"?",
}

def fix_typos(pred):

    pred = pred.replace("catherine 's", "catherine's")
    pred = pred.replace("john 's", 'john"s')
    pred = pred.replace("rosa 's", 'rosa"s')
    pred = pred.replace("mary 's", 'mary"s')
    pred = pred.replace("christ 's", "christ's")
    pred = pred.replace("alpha - milton", "alpha-milton")
    pred = pred.replace("michaelhouse cafe", "mic")
    pred = pred.replace("the ", "")
    pred = pred.replace(" nights", "")  
    pred = pred.replace(" person", "")
    pred = pred.replace(" night", "")   
    pred = pred.replace(" days", "")    
    pred = pred.replace("after ", "")  

    return pred


def remapping(pred):
    pred = pred.lower()
    if pred in SLOTS_REMAPPING:
        pred = SLOTS_REMAPPING[pred]
    return pred


def unpack_belief_states(belief_state, mode):

    def nested_fix(d, fix):
        if not d or isinstance(d, bool):
            return ""
        elif isinstance(d, dict):  # if dict, apply to each key
            return {k: nested_fix(v, fix) for k, v in d.items()}
        elif isinstance(d, list):  # if list, apply to each element
            return [nested_fix(elem, fix) for elem in d]
        else:
            return fix(d)

    unpacked_belief_states = []
    
    if mode == "gold":
        if isinstance(belief_state, str):
            # try:
            #     belief_state = json.loads(fix_typos(belief_state.lower()))
            # except:
            rx = re.compile(r'"[^"]*"(*SKIP)(*FAIL)|\'')
            belief_state = json.loads(rx.sub('"', belief_state.lower()))
        
        belief_state = nested_fix(belief_state, fix_typos)
        if not belief_state:
            return ["none-none"]
        for domain_act in belief_state:
            slot_values = belief_state[domain_act]
            for slot_value in slot_values:
                # slot, value = slot_value[0].lower(), slot_value[1].lower()
                slot, value = fix_typos(slot_value[0].lower()), fix_typos(slot_value[1].lower())
                unpacked_belief_states.append(f"{slot}-{value}")

    elif mode == "pred":
        if not isinstance(belief_state, dict):
            try:
                rx = re.compile(r'"[^"]*"(*SKIP)(*FAIL)|\'')
                belief_state = json.loads(rx.sub('"', belief_state.lower()))
            except:
                print(f"couldn't load: {fix_typos(belief_state.lower())}")
                return ["none-none"]
        else:
            if not belief_state:
                return ["none-none"]

        flag = False

        belief_state = nested_fix(belief_state, fix_typos)
        for slot, value in belief_state.items():

            if isinstance(value, dict):
                # print("is dict", value)
                flag = True
                unpacked_belief_states += unpack_belief_states(value, "pred")
            
            elif isinstance(value, list):
                flag = True
                # print("is list", value)
                unpacked_belief_states.append("none-none")

            elif not value:
                # print("is none", value)
                continue

            else:
                # print("is rest", value)
                fixed_value = remapping(str(value).lower())
                if fixed_value == "none":
                    fixed_slot = "none"
                else:
                    fixed_slot = remapping(slot.lower())

                slot_value = f"{fixed_slot}-{fixed_value}"
                if slot_value != "none-none":
                    flag = True
                unpacked_belief_states.append(slot_value)

        if flag:
            unpacked_belief_states = list(filter(lambda a: a != "none-none", unpacked_belief_states))
        if not unpacked_belief_states:
            unpacked_belief_states = ["none-none"]


    return unpacked_belief_states

In [31]:
def compute_prf(gold, pred):
    TP, FP, FN = 0, 0, 0
    if len(gold)!= 0:
        for g in gold:
            if g in pred:
                TP += 1
            else:
                FN += 1
        for p in pred:
            if p not in gold:
                FP += 1
        precision = TP / float(TP+FP) if (TP+FP)!=0 else 0
        recall = TP / float(TP+FN) if (TP+FN)!=0 else 0
        F1 = 2 * precision * recall / float(precision + recall) if (precision+recall)!=0 else 0
    else:
        if len(pred)==0:
            precision, recall, F1  = 1, 1, 1
        else:
            precision, recall, F1  = 0, 0, 0
    return F1, recall, precision

compute_prf(['food-portuguese', 'area-cambridge'], ['food-portuguese'])

count = 0
F1_score = []
recall_score = []
precision_score = []
L = len(merged_results)
for idx in range(L):
    unpacked_gold = unpack_belief_states(merged_results["gold_bs"][idx], "gold")
    unpacked_pred = unpack_belief_states(merged_results["preds"][idx], "pred")
    F1, recall, precision = compute_prf(unpacked_gold, unpacked_pred)
    F1_score.append(F1)
    recall_score.append(recall)
    precision_score.append(precision)
    if set(unpacked_gold)==set(unpacked_pred):
        count += 1
        continue
    print("idx", idx, "domains", merged_results["domains"][idx])
    print(merged_results["preds"][idx])
    print("prompt", merged_results["prompt"][idx])
    print("original pred", merged_results["preds"][idx])
    print("original gold", merged_results["gold_bs"][idx])
    print("context", merged_results["dialogue_context"][idx].replace("\n", ""))
    print("pred", unpacked_pred)
    print("gold", unpacked_gold)
    print("same", set(unpacked_gold)==set(unpacked_pred))
    print("=======")

couldn't load: as there is no mention of any slot in last turn of conversation, belief state will be empty. here's belief state in json format:

{}
couldn't load: {'departure': none, 'car': none, 'destination': none, 'phone': none, 'leaveat': none, 'arriveby': none}
idx 4 domains ['restaurant', 'attraction']
{'name': 'Nusha'}
prompt Generate the belief state of the last dialogue turn in the following conversation between a USER and a SYSTEM in a task-oriented dialogue setting. The results should be in json format following this format: {'slot1':'value1', 'slot2':'value2', etc...}. Use the slot from SLOTS to generate the belief state:

SLOTS:
area, reference, price, type, time, name, people, postcode, food, pricerange, phone, address, day

CONTEXT:
USER: Please find a restaurant called Nusha.

original pred {'name': 'Nusha'}
original gold {'Restaurant-Inform': [['none', 'none']], 'Attraction-Inform': [['Name', 'Nusha']]}
context USER: Please find a restaurant called Nusha.
pred ['name-n

In [34]:
import statistics
print(statistics.mean(F1_score))
print(statistics.mean(recall_score))
print(statistics.mean(precision_score))

0.6261206202412058
0.6562109759573035
0.617622543680637


In [35]:
count/len(merged_results)

0.5020402611534276

In [55]:

count = 0
count_single_domain = 0
for idx in range(len(merged_results)):
    domains = merged_results["domains"][idx]
    unpacked_gold = unpack_belief_states(merged_results["gold_bs"][idx], "gold")
    unpacked_pred = unpack_belief_states(merged_results["preds"][idx], "pred")
    if len(domains) == 1:
        count_single_domain += 1
        if set(unpacked_gold)==set(unpacked_pred):
            count += 1
            continue
    print("idx", idx, "domains", domains)
    print("original pred", merged_results["preds"][idx])
    print("original gold", merged_results["gold_bs"][idx])
    print("context", merged_results["dialogue_context"][idx].replace("\n", ""))
    print("prompt", merged_results["prompt"][idx])
    print(unpacked_pred)
    print(unpacked_gold)
    print("same", set(unpacked_gold)==set(unpacked_pred))
    print("=======")

couldn't load: {}

since user did not provide any information related to slots in last dialogue turn, belief state is empty.
idx 4 domains ['restaurant']
original pred {
  "food": "Portuguese",
  "area": "Cambridge"
}
original gold {'Restaurant-Inform': [['Food', 'portuguese']]}
context USER: Are there any Portuguese restaurants in Cambridge?
prompt Generate the belief state of the last dialogue turn in the following conversation between a USER and a SYSTEM in a task-oriented dialogue setting. The results should be in json format following this format: {'slot1':'value1', 'slot2':'value2', etc...}. Use the slot from SLOTS to generate the belief state:

SLOTS:
Requestable slots: pricerange, phone, reference, postcode, address, food, area
Informable slots: pricerange, name, day, time, food, people, area

CONTEXT:
USER: Are there any Portuguese restaurants in Cambridge?

['food-portuguese', 'area-cambridge']
['food-portuguese']
same False
idx 11 domains ['hotel']
original pred {
  "people"

In [238]:
# results_path_single = "src/DST/results_single/gpt-3.5-turbo_0-end_singleDomainOnlyTrue_withSlotDescriptionFalse_dialogHistoryLimit0.csv"
# df_results_single = pd.read_csv(results_path_single)

In [277]:
# merged_df_single = pd.merge(df_results_single, dataset, on=["id"])

In [62]:
# count = 0
results_per_domain = {"taxi":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "restaurant":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "hotel":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "train":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "attraction":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0}}
for idx in range(len(merged_results)):
    unpacked_gold = unpack_belief_states(merged_results["gold_bs_x"][idx], "gold")
    unpacked_pred = unpack_belief_states(merged_results["preds"][idx], "pred")        
    domain = merged_results["domains_x"][idx][0]
    results_per_domain[domain]["turns"] += 1
    if set(unpacked_gold)==set(unpacked_pred):
        count += 1
        results_per_domain[domain]["correct_turns_jga"] += 1
        continue
    if domain not in ["attraction"]:
        continue
    print("idx", idx, "domains", domain)
    print("original pred", merged_results["preds"][idx])
    print("original gold", merged_results["gold_bs_x"][idx])
    print("context", merged_results["dialogue_context_x"][idx].replace("\n", ""))
    print("pred", unpacked_pred)
    print("gold", unpacked_gold)
    print("same", set(unpacked_gold)==set(unpacked_pred))
    print("=======")

couldn't load: as there is no mention of any slot in last turn of conversation, belief state will be empty. here's belief state in json format:

{}
couldn't load: {'departure': none, 'leaveat': none, 'destination': none, 'arriveby': none, 'phone': none, 'car': none}
couldn't load: as there is no user input in last turn, belief state remains same as previous turn. without any context of previous turn, it is impossible to generate belief state.
couldn't load: based on last turn of conversation, there is no new information provided by user to update belief state. therefore, belief state remains same as previous turn. without additional context, it is not possible to generate belief state.
couldn't load: as there is no context provided, i cannot generate belief state of last dialogue turn. please provide me with previous conversation turns so that i can generate belief state.
couldn't load: {'food': none, 'address': none, 'reference': none, 'name': none, 'area': none, 'day': none, 'postcod

In [63]:
for key in results_per_domain:
    print(key, f" - JGA: {results_per_domain[key]['correct_turns_jga']/results_per_domain[key]['turns']}")

taxi  - JGA: 0.7081081081081081
restaurant  - JGA: 0.7379310344827587
hotel  - JGA: 0.5328083989501312
train  - JGA: 0.65625
attraction  - JGA: 0.6511627906976745


In [29]:
for key in results_per_domain:
    print(key, f" - JGA: {results_per_domain[key]['correct_turns_jga']/results_per_domain[key]['turns']}")

taxi  - JGA: 0.654054054054054
restaurant  - JGA: 0.7724137931034483
hotel  - JGA: 0.5905511811023622
train  - JGA: 0.7625
attraction  - JGA: 0.4883720930232558


In [289]:
count/len(df_results_single)

0.6402266288951841

In [291]:
def evaluate_single_domain_jga(processed_results_df, vocal=True, save_path=None):
    global_turns = 0
    global_jga = 0
    results_per_domain = {"taxi":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "restaurant":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "hotel":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "train":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "attraction":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0}}
    
    for idx, row in processed_results_df.iterrows():
        domain = row["domains"]
        preds = row["preds"].split("||")[1].split(", ")
        golds = row["gold_bs"].split("||")[1].split(", ")
        gold_slots = [gold.split("-")[1] for gold in golds]
        pred_slots = []
        for pred in preds:
            # if pred not in golds:
            #     print(f"NOT {pred}")
            # else:
            #     print(f"YES {pred}")

            if len(pred.split("-")) != 2:
                continue
            else:
                slot = pred.split("-")[1]
                pred_slots.append(slot)

        for gold_slot in gold_slots:
            if gold_slot in pred_slots:
                results_per_domain[domain]["correct_slots"] += 1
            results_per_domain[domain]["total_slots"] += 1

        if set(preds) == set(golds):
            results_per_domain[domain]["correct_turns_jga"] += 1
            global_jga += 1
        results_per_domain[domain]["turns"] += 1
        global_turns += 1 

    for domain in results_per_domain:
        results_per_domain[domain]["JGA"] = results_per_domain[domain]["correct_turns_jga"]/results_per_domain[domain]["turns"]
        results_per_domain[domain]["SLOT-F1"] = results_per_domain[domain]["correct_slots"]/results_per_domain[domain]["total_slots"]
        if vocal:
            print(f"""For {domain}, JGA: {results_per_domain[domain]["JGA"]} - SLOT-F1: {results_per_domain[domain]["SLOT-F1"]}""")
    
    # if save_path:
    #     with open(save_path, "w") as f:
    #         json.dump(results_per_domain, f, indent=4)

    return results_per_domain

In [314]:
mwoz_path = "/home/willy/InstrucTOD/MultiWOZ_2.4"
dialog_history_limit = 0
single_domain_only = False
with_slot_description = False
mwoz24 = MWOZ_Dataset(CONFIG, 
                    mwoz_path,
                    dialog_history_limit,
                    with_slot_description,
                    single_domain_only)
dataset24 = mwoz24.dataset

Processing mwoz...


  0%|          | 4/10438 [00:00<00:00, 23301.69it/s]


KeyError: 'dialog_act'

In [66]:
with open("/home/willy/InstrucTOD/MultiWOZ2.4/data/mwz24/MULTIWOZ2.4/data.json", "r") as f:
    data24 = json.load(f)

with open("/home/willy/InstrucTOD/MultiWOZ_2.1/data.json", "r") as f:
    data21 = json.load(f)

In [59]:
data21["SNG01856.json"]["log"][0].keys()

dict_keys(['text', 'metadata', 'dialog_act', 'span_info'])

In [67]:
data24["SNG01856.json"]["log"][0].keys()

dict_keys(['text', 'metadata'])

In [70]:
with open("/home/willy/InstrucTOD/MultiWOZ2.4/data/mwz2.4/test_dials.json", "r") as f:
    dials24 = json.load(f)

In [82]:
L = []
for dial in dials24:
    L.append(dial["dialogue_idx"])


In [80]:
dials24[0]

{'dialogue_idx': 'SNG0073.json',
 'domains': ['taxi'],
 'dialogue': [{'system_transcript': '',
   'turn_idx': 0,
   'belief_state': [{'slots': [['taxi-destination', 'pizza hut fenditton']],
     'act': 'inform'},
    {'slots': [['taxi-departure', 'saint johns college']], 'act': 'inform'}],
   'turn_label': [['taxi-destination', 'pizza hut fenditton'],
    ['taxi-departure', 'saint johns college']],
   'transcript': 'i would like a taxi from saint john s college to pizza hut fen ditton .',
   'system_acts': [],
   'domain': 'taxi'},
  {'system_transcript': 'what time do you want to leave and what time do you want to arrive by ?',
   'turn_idx': 1,
   'belief_state': [{'slots': [['taxi-leaveat', '17:15']], 'act': 'inform'},
    {'slots': [['taxi-destination', 'pizza hut fenditton']], 'act': 'inform'},
    {'slots': [['taxi-departure', 'saint johns college']], 'act': 'inform'}],
   'turn_label': [['taxi-leaveat', '17:15']],
   'transcript': 'i want to leave after 17:15 .',
   'system_acts

In [325]:
data21["SNG01856.json"]["log"][0]['dialog_act']

{'Hotel-Inform': [['Type', 'hotel'], ['Price', 'cheap']]}

In [338]:
data24["SNG01856.json"]

{'goal': {'taxi': {},
  'police': {},
  'hospital': {},
  'hotel': {'info': {'type': 'hotel',
    'parking': 'yes',
    'pricerange': 'cheap',
    'internet': 'yes'},
   'fail_info': {},
   'book': {'pre_invalid': True,
    'stay': '2',
    'day': 'tuesday',
    'invalid': False,
    'people': '6'},
   'fail_book': {'stay': '3'}},
  'topic': {'taxi': False,
   'police': False,
   'restaurant': False,
   'hospital': False,
   'hotel': False,
   'general': False,
   'attraction': False,
   'train': False,
   'booking': False},
  'attraction': {},
  'train': {},
  'message': ["You are looking for a <span class='emphasis'>place to stay</span>. The hotel should be in the <span class='emphasis'>cheap</span> price range and should be in the type of <span class='emphasis'>hotel</span>",
   "The hotel should <span class='emphasis'>include free parking</span> and should <span class='emphasis'>include free wifi</span>",
   "Once you find the <span class='emphasis'>hotel</span> you want to book it

In [3]:
df1 = pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_0-1675_singleDomainOnlyFalse_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3_latestSave.csv")
df2 = pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_1676-2301_singleDomainOnlyFalse_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3_latestSave.csv")
df3 = pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_2301-4625_singleDomainOnlyFalse_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3_latestSave.csv")
df4 = pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_4626-6925_singleDomainOnlyFalse_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3_latestSave.csv")
df5 = pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_6925-end_singleDomainOnlyFalse_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3_latestSave.csv")

In [4]:
len(df2) + len(df1) + len(df3) + len(df4) + len(df5)

7352

In [6]:
merged_gpt3 = pd.concat([df1, df2, df3, df4, df5], join="inner")

In [7]:
len(merged_gpt3)

7352

In [9]:
merged_gpt3 = merged_gpt3.loc[:, ~merged_gpt3.columns.str.contains('^Unnamed')]

In [10]:
merged_gpt3.to_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_0-end_singleDomainOnlyFalse_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3_latestSave.csv")

In [1]:
import pandas as pd
len(pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_4626-end_singleDomainOnlyFalse_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3_latestSave.csv"))

2300

In [2]:
4626+2300-1

6925

In [36]:
df = pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-4_0-end_singleDomainOnlyTrue_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3.csv")

In [40]:
dataset.head(2)

Unnamed: 0,id,dialogue_id,dialogue_context,turn,prompt,domains,gold_bs,gold_act,gold_response,gold_database_result
0,0,SNG0073.json,USER: I would like a taxi from Saint John's co...,0,Generate the belief state of the last dialogue...,[taxi],"{'Taxi-Inform': [['Dest', 'pizza hut fen ditto...","{'Taxi-Request': [['Leave', '?'], ['Arrive', '...",SYSTEM: What time do you want to leave and wha...,
1,1,SNG0073.json,USER: I want to leave after 17:15.\n,1,Generate the belief state of the last dialogue...,[taxi],"{'Taxi-Inform': [['Leave', '17:15']]}","{'Taxi-Inform': [['Car', 'blue honda'], ['Phon...",SYSTEM: \nBooking completed! your taxi will be...,


In [37]:
df.head(5)

Unnamed: 0.1,Unnamed: 0,id,preds,completion_info
0,0,0,"{\n ""destination"": ""Pizza Hut Fen Ditton"",\n ...","{\n ""choices"": [\n {\n ""finish_reason..."
1,1,1,"{\n ""leaveat"": ""after 17:15""\n}","{\n ""choices"": [\n {\n ""finish_reason..."
2,2,2,{\n},"{\n ""choices"": [\n {\n ""finish_reason..."
3,3,3,{}\n\nSince the user did not provide any infor...,"{\n ""choices"": [\n {\n ""finish_reason..."
4,4,79,"{\n ""food"": ""Portuguese"",\n ""area"": ""Cambrid...","{\n ""choices"": [\n {\n ""finish_reason..."


In [43]:
results_df_path = "/home/willy/InstrucTOD/src/DST/results_single/gpt-4_0-end_singleDomainOnlyTrue_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3.csv"
save_path = "/home/willy/InstrucTOD/src/DST/processed_results/" + results_df_path.split("/")[-1][:-4] + "_results.csv"

In [49]:
len(pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_0-end_singleDomainOnlyTrue_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt2.csv"))

1059

In [67]:
df_single = pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_0-end_singleDomainOnlyTrue_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3.csv")
df_multi = pd.read_csv("/home/willy/InstrucTOD/src/DST/results_single/gpt-3.5-turbo_0-end_singleDomainOnlyFalse_withSlotDescriptionFalse_withSlotDifferentiationFalse_dialogHistoryLimit0_prompt3_latestSave.csv")
df_multi = df_multi[["id", "preds"]]
df_multi = pd.merge(dataset, df_multi, on=["id"])
df_multi = df_multi.loc[:, ~df_multi.columns.str.contains('^Unnamed')]

In [96]:
count = 0
single_idx = 0
for idx, row in df_multi.iterrows():
    domains = row["domains"]
    if len(domains) == 1:
        if domains[0] == "attraction":
            unpacked_pred_multi = unpack_belief_states(row["preds"], "pred")
            unpacked_pred_multi = unpack_belief_states(row["preds"], "pred")
            print(domains[0])
            print("gold", row["gold_bs"])
            print("multi", unpacked_pred_multi)
            while df_single["domains"][single_idx][2:-2] != "attraction":
                single_idx += 1
            unpacked_pred_single = unpack_belief_states(df_single["preds"][single_idx], "pred")
            print("single", unpacked_pred_single)
            single_idx += 1
            print("=====")

attraction
gold {'Attraction-Inform': [['Area', 'east'], ['Type', 'entertainment']]}
multi ['area-east', 'type-entertainment']
single ['area-east']
=====
attraction
gold {'Attraction-Request': [['Addr', '?'], ['Post', '?']]}
multi ['addr-?', 'post-?']
single ['addr-?', 'post-?']
=====
attraction
gold {'Attraction-Request': [['Fee', '?']]}
multi ['fee-?']
single ['fee-?']
=====
couldn't load: {'area': none, 'postcode': none, 'reference': none, 'type': none, 'address': none, 'phone': none, 'price': none, 'name': none}
attraction
gold {'general-thank': [['none', 'none']]}
multi ['none-none']
couldn't load: {'price': none, 'address': none, 'reference': none, 'name': none, 'area': none, 'postcode': none, 'phone': none, 'type': none}
single ['none-none']
=====
attraction
gold {'Attraction-Inform': [['Name', 'williams art and antiques']]}
multi ['name-williams art and antiques']
single ['name-williams art and antiques']
=====
attraction
gold {'Attraction-Request': [['Area', '?'], ['Post', '?'

In [90]:
df_single["preds"]

0       {'departure': 'Saint John\'s college', 'destin...
1                                    {'leaveat': '17:15'}
2       As there is no mention of any slot in the last...
3       {'departure': None, 'leaveat': None, 'destinat...
4             {'food': 'Portuguese', 'area': 'Cambridge'}
                              ...                        
1054    As there is no new information provided by the...
1055    {'departure': 'unknown', 'leaveat': '11:00', '...
1056    SYSTEM: Where would you like to go?\nUSER: I n...
1057    I'm sorry, but I cannot generate the belief st...
1058    {'departure': None, 'leaveat': None, 'destinat...
Name: preds, Length: 1059, dtype: object

In [71]:
count

1059