In [162]:
from datasets import load_dataset

dataset = load_dataset("multi_woz_v22")
train_data = dataset['train']
val_data = dataset['validation']
test_data = dataset['test']

In [163]:
def filterDomains(data):
    
    return [entry for entry in data if set(entry["services"]).issubset({"restaurant", "hotel", "booking"}) and len(entry["services"]) == 1]
# Only keep dialogues related to Restaurants or Hotels.

train_data_filtered = filterDomains(train_data)
val_data_filtered = filterDomains(val_data)
test_data_filtered = filterDomains(test_data)

In [164]:
def create_augmented_dialogue_data(dataset, print_dialogue=False):
    """
    Augment the dataset with the following information:
    - Information to be retrieved (ground truth)
    
    Heavily inspired by the code from the evaluation script.
    """
    
    for dialogue in dataset:
        turns = dialogue["turns"]
        turns["to_be_retrieved_ground_truth"] = {turn_id: [] for turn_id in range(len(turns["turn_id"]))}
        
        
        for turn_id, _ in enumerate(turns["turn_id"]):
            # If it is SYSTEM turn:
            if turns["speaker"][turn_id]:
                slot_names_per_act = [slot['slot_name'] for slot in turns['dialogue_acts'][turn_id]['dialog_act']['act_slots']]
                slot_values_per_act = [slot['slot_value'] for slot in turns['dialogue_acts'][turn_id]['dialog_act']['act_slots']]
                dialogue_acts = turns['dialogue_acts'][turn_id]['dialog_act']['act_type']
                services = turns['frames'][turn_id]['service']
                current_booking_service = [service for service in services if service in ["hotel", "restaurant"]]

                to_be_retrieved_ground_truth = []
                for act_i in range(len(slot_names_per_act)):
                    domain = dialogue_acts[act_i].split("-")[0].lower()
                    if domain == "booking" and len(current_booking_service) ==1:
                        domain = current_booking_service[0]
                    slot_names = [domain+"-"+slot_names_per_act[act_i][slot_i] for slot_i in range(len(slot_names_per_act[act_i]))
                                    if slot_values_per_act[act_i][slot_i]!="?" and slot_names_per_act[act_i][slot_i]!="none"]
                    if slot_names:
                        to_be_retrieved_slot_names = ["%s-availability" % (domain)] + slot_names
                        while domain+"-choice" in to_be_retrieved_slot_names:
                            del to_be_retrieved_slot_names[to_be_retrieved_slot_names.index(domain+"-choice")]
                        to_be_retrieved_ground_truth.extend(to_be_retrieved_slot_names)
                to_be_retrieved_ground_truth = sorted(list(set(to_be_retrieved_ground_truth)))
                
                # augment the dataset
                turns["to_be_retrieved_ground_truth"][turn_id].extend(to_be_retrieved_ground_truth)
                
                if print_dialogue:
                    print(f"Utterance: {turns['utterance'][turn_id]}")
                    print(f"To be retrieved: {to_be_retrieved_ground_truth}")
        if print_dialogue:        
            print("-"*50)
                       

In [165]:
create_augmented_dialogue_data(train_data_filtered, False)
create_augmented_dialogue_data(val_data_filtered, False)
create_augmented_dialogue_data(test_data_filtered, False)