In [12]:
import json
import pandas as pd

def get_df_from_path(path):
    data_all = json.load(open(path, 'r'))
    set_all_attr = set(data_all[0].keys())
    assume_all_full = True
    count_less = 0
    for data in data_all:
        if (len(set(data.keys())-set_all_attr) != 0):
            assume_all_full = False
            count_less += 1
    if not assume_all_full:
        print(path)
        print(f'{count_less} these many instances have less attributes out of {len(data_all)}')
    d = {}
    for attr in set_all_attr:
        d[attr] = []
    for data in data_all:
        for attr in set_all_attr:
            if attr in data:
                d[attr].append(data[attr])
            else:
                d[attr].append('None')
    return pd.DataFrame(d)


In [None]:
## TODO: Change the path to the json files
df_hotel = get_df_from_path('/path/to/hotel_db.json')
df_attraction = get_df_from_path('/path/to/attraction_db.json')
df_restaurant = get_df_from_path('/path/to/restaurant_db.json')
df_train = get_df_from_path('/ path/to/train_db.json')

In [None]:
df_hotel.head()

In [None]:
df_restaurant.head()

In [11]:
import sys
sys.path.append("MultiWOZ_Evaluation/mwzeval/")

from normalization import normalize_state_slot_value, time_str_to_minutes


In [87]:
import os
import json
from fuzzywuzzy import fuzz


class MultiWOZVenueDatabase:
    
    IGNORE_VALUES = {
        'attraction' : ['location', 'openhours'],
        'hotel' : ['location', 'price', 'takesbookings'],
        'restaurant' : ['location', 'introduction', 'signature']
    }

    FUZZY_KEYS = {
        'hotel' : {'name'},
        'attraction' : {'name'},
        'restaurant' : {'name', 'food'},
        'train' : {'departure', 'destination'}
    }

    def __init__(self):
        self.data, self.data_keys = self._load_data()

    def _load_data(self):

        def normalize_column_name(name):
            if name == "id":
                return name
            name = name.lower().replace(' ', '')
            if name == "arriveby": return "arrive"
            if name == "leaveat": return "leave"
            return name

        database_data, database_keys = {}, {}
        
        for domain in ["restaurant", "attraction", "hotel", "train"]:
           
            dir_path = 'MultiWOZ_Evaluation/mwzeval/'
            with open(os.path.join(dir_path, "data", "database", f"{domain}_db.json"), "r") as f:
                database_data[domain] = json.load(f)
            
            if domain in self.IGNORE_VALUES:
                for i in database_data[domain]:
                    for ignore in self.IGNORE_VALUES[domain]:
                        if ignore in i:
                            i.pop(ignore)

            for i, database_item in enumerate(database_data[domain]):
                database_data[domain][i] =  {normalize_column_name(k) : v for k, v in database_item.items()}
            
            database_keys[domain] = set(database_data[domain][0].keys())
            
        return database_data, database_keys

    def query(self, domain, constraints, fuzzy_ratio=90):

        # Hotel database keys:      address, area, name, phone, postcode, pricerange, type, internet, parking, stars (other are ignored)
        # Attraction database keys: address, area, name, phone, postcode, pricerange, type, entrance fee (other are ignored)
        # Restaurant database keys: address, area, name, phone, postcode, pricerange, type, food 
        # Train database contains keys: arriveby, departure, day, leaveat, destination, trainid, price, duration
        
        results = []
        
        if domain not in ["hotel", "restaurant", "attraction", "train"]:
            return results
        
        query = {}
        for key in self.data_keys[domain]:  
            if key in constraints:
                if constraints[key] in ["dontcare", "not mentioned", "don't care", "dont care", "do n't care", "do not care"]:
                    continue
                query[key] = normalize_state_slot_value(key, constraints[key])
                if key in ['arrive', 'leave']:
                    query[key] = time_str_to_minutes(query[key])
            else:
                query[key] = None

        count = 0
        saved_item = None           
        for i, item in enumerate(self.data[domain]):
            
            for k, v in query.items():
                if v is None or item[k] == '?':
                    continue

                if k == 'arrive':
                    time = time_str_to_minutes(item[k]) 
                    if time > v:
                        break
                elif k == 'leave':
                    time = time_str_to_minutes(item[k]) 
                    if time < v:
                        break
                else:
                    if k in self.FUZZY_KEYS.get(domain, {}):
                        f = (lambda x: fuzz.partial_ratio(item[k], x) < fuzzy_ratio)
                    else:
                        f = (lambda x: item[k] != x)
                    if f(v):
                        break
            else:
                count += 1
                if domain == "train":
                    results.append(item["trainid"])
                    if count == 1:
                        saved_item = item
                    if count <= 5:
                        print(item['trainid'])
                        for attr in item.keys():
                            print ("\t", attr, '-', item[attr])
                        print()
                else:
                    results.append(item["id"])
                    if count == 1:
                        saved_item = item
                    if count <= 5:
                        print("\t", "<hotel>")
                        for attr in item.keys():
                            print ("\t", attr, '-',item[attr])
                        print("\t", "</hotel>")
                        print()
        # print("Total number of results", count)
        return count, saved_item

In [88]:
mwdb = MultiWOZVenueDatabase()

In [None]:
l = mwdb.query("train", {}, fuzzy_ratio=90)
print(l)

In [None]:
df_restaurant[df_restaurant['id'].isin(l)]

In [71]:
import pickle
with open("sav_text_prompts.pickle", 'rb') as f:
    test_data = pickle.load(f)


In [None]:
test_data.keys()

In [73]:
dict_to_prompt = {}
for i in range(len(test_data['diag_id'])):
    if f"{test_data['diag_id'][i].lower().split('.')[0]}_{test_data['conv_idx'][i]}" not in dict_to_prompt:
        dict_to_prompt[f"{test_data['diag_id'][i].lower().split('.')[0]}_{test_data['conv_idx'][i]}"] = test_data['prompts'][i]

In [None]:
dict_to_prompt

In [75]:
import json
# Change the path to the predictions.json file
pred_path = "/path/to/predictions.json"

with open(pred_path, 'r') as f:
    pred_data = json.load(f)

In [76]:
def get_domain_estimates_from_state(data):

    for dialog in data.values():

        # Use an approximation of the current domain because the slot names used for delexicalization do not contain any
        # information about the domain they belong to. However, it is likely that the system talks about the same domain
        # as the domain that recently changed in the dialog state (which should be probably used for the possible lexicalization). 
        # Moreover, the usage of the domain removes a very strong assumption done in the original evaluation script assuming that 
        # all requestable slots are mentioned only and exactly for one domain (through the whole dialog).

        current_domain = None
        old_state = {}
        old_changed_domains = []

        for turn in dialog:
 
            # Find all domains that changed, i.e. their set of slot name, slot value pairs changed.
            changed_domains = []
            for domain in turn["state"]:
                domain_state_difference = set(turn["state"].get(domain, {}).items()) - set(old_state.get(domain, {}).items())
                if len(domain_state_difference) > 0:
                    changed_domains.append(domain)

            # Update the current domain with the domain whose state currently changed, if multiple domains were changed then:
            # - if the old current domain also changed, let the current domain be
            # - if the old current domain did not change, overwrite it with the changed domain with most filled slots
            # - if there were multiple domains in the last turn and we kept the old current domain & there are currently no changed domains, use the other old domain
            if len(changed_domains) == 0:
                if current_domain is None:
                    turn["active_domains"] = []
                    continue 
                else:
                    if len(old_changed_domains) > 1:
                        old_changed_domains = [x for x in old_changed_domains if x in turn["state"] and x != current_domain]
                        if len(old_changed_domains) > 0:
                            current_domain = old_changed_domains[0] 

            elif current_domain not in changed_domains:
                current_domain = max(changed_domains, key=lambda x: len(turn["state"][x]))

            old_state = turn["state"]
            old_changed_domains = changed_domains
            
            turn["active_domains"] = [current_domain]


get_domain_estimates_from_state(pred_data)



In [77]:
# read_ground_truth
gt_response = json.load(open('MultiWOZ_Evaluation/mwzeval/data/references/mwz22.json', 'r'))

In [None]:
gt_response

In [None]:
for diag_id in pred_data:
    for i in range(len(pred_data[diag_id])):
        if len(pred_data[diag_id][i]['active_domains']) > 0 and pred_data[diag_id][i]['active_domains'][0] in ['restaurant']:
            print("=="*20)
            print("Active:",pred_data[diag_id][i]['active_domains'][0])
            print(diag_id, i)
            print("Inputs:")
            print(dict_to_prompt[f"{diag_id}_{i+1}"])
            print("Outputs:")
            print(gt_response[diag_id][i])
            if pred_data[diag_id][i]['active_domains'][0] in pred_data[diag_id][i]['state']:
                search_state = {}
                fuzz_ratio = 90
                ## Make this on only for restaurants
                # if 'name' in pred_data[diag_id][i]['state'][pred_data[diag_id][i]['active_domains'][0]]:
                #     search_state = {'name': pred_data[diag_id][i]['state'][pred_data[diag_id][i]['active_domains'][0]]['name'].replace('b and b', 'bed and breakfast')}
                #     fuzz_ratio = 75
                # else:
                search_state = pred_data[diag_id][i]['state'][pred_data[diag_id][i]['active_domains'][0]]
                print(search_state)
                temp_num, _ = mwdb.query(pred_data[diag_id][i]['active_domains'][0], search_state, fuzzy_ratio=fuzz_ratio)
                print("Number of venus:", temp_num)
                category = ""
                if temp_num >= 10:
                    category = "BIG"
                elif temp_num >= 5:
                    category = "MEDIUM"
                elif temp_num > 1:
                    category = "SMALL"
                suggestion_made = False
                if "NAME" in gt_response[diag_id][i] or "TRAINID" in gt_response[diag_id][i]:
                    suggestion_made = True
                
                if not suggestion_made:
                    print(f"{category}, without suggestion")
                else:
                    print(f"{category}, with suggestion")


In [98]:
# generate the restaurant test data with updated information
import pandas as pd
domain_create_test = 'taxi'
df_test_rest = pd.read_excel(f'domain_agents/{domain_create_test}/{domain_create_test}_test.xlsx')

In [None]:
df_test_rest.head()

In [100]:
updated_prompts = []
for ittr, row in df_test_rest.iterrows():
    # print(row['prompts'])
    diag_id_converted = row['diag_id'].split('.')[0].lower()
    utter_id = row['conv_idx'] - 1
    ###TAXI
    convereted_prompt = row['prompts'] + "\n"
    # if domain_create_test in pred_data[diag_id_converted][utter_id]['state']:
    #     search_state = {}
    #     fuzz_ratio = 90
    #     # if 'name' in pred_data[diag_id_converted][utter_id]['state'][domain_create_test]:
    #     #     search_state = {'name': pred_data[diag_id_converted][utter_id]['state'][domain_create_test]['name'].replace('b and b', 'bed and breakfast')}
    #     #     fuzz_ratio = 75
    #     # else:
    #     search_state = pred_data[diag_id_converted][utter_id]['state'][domain_create_test]
    #     num_venues, first_venue = mwdb.query(domain_create_test, search_state, fuzzy_ratio=fuzz_ratio)
    #     # print(num_venues, first_venue)
    #     # print(pred_data[diag_id_converted][utter_id]['state']['restaurant'])
    #     convereted_prompt = row['prompts']
    #     convereted_prompt += f"\nNumber of {domain_create_test}s that meet the user's criteria: {num_venues}\n"
    #     if first_venue is not None:
    #         convereted_prompt += "One of them is the following:\n"
    #         convereted_prompt += f"\t <{domain_create_test}> \n"
    #         for key, value in first_venue.items():
    #             convereted_prompt += f"\t {key} - {value}\n"
    #         convereted_prompt += f"\t </{domain_create_test}>\n"
    #     else:
    #         convereted_prompt += "One of them is the following:\n"
    # else:
    #     num_venues, first_venue = mwdb.query(domain_create_test, {})

    #     convereted_prompt = row['prompts']
    #     convereted_prompt += f"\nNumber of {domain_create_test}s that meet the user's criteria: {num_venues}\n"
    #     if first_venue is not None:
    #         convereted_prompt += "One of them is the following:\n"
    #         convereted_prompt += f"\t <{domain_create_test}> \n"
    #         for key, value in first_venue.items():
    #             convereted_prompt += f"\t {key} - {value}\n"
    #         convereted_prompt += f"\t </{domain_create_test}>\n"
    #     else:
    #         convereted_prompt += "One of them is the following:\n"
    
    updated_prompts.append(convereted_prompt)



    

In [None]:
for p in updated_prompts:
    print(p)
    print("===="*20)

In [102]:
df_test_rest_updated = df_test_rest.copy()
df_test_rest_updated['prompts'] = updated_prompts


In [103]:
df_test_rest_updated.to_excel(f'domain_agents/{domain_create_test}/{domain_create_test}_test_updated.xlsx')