In [53]:
import os
import ast
import json
import openai
import pandas as pd

from tqdm import tqdm
from langchain import PromptTemplate

from src.DST.evaluate_utils import unpack_belief_states, remapping, compute_prf
from src.DST.config import CONFIG


class PromptConstructor():
    def __init__(self, 
                 config):
        self.config = config
        self.instructions = config["INSTRUCTIONS"]
        self.prompt_templates = config["PROMPT_TEMPLATES"]
        
    def _get_slots_from_domains(self, domains, with_slot_description, with_req_inf_differentiation, with_all_slots):
        # slot_description = self.config["slot_descrpition"]
        if with_all_slots:
            domains = "all"
        
        if with_slot_description:
            with_req_inf_differentiation = False #Slot description is the discriminator

        if domains == "all":
            if with_req_inf_differentiation:
                req_slots = ", ".join(self.config["multiwoz21"]["all_requestable_slots"])
                inf_slots = ", ".join(self.config["multiwoz21"]["all_informable_slots"])
            else:
                slots = set(self.config["multiwoz21"]["all_requestable_slots"] + 
                            self.config["multiwoz21"]["all_informable_slots"])
                slots = ", ".join(slots)
        elif not isinstance(domains, list):
            raise ValueError("""Provided domain should be either 'all' or list of valid domain names:
                                - for multiwoz2.1 and 2.4: taxi, restaurant, hotel, train, attraction 
                                - for SGD: To-do""")
        else:
            req_slots = ""
            inf_slots = ""
            domain_req_slots = []
            domain_inf_slots = []
            for domain in domains:
                domain_req_slots += self.config["multiwoz21"]["requestable_slots"][domain]
                domain_inf_slots += self.config["multiwoz21"]["informable_slots"][domain]
            if with_req_inf_differentiation:
                domain_req_slots = set(domain_req_slots)
                domain_inf_slots = set(domain_inf_slots)
                req_slots += ", ".join(domain_req_slots)
                inf_slots += ", ".join(domain_inf_slots)
            else:
                slots = set(domain_req_slots + domain_inf_slots)
                slots = ", ".join(slots)

        if with_req_inf_differentiation:
            slots_info = f"Requestable slots: {req_slots}\nInformable slots: {inf_slots}"
        else:
            slots_info = f"{slots}"

        if with_slot_description:
            slots = slots.split(", ")
            slots_info = ""
            for slot in slots:
                if slot not in self.config["multiwoz21"]["all_informable_slots"]:
                    continue
                slots_info += f"name: {slot}, description: {SLOTS_DESCRIPTIONS[slot]}\n"
            slots_info = slots_info[:-2]
        
        return slots_info
    
    
    def _build_prompt(self, mode="", dialogue_context="", ontology="", slots="", dialogue_acts="", belief_states=""):
        prompt = ""
        if mode == "dst":
            instruction = self.instructions["instruction_with_slots"]
            template_variables = self.prompt_templates["template_with_slots"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     slots=slots,
                                     dialogue_context=dialogue_context)
            
        elif mode == "dst_recorrect":
            instruction = self.instructions["instruction_with_slots_recorrect"]
            template = self.prompt_templates["template_with_slots_recorrect"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])            
            prompt = template.format(instruction=instruction,
                                    slots=slots,
                                    dialogue_context=dialogue_context,
                                    belief_states=belief_states)
            
        elif mode == "database_query":
            instruction = self.instructions["instruction_query_database"]
            template = self.prompt_templates["template_query_database"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                    belief_states=belief_states)
            
        elif mode == "response_generation":
            instruction = self.instructions["instruction_response_generation"]
            template = self.prompt_templates["template_response_generation"]
            template = PromptTemplate(input_variables = template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                    dialogue_acts=dialogue_acts,
                                    dialogue_context=dialogue_context)
        elif mode == "dst_extracted_ontology":
            pass

        else:
            raise ValueError("'mode' should be one of: [dst, dst_recorrect, database_query, response_generation]")
        
        return prompt


class MWOZ_Dataset(PromptConstructor):
    def __init__(self,
                 config,
                 mwoz_path,
                 dialog_history_limit,
                 with_slot_description,
                 with_req_inf_differentiation,
                 single_domain_only,
                 with_all_slots):
        PromptConstructor.__init__(self, config)
        self.dataset = {"id":[],
                        "dialogue_id":[],
                        "dialogue_context":[],
                        "turn":[],
                        "prompt":[],
                        "domains":[],
                        "gold_turn_bs":[],
                        "gold_bs":[],
                        "gold_act":[],
                        "gold_response":[],
                        "gold_database_result":[],
                        }
        self.all_data, self.testfiles = self._get_mwoz_data(mwoz_path)
        self.idx = 0
        self.dialog_history_limit = dialog_history_limit
        self.single_domain_only = single_domain_only
        self.with_slot_description = with_slot_description
        self.with_req_inf_differentiation = with_req_inf_differentiation
        self.with_all_slots = with_all_slots

        print("Processing mwoz...")
        for sample in tqdm(self.all_data):
            if sample in self.testfiles:
                dialogue_log = self.all_data[sample]["log"]
                self._process_dialogue_log(sample=sample,
                                           dialogue_log=dialogue_log)

        self.dataset = pd.DataFrame(self.dataset)
        if single_domain_only:
            for index, row in tqdm(self.dataset.iterrows()):
                if len(row["domains"]) != 1:
                    self.dataset.drop(index, inplace=True)

    def _get_mwoz_data(self, mwoz_path):
        data_path = os.path.join(mwoz_path, "data.json")
        testListFile_path = os.path.join(mwoz_path, "testListFile.txt")

        with open(data_path, "r") as f:
            all_data = json.load(f)
            
        with open(testListFile_path, "r") as f:
            testfiles = f.read()
        testfiles = testfiles.split("\n")
        return all_data, testfiles
    
    def _process_dialogue_log(self, sample, dialogue_log):

        dialog_history_memory = []
        dialog_history = ""
        domains = self._get_domains_from_log(dialogue_log)
        slots = self._get_slots_from_domains(domains, 
                                             self.with_slot_description,
                                             self.with_req_inf_differentiation,
                                             self.with_all_slots) # or all

        for turn_nb, turn in enumerate(dialogue_log):

            if turn_nb % 2 == 0:
                speaker = "USER"
            else:
                speaker = "SYSTEM"

            utterance = f"""{speaker}: {turn["text"]}\n"""
            dialogue_context = dialog_history + utterance
            dialog_act = turn["dialog_act"]
            prompt = self._build_prompt(mode="dst",
                                        slots=slots,
                                        dialogue_context=dialogue_context) 


            if self.dialog_history_limit != 0:
                if self.dialog_history_limit == -1:
                    self.dialog_history_limit = len(dialogue_log)

                if len(dialog_history_memory) >= self.dialog_history_limit:
                    dialog_history_memory.pop(0)
                dialog_history_memory.append(utterance)
                dialog_history = "".join(dialog_history_memory)

            metadata = turn["metadata"]
            bspn_dict = {}
            if metadata:
                for domain in metadata:
                    slot_values = metadata[domain]["semi"]
                    for slot in slot_values:
                        value = slot_values[slot]
                        if value and value not in ["not mentioned", "none"]:
                            if domain in bspn_dict:
                                bspn_dict[domain].append(remapping(slot))
                                bspn_dict[domain].append(remapping(value))
                            else:
                                bspn_dict[domain] = [remapping(slot), remapping(value)]
                bspn = " ".join([f"[{domain}] {' '.join(bspn_dict[domain])}" for domain in bspn_dict])

            self.idx += 1
            if turn_nb % 2 == 0:
                self.dataset["gold_turn_bs"].append(dialog_act)
                self.dataset["dialogue_context"].append(dialogue_context)
                self.dataset["gold_database_result"].append(None) 
                self.dataset["turn"].append(turn_nb//2)
                self.dataset["domains"].append(domains)
                self.dataset["id"].append(self.idx//2)
                self.dataset["dialogue_id"].append(sample)
                self.dataset["prompt"].append(prompt)
            else:
                self.dataset["gold_response"].append(utterance)
                self.dataset["gold_bs"].append(bspn)
                self.dataset["gold_act"].append(dialog_act)


    def _get_domains_from_log(self, dialogue_log):
        domains = []
        all_domains = ["restaurant", "taxi", "hotel", "train", "attraction"]
        for log in dialogue_log:
            for domain_act in log["dialog_act"]:
                domain = domain_act.split("-")[0].lower()
                if domain in all_domains and domain not in domains:
                    domains.append(domain)
        return domains
                
def evaluate_dst(results_df, vocal=True, save_path=None):
    global_turns = 0    
    global_jga = 0
    results_single_domain = {"taxi":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0, "slot_f1":0},
                            "restaurant":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0, "slot_f1":0},
                            "hotel":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0, "slot_f1":0},
                            "train":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0, "slot_f1":0},
                            "attraction":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0, "slot_f1":0},
                            "all":{"global_turns":0, "global_f1":0}}
    
    for _, row in results_df.iterrows():
        unpacked_gold = unpack_belief_states(row["gold_bs"], "gold")
        unpacked_pred = unpack_belief_states(row["preds"], "pred")
        domains = row["domains"]
        if isinstance(domains, str):
            domains = ast.literal_eval(domains)

        if set(unpacked_gold)==set(unpacked_pred):
            global_jga += 1
            if len(domains) == 1:
                results_single_domain[domains[0]]["correct_turns_jga"] += 1

        gold_values = [gold.split("-")[1] for gold in unpacked_gold]
        pred_values = [pred.split("-")[1] for pred in unpacked_pred]
        F1, recall, precision = compute_prf(gold_values, pred_values)
        if len(domains) == 1:
            results_single_domain[domains[0]]["slot_f1"] += F1
            results_single_domain[domains[0]]["turns"] += 1
        results_single_domain["all"]["global_f1"] += F1
        results_single_domain["all"]["global_turns"] += 1
        global_turns += 1

    total_single_domain_jga = 0
    total_single_domain_turns = 0
    for domain in results_single_domain:
        if domain == "all":
            continue
        domain_slot_f1 = results_single_domain[domain]["slot_f1"]
        domain_jga = results_single_domain[domain]["correct_turns_jga"]
        domain_turns = results_single_domain[domain]["turns"]
        total_single_domain_jga += domain_jga
        total_single_domain_turns += domain_turns
        results_single_domain[domain]["JGA"] = domain_jga/domain_turns
        results_single_domain[domain]["SLOT-F1"] = domain_slot_f1/domain_turns

        if vocal:
            print(f"""For {domain}, JGA: {results_single_domain[domain]["JGA"]} - SLOT-F1: {results_single_domain[domain]["SLOT-F1"]}""")
    jga_single_domain_average = total_single_domain_jga/total_single_domain_turns
    jga_average = global_jga/global_turns    
    slot_f1_average = results_single_domain["all"]["global_f1"] / results_single_domain["all"]["global_turns"]
    if vocal:
        print(f"""Average JGA in single domain samples only: {jga_single_domain_average}""")
        print(f"""Average JGA overall: {jga_average}""")
        print(f"""Average Slot F1 Overall: {slot_f1_average}""")

    results = results_single_domain
    results["JGA_single_domain_average"] = jga_single_domain_average
    results["JGA_average"] = jga_average

    return results

In [8]:
data = json.load(open("/home/willy/instructod/MultiWOZ_2.1/data.json", "r"))

In [51]:
mwoz_path = "/home/willy/instructod/MultiWOZ_2.1/"
dialog_history_limit = 20
with_slot_description = False
single_domain_only = False
with_req_inf_differentiation = False
with_all_slots = True
mwoz_multi = MWOZ_Dataset(config=CONFIG, 
                        mwoz_path=mwoz_path,
                        dialog_history_limit=dialog_history_limit,
                        with_slot_description=with_slot_description,
                        with_req_inf_differentiation=with_req_inf_differentiation,
                        with_all_slots=with_all_slots,
                        single_domain_only=single_domain_only)
dataset_multi = mwoz_multi.dataset

mwoz_path = "/home/willy/instructod/MultiWOZ_2.1/"
dialog_history_limit = 0
with_slot_description = False
single_domain_only = True
with_req_inf_differentiation = False
with_all_slots = True
mwoz_single = MWOZ_Dataset(config=CONFIG, 
                        mwoz_path=mwoz_path,
                        dialog_history_limit=dialog_history_limit,
                        with_slot_description=with_slot_description,
                        with_req_inf_differentiation=with_req_inf_differentiation,
                        with_all_slots=with_all_slots,
                        single_domain_only=single_domain_only)
dataset_single = mwoz_single.dataset

Processing mwoz...


100%|██████████| 10438/10438 [00:01<00:00, 9160.58it/s]


Processing mwoz...


100%|██████████| 10438/10438 [00:01<00:00, 9615.21it/s]
7372it [00:07, 1001.83it/s]


In [65]:
data.keys()

dict_keys(['SNG01856.json', 'SNG0129.json', 'PMUL1635.json', 'MUL2168.json', 'SNG0073.json', 'SNG01445.json', 'MUL2105.json', 'PMUL1690.json', 'MUL2395.json', 'SNG0190.json', 'PMUL1170.json', 'SNG01741.json', 'PMUL4899.json', 'MUL2261.json', 'SSNG0348.json', 'MUL0784.json', 'MUL0886.json', 'PMUL2512.json', 'SNG0548.json', 'MUL1474.json', 'PMUL4372.json', 'PMUL4047.json', 'PMUL1181.json', 'PMUL0287.json', 'PMUL3470.json', 'PMUL0151.json', 'MUL0586.json', 'PMUL3552.json', 'PMUL1539.json', 'MUL1790.json', 'PMUL3021.json', 'SNG0699.json', 'SNG0228.json', 'PMUL3296.json', 'MUL1434.json', 'PMUL2203.json', 'PMUL3250.json', 'PMUL0510.json', 'MUL1124.json', 'PMUL3719.json', 'PMUL4648.json', 'PMUL2437.json', 'SNG0297.json', 'PMUL2049.json', 'SNG01722.json', 'PMUL2100.json', 'MUL1853.json', 'MUL2694.json', 'SNG1006.json', 'SNG1345.json', 'MUL1299.json', 'MUL1490.json', 'PMUL2749.json', 'PMUL2804.json', 'MUL1628.json', 'PMUL2202.json', 'SNG01450.json', 'SNG0131.json', 'SNG0984.json', 'PMUL1419.jso

In [74]:
logs = data["SNG0073.json"]["log"]

turn = 0
for log in logs:
    turn += 1
    metadata = log["metadata"]
    bspn_dict = {}
    if not metadata:
        continue
    print(turn)
    for domain in metadata:
        slot_values = metadata[domain]["semi"]
        for slot in slot_values:
            value = slot_values[slot]
            if value and value not in ["not mentioned", "none"]:
                if domain in bspn_dict:
                    bspn_dict[domain].append(remapping(slot))
                    bspn_dict[domain].append(remapping(value))
                else:
                    bspn_dict[domain] = [remapping(slot), remapping(value)]
    bspn = " ".join([f"[{domain}] {' '.join(bspn_dict[domain])}" for domain in bspn_dict])
    print(bspn)
    print("-------")


2
[taxi] dest pizza hut fenditton depart saint johns college
-------
4
[taxi] leave 17:15 dest pizza hut fenditton depart saint johns college
-------
6
[taxi] leave 17:15 dest pizza hut fenditton depart saint johns college
-------
8
[taxi] leave 17:15 dest pizza hut fenditton depart saint johns college
-------


In [54]:
df_results = pd.read_csv("/home/willy/instructod/src/DST/results_single/gpt-3.5-turbo_0-end_debugFalse_singleDomainOnlyTrue_withSlotDescriptionTrue_withSlotDifferentiationFalse_withAllSlotsTrue_dialogHistoryLimit0_prompt3.csv")
        

In [55]:
df_results.head(2)

Unnamed: 0.1,Unnamed: 0,id,dialogue_id,dialogue_context,turn,prompt,domains,gold_bs,gold_act,gold_response,gold_database_result,model_used,preds,completion_info
0,0,0,SNG0073.json,USER: I would like a taxi from Saint John's co...,0,Generate the belief state of the very last dia...,['taxi'],"{'Taxi-Inform': [['Dest', 'pizza hut fen ditto...","{'Taxi-Request': [['Leave', '?'], ['Arrive', '...",SYSTEM: What time do you want to leave and wha...,,gpt-3.5-turbo,"{'departure': 'Saint John\'s college', 'destin...","{\n ""choices"": [\n {\n ""finish_reason..."
1,1,1,SNG0073.json,USER: I want to leave after 17:15.\n,1,Generate the belief state of the very last dia...,['taxi'],"{'Taxi-Inform': [['Leave', '17:15']]}","{'Taxi-Inform': [['Car', 'blue honda'], ['Phon...",SYSTEM: \nBooking completed! your taxi will be...,,gpt-3.5-turbo,{'leaveat': 'after 17:15'},"{\n ""choices"": [\n {\n ""finish_reason..."


In [84]:
# dataset_single["gold_bs"]
df_results["preds"]

0       {'departure': 'Saint John\'s college', 'destin...
1                              {'leaveat': 'after 17:15'}
2       {'people': None, 'departure': None, 'type': No...
3       {'name': None, 'departure': None, 'type': None...
4             {'food': 'Portuguese', 'area': 'Cambridge'}
                              ...                        
1054    {'name': None, 'departure': None, 'type': None...
1055    {'departure': 'current_location', 'destination...
1056    I'm sorry, but there is no SYSTEM response pro...
1057    As there is no previous conversation provided,...
1058    {'name': None, 'departure': None, 'type': None...
Name: preds, Length: 1059, dtype: object

In [88]:
#single domain only

prev_dialogue_id = ""
L = len(df_results)
for i in range(L):
    ids = df_results["id"][i]
    pred, gold = df_results["preds"][i], dataset_single["gold_bs"][ids]
    dialogue_context = df_results["dialogue_context"][i]
    cur_dialogue_id = df_results["dialogue_id"][i]
    if prev_dialogue_id != cur_dialogue_id:
        bspn_pred = {}
    pred = unpack_belief_states(pred, "pred")
    if "none-none" not in pred:
        for slot_value in pred:
            slot, value = slot_value.split("-")
            bspn_pred[slot] = value

    print("context:", dialogue_context)
    print("pred", bspn_pred)
    print("gold", gold)
    prev_dialogue_id = cur_dialogue_id
    print("==============")

context: USER: I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.

pred {'depart': 'saint john"s college', 'dest': 'pizza hut fen ditton'}
gold [taxi] dest pizza hut fenditton depart saint johns college
context: USER: I want to leave after 17:15.

pred {'depart': 'saint john"s college', 'dest': 'pizza hut fen ditton', 'leave': '17:15'}
gold [taxi] leave 17:15 dest pizza hut fenditton depart saint johns college
context: USER: Thank you for all the help! I appreciate it.

pred {'depart': 'saint john"s college', 'dest': 'pizza hut fen ditton', 'leave': '17:15'}
gold [taxi] leave 17:15 dest pizza hut fenditton depart saint johns college
context: USER: No, I am all set.  Have a nice day.  Bye.

pred {'depart': 'saint john"s college', 'dest': 'pizza hut fen ditton', 'leave': '17:15'}
gold [taxi] leave 17:15 dest pizza hut fenditton depart saint johns college
context: USER: Are there any Portuguese restaurants in Cambridge?

pred {'food': 'portuguese', 'area': 'cambridge'}

ValueError: too many values to unpack (expected 2)