In [1]:
import os
import ast
import json
import openai
import pandas as pd

from tqdm import tqdm
from langchain import PromptTemplate

from src.DST.evaluate_utils import unpack_belief_states
from src.DST.config import CONFIG


class PromptConstructor():
    def __init__(self, 
                 config):
        self.config = config
        self.instructions = config["INSTRUCTIONS"]
        self.prompt_templates = config["PROMPT_TEMPLATES"]
        
    def _get_slots_from_domains(self, domains, with_slot_description, with_req_inf_differentiation):
        # slot_description = self.config["slot_descrpition"]
        if not with_slot_description:
            if domains == "all":
                if with_req_inf_differentiation:
                    req_slots = ", ".join(self.config["multiwoz21"]["all_requestable_slots"])
                    inf_slots = ", ".join(self.config["multiwoz21"]["all_informable_slots"])
                else:
                    slots = set(self.config["multiwoz21"]["all_requestable_slots"] + 
                                self.config["multiwoz21"]["all_informable_slots"])
                    slots = ", ".join(slots)
            elif not isinstance(domains, list):
                raise ValueError("""Provided domain should be either 'all' or list of valid domain names:
                                    - for multiwoz2.1 and 2.4: taxi, restaurant, hotel, train, attraction 
                                    - for SGD: To-do""")
            else:
                req_slots = ""
                inf_slots = ""
                domain_req_slots = []
                domain_inf_slots = []
                for domain in domains:
                    domain_req_slots += self.config["multiwoz21"]["requestable_slots"][domain]
                    domain_inf_slots += self.config["multiwoz21"]["informable_slots"][domain]
                if with_req_inf_differentiation:
                    domain_req_slots = set(domain_req_slots)
                    domain_inf_slots = set(domain_inf_slots)
                    req_slots += ", ".join(domain_req_slots)
                    inf_slots += ", ".join(domain_inf_slots)
                else:
                    slots = set(domain_req_slots + domain_inf_slots)
                    slots = ", ".join(slots)

            if with_req_inf_differentiation:
                slots_info = f"Requestable slots: {req_slots}\nInformable slots: {inf_slots}"
            else:
                slots_info = f"{slots}"

        else:
            raise ValueError("Not Implemented Yet")
        
        return slots_info
    
    
    def _build_prompt(self, mode="", dialogue_context="", ontology="", slots="", dialogue_acts="", belief_states=""):
        prompt = ""
        if mode == "dst":
            instruction = self.instructions["instruction_with_slots"]
            template_variables = self.prompt_templates["template_with_slots"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     slots=slots,
                                     dialogue_context=dialogue_context)
            
        elif mode == "dst_recorrect":
            instruction = self.instructions["instruction_with_slots_recorrect"]
            template = self.prompt_templates["template_with_slots_recorrect"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])            
            prompt = template.format(instruction=instruction,
                                    slots=slots,
                                    dialogue_context=dialogue_context,
                                    belief_states=belief_states)
            
        elif mode == "database_query":
            instruction = self.instructions["instruction_query_database"]
            template = self.prompt_templates["template_query_database"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                    belief_states=belief_states)
            
        elif mode == "response_generation":
            instruction = self.instructions["instruction_response_generation"]
            template = self.prompt_templates["template_response_generation"]
            template = PromptTemplate(input_variables = template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                    dialogue_acts=dialogue_acts,
                                    dialogue_context=dialogue_context)
        elif mode == "dst_extracted_ontology":
            pass

        else:
            raise ValueError("'mode' should be one of: [dst, dst_recorrect, database_query, response_generation]")
        
        return prompt


class MWOZ_Dataset(PromptConstructor):
    def __init__(self,
                 config,
                 mwoz_path,
                 dialog_history_limit,
                 with_slot_description,
                 with_req_inf_differentiation,
                 single_domain_only):
        PromptConstructor.__init__(self, config)
        self.dataset = {"id":[],
                        "dialogue_id":[],
                        "dialogue_context":[],
                        "turn":[],
                        "prompt":[],
                        "domains":[],
                        "gold_bs":[],
                        "gold_act":[],
                        "gold_response":[],
                        "gold_database_result":[],
                        }
        self.all_data, self.testfiles = self._get_mwoz_data(mwoz_path)
        self.idx = 0
        self.dialog_history_limit = dialog_history_limit
        self.single_domain_only = single_domain_only
        self.with_slot_description = with_slot_description
        self.with_req_inf_differentiation = with_req_inf_differentiation

        print("Processing mwoz...")
        for sample in tqdm(self.all_data):
            if sample in self.testfiles:
                dialogue_log = self.all_data[sample]["log"]
                self._process_dialogue_log(sample=sample,
                                           dialogue_log=dialogue_log)

        self.dataset = pd.DataFrame(self.dataset)
        if single_domain_only:
            for index, row in tqdm(self.dataset.iterrows()):
                if len(row["domains"]) != 1:
                    self.dataset.drop(index, inplace=True)

    def _get_mwoz_data(self, mwoz_path):
        data_path = os.path.join(mwoz_path, "data.json")
        testListFile_path = os.path.join(mwoz_path, "testListFile.txt")

        with open(data_path, "r") as f:
            all_data = json.load(f)
            
        with open(testListFile_path, "r") as f:
            testfiles = f.read()
        testfiles = testfiles.split("\n")
        return all_data, testfiles
    
    def _process_dialogue_log(self, sample, dialogue_log):

        dialog_history_memory = []
        dialog_history = ""
        domains = self._get_domains_from_log(dialogue_log)
        slots = self._get_slots_from_domains(domains, 
                                             self.with_slot_description,
                                             self.with_req_inf_differentiation) # or all

        for turn_nb, turn in enumerate(dialogue_log):

            if turn_nb % 2 == 0:
                speaker = "USER"
            else:
                speaker = "SYSTEM"

            utterance = f"""{speaker}: {turn["text"]}\n"""
            dialogue_context = dialog_history + utterance
            dialog_act = turn["dialog_act"]
            prompt = self._build_prompt(mode="dst",
                                        slots=slots,
                                        dialogue_context=dialogue_context) 


            if self.dialog_history_limit != 0:
                if self.dialog_history_limit == -1:
                    self.dialog_history_limit = len(dialogue_log)

                if len(dialog_history_memory) >= self.dialog_history_limit:
                    dialog_history_memory.pop(0)
                dialog_history_memory.append(utterance)
                dialog_history = "".join(dialog_history_memory)

            self.idx += 1
            if turn_nb % 2 == 0:
                self.dataset["gold_bs"].append(dialog_act)
                self.dataset["dialogue_context"].append(dialogue_context)
                self.dataset["gold_database_result"].append(None) 
                self.dataset["turn"].append(turn_nb//2)
                self.dataset["domains"].append(domains)
                self.dataset["id"].append(self.idx//2)
                self.dataset["dialogue_id"].append(sample)
                self.dataset["prompt"].append(prompt)
            else:
                self.dataset["gold_response"].append(utterance)
                self.dataset["gold_act"].append(dialog_act)


    def _get_domains_from_log(self, dialogue_log):
        domains = []
        all_domains = ["restaurant", "taxi", "hotel", "train", "attraction"]
        for log in dialogue_log:
            for domain_act in log["dialog_act"]:
                domain = domain_act.split("-")[0].lower()
                if domain in all_domains and domain not in domains:
                    domains.append(domain)
        return domains
                
def evaluate_dst(results_df, vocal=True, save_path=None):
    global_turns = 0
    global_jga = 0
    results_single_domain = {"taxi":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "restaurant":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "hotel":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "train":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0},
                    "attraction":{"turns":0, "correct_turns_jga":0, "correct_slots":0, "total_slots":0}}
    
    for idx, row in results_df.iterrows():
        print(row["gold_bs"])
        unpacked_gold = unpack_belief_states(row["gold_bs"], "gold")
        unpacked_pred = unpack_belief_states(row["preds"], "pred")
        domains = row["domains"]
        if isinstance(domains, str):
            domains = ast.literal_eval(domains)

        if set(unpacked_gold)==set(unpacked_pred):
            global_jga += 1
            if len(domains) == 1:
                results_single_domain[domains[0]]["correct_turns_jga"] += 1
                gold_values = [gold.split("-")[1] for gold in unpacked_gold]
                pred_values = [pred.split("-")[1] for pred in unpacked_pred]
                for gold_value in gold_values:
                    if gold_value in pred_values:
                        results_single_domain[domains[0]]["correct_slots"] += 1
                    results_single_domain[domains[0]]["total_slots"] += 1

        if len(domains) == 1:
            gold_values = [gold.split("-")[1] for gold in unpacked_gold]
            pred_values = [pred.split("-")[1] for pred in unpacked_pred]
            for gold_value in gold_values:
                if gold_value in pred_values:
                    results_single_domain[domains[0]]["correct_slots"] += 1
                results_single_domain[domains[0]]["total_slots"] += 1
            results_single_domain[domains[0]]["turns"] += 1
        global_turns += 1

    total_single_domain_jga = 0
    total_single_domain_turns = 0
    for domain in results_single_domain:
        domain_correct_slots = results_single_domain[domain]["correct_slots"]
        domain_total_slots = results_single_domain[domain]["total_slots"]
        domain_jga = results_single_domain[domain]["correct_turns_jga"]
        domain_turns = results_single_domain[domain]["turns"]
        total_single_domain_jga += domain_jga
        total_single_domain_turns += domain_turns
        results_single_domain[domain]["JGA"] = domain_jga/domain_turns
        results_single_domain[domain]["SLOT-F1"] = domain_correct_slots/domain_total_slots
        if vocal:
            print(f"""For {domain}, JGA: {results_single_domain[domain]["JGA"]} - SLOT-F1: {results_single_domain[domain]["SLOT-F1"]}""")
    jga_single_domain_average = total_single_domain_jga/total_single_domain_turns
    jga_average = global_jga/global_turns            
    if vocal:
        print(f"""Average JGA in single domain samples only: {jga_single_domain_average}""")
        print(f"""Average JGA overall: {jga_average}""")

    results = results_single_domain
    results["JGA_single_domain_average"] = jga_single_domain_average
    results["JGA_average"] = jga_average

    if save_path:
        with open(save_path, "w") as f:
            json.dump(results, f, indent=4)

    return results

In [8]:
data = json.load(open("/home/willy/instructod/MultiWOZ_2.1/data.json", "r"))

In [2]:
mwoz_path = "/home/willy/instructod/MultiWOZ_2.1/"
dialog_history_limit = 20
with_slot_description = False
single_domain_only = False
with_req_inf_differentiation = False
mwoz_multi = MWOZ_Dataset(config=CONFIG, 
                    mwoz_path=mwoz_path,
                    dialog_history_limit=dialog_history_limit,
                    with_slot_description=with_slot_description,
                    with_req_inf_differentiation=with_req_inf_differentiation,
                    single_domain_only=single_domain_only)
dataset_multi = mwoz_multi.dataset

mwoz_path = "/home/willy/instructod/MultiWOZ_2.1/"
dialog_history_limit = 0
with_slot_description = False
single_domain_only = True
with_req_inf_differentiation = False
mwoz_single = MWOZ_Dataset(config=CONFIG, 
                    mwoz_path=mwoz_path,
                    dialog_history_limit=dialog_history_limit,
                    with_slot_description=with_slot_description,
                    with_req_inf_differentiation=with_req_inf_differentiation,
                    single_domain_only=single_domain_only)
dataset_single = mwoz_single.dataset

Processing mwoz...


100%|██████████| 10438/10438 [00:00<00:00, 10861.34it/s]


Processing mwoz...


100%|██████████| 10438/10438 [00:00<00:00, 11248.14it/s]
7372it [00:06, 1071.71it/s]


In [5]:
test = dataset_single["gold_bs"][0]
print(unpack_belief_states(test, "gold"))

['dest-pizza hut fen ditton', 'depart-saint john"s college']


In [40]:
from src.DST.evaluate_utils import remapping

In [42]:
turn = 0
logs = data["PMUL1635.json"]["log"]
for log in logs:
    turn += 1
    metadata = log["metadata"]
    bspn_dict = {}
    if not metadata:
        continue
    print(turn)
    for domain in metadata:
        slot_values = metadata[domain]["semi"]
        for slot in slot_values:
            value = slot_values[slot]
            if value and value not in ["not mentioned", "none"]:
                if domain in bspn_dict:
                    bspn_dict[domain].append(remapping(slot))
                    bspn_dict[domain].append(remapping(value))
                else:
                    bspn_dict[domain] = [remapping(slot), remapping(value)]
    bspn = " ".join([f"[{domain}] {' '.join(bspn_dict[domain])}" for domain in bspn_dict])
    print(bspn)
    print("-------")

        

2
[hotel] area east stars 4
-------
4
[hotel] area east parking yes stars 4 internet yes
-------
6
[hotel] name wartworth area east parking yes stars 4 internet yes
-------
8
[hotel] name wartworth area east parking yes stars 4 internet yes
-------
10
[hotel] name wartworth area east parking yes stars 4 internet yes [train] dest bishops stortford day friday depart cambridge
-------
12
[hotel] name wartworth area east parking yes stars 4 internet yes [train] dest bishops stortford day friday arrive 19:45 depart cambridge
-------
14
[hotel] name wartworth area east parking yes stars 4 internet yes [train] dest bishops stortford day friday arrive 19:45 depart cambridge
-------
16
[hotel] name wartworth area east parking yes stars 4 internet yes [train] dest bishops stortford day friday arrive 19:45 depart cambridge
-------
18
[hotel] name wartworth area east parking yes stars 4 internet yes [train] dest bishops stortford day friday arrive 19:45 depart cambridge
-------


In [11]:
for log in data["SNG01856.json"]["log"]:
    if log["metadata"]:
        

[{'text': 'am looking for a place to to stay that has cheap price range it should be in a type of hotel',
  'metadata': {},
  'dialog_act': {'Hotel-Inform': [['Type', 'hotel'], ['Price', 'cheap']]},
  'span_info': [['Hotel-Inform', 'Type', 'hotel', 20, 20],
   ['Hotel-Inform', 'Price', 'cheap', 10, 10]]},
 {'text': 'Okay, do you have a specific area you want to stay in?',
  'metadata': {'taxi': {'book': {'booked': []},
    'semi': {'leaveAt': '',
     'destination': '',
     'departure': '',
     'arriveBy': ''}},
   'police': {'book': {'booked': []}, 'semi': {}},
   'restaurant': {'book': {'booked': [], 'time': '', 'day': '', 'people': ''},
    'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}},
   'hospital': {'book': {'booked': []}, 'semi': {'department': ''}},
   'hotel': {'book': {'booked': [], 'stay': '', 'day': '', 'people': ''},
    'semi': {'name': 'not mentioned',
     'area': 'not mentioned',
     'parking': 'not mentioned',
     'pricerange': 'cheap',
     'sta