# Initialization

## Setup Jupyter

In [1]:
%load_ext jupyternotify

<IPython.core.display.Javascript object>

## Constants

In [2]:
use_colab = False
use_api = False
update_prediction_files = False
predict_domains = False
predict_user_state_slots = True
predict_user_requested_slots = True
predict_mismatches = False
prediction_level = 'TURN' #TURN or DIALOGUE
prompt_type = 'TASK' # QA or TASK
predict_dontcare_slots = True
use_ontology  = False
use_domain_description = False
dataset_driver = 'USER' #USER or SYSTEM
dataset_driver = 'SYSTEM' #USER or SYSTEM
overwrite_predictions = False  # works for both domains and slots

## Global Directories

In [3]:
if use_colab:
  !pip install bardapi
  !pip install openai
  import sys
  from google.colab import drive
  drive.mount('/content/drive')
  root_dir = '/content/drive/MyDrive/Colab/tod-nlu'
  dataset_dir = '/content/drive/MyDrive/Colab/datasets/MultiWOZ_2.4/'
  sys.path.insert(0,root_dir)
  %cd {root_dir}
else:
  root_dir = '.'
  dataset_dir = '../datasets/MultiWOZ2.4/data/MULTIWOZ2.4/'

## Modules

In [4]:
import json
import os
import csv
from Dataset import Dataset
from ChatBot import ChatBot
from multiwoz_utils import MultiWOZ

['train_id', 'value_time']


[nltk_data] Downloading package punkt to /home/asafa/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/asafa/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## Constants and Global Variables

In [5]:
dataset_split="test"
model = 'CHATGPT'
sub_model = 'gpt-4o'

response_format = 'json'
sub_model_name = sub_model.replace('/','-')
model_name = model.replace('/', '-')

dataset_name = 'SGD'
dataset_name = 'MultiWOZ'

dataset_type = 'SGD'
dataset_type = 'MWZ'

test_dialogue_lower_index = 0 #
test_dialogue_upper_index = 1001

In [6]:
# directories
data_dir = os.path.join(root_dir, 'data', dataset_name)
log_dir = os.path.join(root_dir, 'debug', dataset_name)
os.makedirs(log_dir, exist_ok=True)
metadata_file_path = os.path.join(data_dir, 'metadata.json')

# files
# Make sure to delete the predicted_slots file if you need to predict all slots
prompt_template_file_path = os.path.join(root_dir,'prompt_templates.json')
# construct the predicted dialogue file path
if not use_ontology:
    predicted_dialogues_file_path = os.path.join(
    data_dir, dataset_split, model_name+'-'+sub_model_name+'_'+prompt_type+'_'+prediction_level+'_'+'dialogues.json')
else:
     predicted_dialogues_file_path = os.path.join(data_dir, dataset_split, model_name+'-'+sub_model_name+'_'+prompt_type+'_'+prediction_level+'_'+'with-ontology'+'_'+'dialogues.json')
#predicted_dialogues_file_path = os.path.join(data_dir, dataset_split,'CHATGPT-gpt-4-turbo-preview_TASK_TURN_without-domain-fixes_dialogues.json')

annotated_dialogues_file_path = os.path.join(
    data_dir, dataset_split, model+'-'+sub_model_name+'_'+'annotated_dialogues.json')
mismatched_domains_file_path = os.path.join(
    root_dir, 'debug/mismatched_domains.txt')
mismatched_state_slots_file_path = os.path.join(
    root_dir, 'debug/mismatched_state_slots.txt')
additional_state_slots_file_path = os.path.join(
    root_dir, 'debug/additional_state_slots.psv')
missing_requested_slots_file_path = os.path.join(
    root_dir, 'debug/missing_requested_slots.txt')
additional_requested_slots_file_path = os.path.join(
    root_dir, 'debug/additional_requested_slots.psv')
slot_value_pool_file_path = os.path.join(data_dir, 'slot_value_pool.psv')

In [7]:
if use_api:
    chatBot = ChatBot(model, sub_model, response_format, prompt_template_file_path, log_dir, debug=False)

# Load ground truth data

Load datasets

In [8]:
dataset = Dataset(dataset_type, metadata_file_path, dataset_dir)
dataset.parse()

# to be refactored
dataset_utils = MultiWOZ(domains=dataset.domains.keys(), slots=dataset.slots, slot_value_pool_file_path=slot_value_pool_file_path)

parsing domain Restaurant
parsing domain Attraction
parsing domain hotel
parsing domain taxi
parsing domain train
parsing domain Bus
parsing domain Hospital
parsing domain Police
reading all dialoge
training data size:  8438
development data size:  1000
testing data size:  1000
Finished reading the ./data/MultiWOZ/slot_value_pool.psv
Number of entries in ./data/MultiWOZ/slot_value_pool.psv is 0
Dumping ./data/MultiWOZ/slot_value_pool.psv


Dumping ./data/MultiWOZ/slot_value_pool.psv
Dumping ./data/MultiWOZ/slot_value_pool.psv
Dumping ./data/MultiWOZ/slot_value_pool.psv
Dumping ./data/MultiWOZ/slot_value_pool.psv
Dumping ./data/MultiWOZ/slot_value_pool.psv
Dumping ./data/MultiWOZ/slot_value_pool.psv
Dumping ./data/MultiWOZ/slot_value_pool.psv
Dumping ./data/MultiWOZ/slot_value_pool.psv
Dumping ./data/MultiWOZ/slot_value_pool.psv
Dumping ./data/MultiWOZ/slot_value_pool.psv


# Predict data

## Load Previous Predicitons

In [9]:
if os.path.isfile(predicted_dialogues_file_path):
    with open(predicted_dialogues_file_path, 'r') as file:
        predicted_dialogues = json.load(file)
else:
        predicted_dialogues = {}
dataset_dialogues = dataset.get_split_data(split=dataset_split)

print('Dataset Dialogues size ', len(dataset_dialogues))
print(f'Parsing dialogue file {predicted_dialogues_file_path}')
print('Predicted Dialogues size ', len(predicted_dialogues))

Dataset Dialogues size  1000
Parsing dialogue file ./data/MultiWOZ/test/CHATGPT-gpt-4o_TASK_TURN_dialogues.json
Predicted Dialogues size  1000


## Predict Domains

### Load mismatched domains from files

In [10]:
#### read mismatched domains files ###
try:
    with open(mismatched_domains_file_path, 'r') as mismatched_domains_file:
        mismatched_domains_ids = [row.split('|')[0].strip() for row in mismatched_domains_file]
except:
    mismatched_domains_ids = []
print(len(mismatched_domains_ids), ' entries were found in ',
      mismatched_domains_file_path)

344  entries were found in  ./debug/mismatched_domains.txt


### Predict Domains

In [11]:
processed_dialogue_index = -1
if use_api and predict_domains:
    for dialogue in dataset_dialogues:
        dialogue_turns = []
        dialogue_speakers = []
        predict_dialogue_domains = False
        processed_dialogue_index+=1
        print('Number of processed dialogues so far = ',processed_dialogue_index)
        if processed_dialogue_index<test_dialogue_lower_index:
            continue
        if processed_dialogue_index>test_dialogue_upper_index:
            break
        for turn in dataset_dialogues[dialogue]:
            turn_id = dialogue+'-'+turn
            text = dataset_dialogues[dialogue][turn]['text']
            speaker = dataset_dialogues[dialogue][turn]['speaker']
            dialogue_turns.append(text)
            dialogue_speakers.append(speaker)
            print(turn_id)
            if overwrite_predictions:
                predict_dialogue_domains = True
            elif not dialogue in predicted_dialogues:
                predict_dialogue_domains = True
            elif predict_mismatches and turn_id in mismatched_domains_ids:
                predict_dialogue_domains = True
                print(dialogue+'-'+turn, ' found in the mismatched domains file')
            elif not overwrite_predictions and 'domains' in predicted_dialogues[dialogue][turn].keys():
                print(dialogue+'-'+turn, 'has already predictions')
                predict_dialogue_domains = predict_dialogue_domains | False
            else:
                predict_dialogue_domains = False
        
        if predict_dialogue_domains:
            predicted_domains = chatBot.predict_dialogue_domains(dialogue_turns, dialogue_speakers, dataset.domains, prediction_level, use_domain_description)
            predicted_dialogues[dialogue] = {}
            for turn_index in predicted_domains:
                # check if additional turns are back                
                if turn_index in dataset_dialogues[dialogue].keys():
                    predicted_dialogues[dialogue][turn_index]={}
                    predicted_dialogues[dialogue][turn_index]['text'] = dataset_dialogues[dialogue][turn_index]['text']
                    predicted_dialogues[dialogue][turn_index]['speaker'] = dataset_dialogues[dialogue][turn_index]['speaker']
                    predicted_dialogues[dialogue][turn_index]['history'] = dataset_dialogues[dialogue][turn_index]['history']

                    turn_predicted_domains = predicted_domains[turn_index]
                    # filter domains
                    if isinstance(turn_predicted_domains, str):
                        turn_predicted_domains = [turn_predicted_domain]
                    for turn_predicted_domain in turn_predicted_domains:
                        if not turn_predicted_domain in dataset_utils.domains:
                            print('Predicted domain ', turn_predicted_domain,' is not of the defined domain list')
                            turn_predicted_domains.remove(turn_predicted_domain)
                    predicted_dialogues[dialogue][turn_index]['domains']=turn_predicted_domains
                else:
                    print('Turn with index ', turn_index,' couldn\'t be found in the dialogue ', dialogue)
            
            # fill out the turns with no predicted domains
        for turn in dataset_dialogues[dialogue]:
            if not turn in predicted_dialogues[dialogue]:
                predicted_dialogues[dialogue][turn] = {}
                predicted_dialogues[dialogue][turn]['text'] = dataset_dialogues[dialogue][turn]['text']
                predicted_dialogues[dialogue][turn]['speaker'] = dataset_dialogues[dialogue][turn]['speaker']
                predicted_dialogues[dialogue][turn]['history'] = dataset_dialogues[dialogue][turn]['history']
            if not 'domains' in predicted_dialogues[dialogue][turn]:
                predicted_dialogues[dialogue][turn]['domains']=[]

## Predict Slots

### Load mismatched data from files

In [12]:
#### read mismatched slots files ###
try:
    with open(mismatched_state_slots_file_path, 'r') as mismatched_state_slots_file:
        mismatched_state_slot_ids = [row.split('|')[0] for row in mismatched_state_slots_file]
except:
    mismatched_state_slot_ids = []
print(len(mismatched_state_slot_ids), ' entries were found in ',
      mismatched_state_slots_file_path)

try:
    with open(missing_requested_slots_file_path, 'r') as missing_requested_slots_file:
        missing_requested_slots_ids = [row.split('|')[0] for row in missing_requested_slots_file]
except:
    missing_requested_slots_ids = []
print(len(missing_requested_slots_ids), ' entries were found in ',
      missing_requested_slots_file_path)
#### finished reading mismatched slots files ###

2161  entries were found in  ./debug/mismatched_state_slots.txt
0  entries were found in  ./debug/missing_requested_slots.txt


### Predict Slots

In [13]:
def predict_turn_slots_as_qa(turns, speakers, domain, schema_slots, dialogue_state):
    # 1. extract the turn slots
    predicted_slots = {}
    dontcare_value_found = False
    last_turn_speaker = speakers[-1]
    # extract the possible values
    turn_extracted_slots = chatBot.extract_turn_entities(
        turn=turns[-1], prediction_level='TURN')
    domain_slots = schema_slots[domain]

    # check if dontcare slot found
    if 'DONTCARE' in turn_extracted_slots and len(turn_extracted_slots['DONTCARE']) > 0:
        dontcare_value_found = True

    for slot in domain_slots:
        slot_category = domain_slots[slot]['category']
        possible_slot_values = set()
        # add possible slot values from the same slot type from dialogue state
        # to cover the cases where the slot value was proposed by the system
        for speaker in dialogue_state:
            for state_domain in dialogue_state[speaker]:
                for state_slot in dialogue_state[speaker][state_domain]:
                    if slot_category == schema_slots[state_domain][state_slot]['category']:
                        if isinstance(dialogue_state[speaker][state_domain][state_slot], list):
                            possible_slot_values.update(
                                dialogue_state[speaker][state_domain][state_slot])
                        else:
                            possible_slot_values.add(dialogue_state[speaker][state_domain][state_slot])

        if slot_category in turn_extracted_slots and len(turn_extracted_slots[slot_category]) > 0:
            if slot_category in turn_extracted_slots and len(turn_extracted_slots[slot_category]) > 0:
                extracted_slot_values = turn_extracted_slots[slot_category]
            else:
                extracted_slot_values = []
            
            if predict_dontcare_slots and dontcare_value_found and last_turn_speaker == 'USER':
                extracted_slot_values.append('dontcare: USER explicitly said he has no preference for the slot value.')
                if 'dontcare' in extracted_slot_values:
                    extracted_slot_values.remove('dontcare')
            print('extracted_slot_values: ', extracted_slot_values)
            try:
                possible_slot_values.update(extracted_slot_values)
            except:
                print('ERROR: extracted_slot_values is not valid list. ', extracted_slot_values)

            predicted_slot_value = chatBot.choose_slot_value(
                slot, turns=turns, speakers=speakers, domain=domain, domain_slots=domain_slots, possible_values=possible_slot_values)
            if predicted_slot_value == 'None':
                continue

            predicted_slots[slot] = predicted_slot_value
    dialogue_state[speakers[-1]][domain].update(predicted_slots)

    return predicted_slots, dialogue_state

In [14]:
def predict_dialogue_slots(dialogue, domains, dataset_slots, use_slots_possible_values, prediction_type):
    
    if prediction_type == 'TASK':
        turns = []
        speakers = []
        
        for turn in dialogue:
            turns.append(dialogue[turn]['text'])
            speakers.append(dialogue[turn]['speaker'])

        for domain in domains:
            domain_slots = dataset_slots[domain]
            dialogue_domain_slots = chatBot.predict_dialogue_slots(turns, speakers, domain, domain_slots=domain_slots,use_slots_possible_values=use_slots_possible_values, prediction_level=prediction_level)
            for turn_index in dialogue_domain_slots:
                turn_domain_state_slots = {}
                turn_domain_requested_slots = []
                if not turn_index in dialogue or not isinstance(dialogue_domain_slots[turn_index], dict):
                    print('Warning: predicted slots have additional turn index ', turn_index)
                    continue
                if not isinstance(dialogue_domain_slots[turn_index], dict):
                    print('Warning: invalud predicted slots ', dialogue_domain_slots[turn_index])
                    continue
                
                for slot in dialogue_domain_slots[turn_index]:
                    if dialogue_domain_slots[turn_index][slot] == '?':
                        turn_domain_requested_slots.append(slot)
                    else:
                        turn_domain_state_slots[slot] = dialogue_domain_slots[turn_index][slot]
                
                if not domain in dialogue[turn_index]['domains']:
                    dialogue[turn_index]['domains'][domain] = {}
                dialogue[turn_index]['domains'][domain]['slots'] = turn_domain_state_slots
                dialogue[turn_index]['domains'][domain]['requested_slots'] = turn_domain_requested_slots
            
    
    elif prediction_type == 'QA':
        print('dialogue domains to predict = ', domains)
        turns = []
        speakers = []
        dialogue_state = {}
        for turn in dialogue:
            speaker = dialogue[turn]['speaker']
            if not speaker in dialogue_state:
                dialogue_state[speaker] = {}
            turns.append(dialogue[turn]['text'])
            speakers.append(dialogue[turn]['speaker'])
            turn_domains = dialogue[turn]['domains']
            print(f'Turn domains {turn_domains}')
            for domain in turn_domains:
                print('domain = ', domain)
                print(f'working on {turn}-{domain}')
                if not domain in domains:
                    continue
                if not domain in dialogue_state[speaker]:
                    dialogue_state[speaker][domain] = {}
                turn_predicted_slots, dialogue_state = predict_turn_slots_as_qa(turns = turns, speakers= speakers, domain=domain, schema_slots=dataset.slots, dialogue_state=dialogue_state)
                dialogue[turn]['domains'][domain]['slots'] = turn_predicted_slots
    
    #print('dialogue = ', dialogue)
    return dialogue

In [15]:
if use_api and (predict_user_state_slots or predict_user_requested_slots):
    processed_dialogue_index = -1
    for dialogue in predicted_dialogues:
        dialogue_turns = []
        dialogue_speakers = []
        dialogue_domains = {}
        processed_dialogue_index += 1
        dialogue_domains_to_predict = set()
        dialogue_domain_tracker = set()

        print('Number of processed dialogues so far = ', processed_dialogue_index)
        if processed_dialogue_index < test_dialogue_lower_index:
            continue
        if processed_dialogue_index > test_dialogue_upper_index:
            break

        for turn in predicted_dialogues[dialogue]:
            print(dialogue+':'+turn)
            speaker = predicted_dialogues[dialogue][turn]['speaker']
            text = predicted_dialogues[dialogue][turn]['text']
            dialogue_turns.append(text)
            dialogue_speakers.append(speaker)
            # set the domains to predict for #
            turn_domains = predicted_dialogues[dialogue][turn]['domains']
            
            # append the dataset domains #
            if isinstance(turn_domains , list):
                dialogue_domain_tracker.update(turn_domains)
            else:
                dialogue_domain_tracker.update(list(turn_domains.keys()))
            
            # add the datasets domains
            '''dataset_domains = dataset_dialogues[dialogue][turn]['domains']
            for dataset_domain in dataset_domains:
                if dataset_domain not in dialogue_domain_tracker:
                    turn_domains[dataset_domain] = {}
                    turn_domains[dataset_domain]['source'] = 'dataset'''
            dialogue_domains[turn] = turn_domains

            # Check which domain to predict
            if overwrite_predictions or isinstance(turn_domains , list):
                predicted_dialogues[dialogue][turn]["domains"]={}
                for turn_domain in turn_domains:
                    predicted_dialogues[dialogue][turn]['domains'][turn_domain] = {}
                    predicted_dialogues[dialogue][turn]['domains'][turn_domain]['slots'] = {}
                    predicted_dialogues[dialogue][turn]['domains'][turn_domain]['requested_slots'] = {}
                dialogue_domains_to_predict.update(turn_domains)
            else:
                # loop over the turn domains
                for turn_domain in turn_domains:
                    domain_id = dialogue+'-'+turn+'-'+turn_domain
                    print('Working on: ', domain_id)
                    predict_domain_slots = False
                    if not 'slots' in predicted_dialogues[dialogue][turn]["domains"][turn_domain]:
                        predict_domain_slots = True
                    elif predict_mismatches and domain_id in mismatched_state_slot_ids:
                        print(domain_id,' is found in the mismatched slots file. Will be predicted')
                        predict_domain_slots = True
                    elif predict_mismatches and domain_id in missing_requested_slots_ids:
                        print(domain_id,' is found in the missing slots file. Will be predicted')
                        predict_domain_slots = True
                    elif not 'requested_slots' in predicted_dialogues[dialogue][turn]["domains"][turn_domain]:
                        predict_domain_slots = True
                    else:
                        predict_domain_slots = False
                        print(domain_id, ' has already predicted slots')

                    if predict_domain_slots:
                        dialogue_domains_to_predict.add(turn_domain)
        
        # initiate the slots to avoid predicting the turn  domain again later
        for dialogue_domain in dialogue_domains_to_predict:
            for turn in  predicted_dialogues[dialogue]:
                if dialogue_domain in predicted_dialogues[dialogue][turn]['domains']:
                    predicted_dialogues[dialogue][turn]['domains'][dialogue_domain]['slots'] = {}
                    predicted_dialogues[dialogue][turn]['domains'][dialogue_domain]['requested_slots'] = []

        # start predicting the slots
        print('Predicting the slots for ', dialogue)
        dialogue = predict_dialogue_slots(dialogue=predicted_dialogues[dialogue], domains = dialogue_domains_to_predict, dataset_slots = dataset.slots, use_slots_possible_values=use_ontology, prediction_type=prompt_type)
            

## Update Prediction File

In [16]:
if update_prediction_files:
    with open(predicted_dialogues_file_path, 'w') as file:
        json.dump(predicted_dialogues, file, indent=4)