Fantastic Rewards and How to Tame Them: A Case Study on Reward Learning for Task-Oriented Dialogue Systems
To install the required packages, first create and activate a fantastic_reward environment in conda.
Then execute the following command:
bash install_packages.sh
Our data-setup follows the CASPI paper.
Please download the pre-processed data from here.
Unzip the downloaded file and put the resulting folder ExpStoreddata into the folder damd_multiwoz.
For our variant of RewardNet+GS
bash ./run_multiple_seeds.sh --EXP_IDX ${EXP_IDX} --REWARD_SAMPLES 3 --REWARD_LOSS "listNet" --LISTMLE_TEMP 1 --LISTNET_POW 1 --POLICY_TRAIN_DATA_FRAC 1 --NEG_REW_WEIGHT 0.1 --REW_MODEL_EXP '0'
where ${EXP_IDX} is the index of the experiment, such as "2023".
For our variant of RewardMLE+GS
bash ./run_multiple_seeds.sh --EXP_IDX ${EXP_IDX} --REWARD_SAMPLES 5 --REWARD_LOSS "listMLE" --LISTMLE_TEMP 1 --LISTNET_POW 0 --POLICY_TRAIN_DATA_FRAC 1 --NEG_REW_WEIGHT 1.0 --REW_MODEL_EXP '0'
where ${EXP_IDX} is again the index of the experiment.
To facilitate reproducibility, we release a checkpoint for each of the variant
RewardNet+GS 999 of the tested five seeds (111 333 555 777 999).
To evaluate the checkpoints, please try the following steps.
Here Exp1 corresponds to the variant of RewardNet+GS Exp2 for RewardMLE+GS
- Download and unzip the checkpoints from here.
- Download and unzip the processed data from here. Put the resulting folders into the folder
damd_multiwoz. - Try the following command
python train.py --model_path "experiments/Exp${EXP_IDX}/all_sd999/" \
--mode 'test' --context_window 2 --pretrained_checkpoint bart-large-cnn \
--back_bone bart --cfg seed=999 cuda_device=0 batch_size=8 early_stop_count=7 \
--caspi_returns_file="fn_Gs_10_0.0_resp_soft.json" --caspi_wt=5. \
--caspi_data_file=data_for_damd.json --caspi_val_fraction=.5 --caspi --data_folder "Exp${EXP_IDX}data/s999_K10_GAMMA0.0" \
--exp_idx ${EXP_IDX}
where ${EXP_IDX} should be replaced by 1 or 2.
The following table shows the standardized evaluation results of our ``RewardNet+GS'' model.
Detailed numbers are provided in Example_generation/result_standard_eval.json.
| BLEU | Inform | Success | Combined Score | Av. len. | CBE | #uniq. words | #uniq. 3-grams |
|---|---|---|---|---|---|---|---|
| 17.6 | 87.6 | 81.5 | 102.2 | 13.22 | 1.99 | 423 | 3942 |
Examples of generated dialogues on the test-split of MultiWOZ2.0 can be found at Example_generation/gen_test_formatted.json.
This codebase builds on the following codebases and datasets: