-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 60978a3
Showing
69 changed files
with
114,642 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # Auto detect text files and perform LF normalization | ||
| * text=auto |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| damd_multiwoz/data/embeddings | ||
| damd_multiwoz/data/multi-woz-analysis | ||
| damd_multiwoz/data/multi-woz-processed | ||
| damd_multiwoz/data/multi-woz-oppe | ||
| __pycache__ | ||
| log/*/*.txt | ||
| log/*/*.log | ||
| experiments/** | ||
| damd_multiwoz/log/*/*.txt | ||
| damd_multiwoz/log/*/*.log | ||
| damd_multiwoz/experiments/** | ||
| .idea/* | ||
| *.ipynb | ||
| *.log |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,264 @@ | ||
| from collections import defaultdict | ||
| import json | ||
| import os | ||
| import random | ||
|
|
||
| import numpy as np | ||
| from argparse import ArgumentParser | ||
|
|
||
| K=10 | ||
| TRAIN_ON=['act','resp'][0] | ||
| GAMMA_GLOBAL = 0.0 | ||
| USE_R_AS_G = True | ||
| METRIC=['soft', 'hard'][0] | ||
|
|
||
|
|
||
| def get_turn_state(full_state,turn_domain): | ||
| all_domain = ['[police]', '[taxi]', '[restaurant]', '[attraction]', '[hotel]', '[hospital]', '[train]', '[general]'] | ||
| if turn_domain not in full_state: | ||
| return turn_domain | ||
| turn_domain_idx = full_state.index(turn_domain) | ||
| end_idx = len(full_state) | ||
| for domain in all_domain: | ||
| if turn_domain!=domain: | ||
| if domain in full_state: | ||
| domain_idx = full_state.index(domain) | ||
| if domain_idx>turn_domain_idx and domain_idx<end_idx: | ||
| end_idx=domain_idx | ||
|
|
||
| return full_state[turn_domain_idx:end_idx].strip() | ||
|
|
||
| def get_state_act(data_for_damd): | ||
| fn_tn_state = defaultdict(dict) | ||
| state_act = defaultdict(list) | ||
| act_state = defaultdict(list) | ||
| for fn,dia in data_for_damd.items(): | ||
| if fn not in test_fn: | ||
| prev_act = None | ||
| for turn_num,log in enumerate(dia['log']): | ||
| full_state=log['cons_delex'] | ||
| turn_domains = log['turn_domain'] | ||
| turn_domains = turn_domains.split(' ') | ||
| turn_domain = turn_domains[-1] | ||
| if turn_domain not in full_state: | ||
| turn_state = turn_domain | ||
| else: | ||
| turn_state = get_turn_state(full_state,turn_domain) | ||
| if prev_act is not None: | ||
| act_state[prev_act].append(turn_state) | ||
| act = log['sys_act'] | ||
| state_act[turn_state].append(act) | ||
| fn_tn_state[fn][turn_num] = turn_state | ||
| prev_act = act | ||
| return state_act,act_state,fn_tn_state | ||
|
|
||
| def get_state(fn, fn_tn_state, tn): | ||
| turn_state = fn_tn_state[fn][tn] | ||
| return turn_state | ||
|
|
||
| def get_act(turn): | ||
| return turn['sys_act'] | ||
|
|
||
| def get_gamma(gamma_local): | ||
| if GAMMA_GLOBAL is not None: | ||
| gamma = GAMMA_GLOBAL | ||
| else: | ||
| gamma = gamma_local | ||
| return gamma | ||
|
|
||
|
|
||
| not_in_fn_Gs = set() | ||
| def get_reward_gamma(fn_Gs, fn, turn_num): | ||
| turn_num = str(turn_num) | ||
| if fn not in fn_Gs: | ||
| not_in_fn_Gs.add(fn) | ||
| return None | ||
| if USE_R_AS_G == False: | ||
| reward = fn_Gs[fn][turn_num]['R'] | ||
| gamma = fn_Gs[fn][turn_num]['gamma'] | ||
| elif USE_R_AS_G == True: | ||
| reward = fn_Gs[fn][turn_num] | ||
| gamma = 0 | ||
| else: | ||
| raise Exception('Invalid USE_R_AS_G selection') | ||
| return reward, gamma | ||
|
|
||
|
|
||
| def get_value_function(data_for_damd, fn_tn_state, fn_Gs): | ||
| V_info = {} | ||
| for fn, dia in data_for_damd.items(): | ||
| if fn not in test_fn: | ||
| log = dia['log'] | ||
| G_nxt = 0 | ||
| for turn in reversed(log): | ||
| turn_num = turn['turn_num'] | ||
| state = get_state(fn, fn_tn_state, turn_num) | ||
| R_gamma = get_reward_gamma(fn_Gs ,fn, turn_num) | ||
| if R_gamma is None: | ||
| continue | ||
| R,gamma = R_gamma[0]['G'], R_gamma[0]['gamma'] | ||
| if state not in V_info: | ||
| V_info[state] = { | ||
| 'V':0, | ||
| '|S|':0 | ||
| } | ||
| if USE_R_AS_G == False: | ||
| G = R + get_gamma(gamma) * G_nxt | ||
| elif USE_R_AS_G == True: | ||
| G = R | ||
| else: | ||
| raise Exception('Invalid USE_R_AS_G selection') | ||
| V_info[state]['V'] = (V_info[state]['V'] * V_info[state]['|S|'] + G)/(V_info[state]['|S|']+1) | ||
| V_info[state]['|S|']+=1 | ||
| G_nxt = G | ||
| return V_info | ||
|
|
||
|
|
||
| def get_Q_function(data_for_damd, V_info, fn_tn_state, fn_Gs): | ||
| Q_info = {} | ||
| for fn, dia in data_for_damd.items(): | ||
| if fn not in test_fn: | ||
| log = dia['log'] | ||
| V_nxt = 0 | ||
| for turn in reversed(log): | ||
| turn_num = turn['turn_num'] | ||
| state = get_state(fn, fn_tn_state, turn_num) | ||
| act = get_act(turn) | ||
| R_gamma = get_reward_gamma(fn_Gs, fn, turn_num) | ||
| if R_gamma is None: | ||
| continue | ||
| R,gamma = R_gamma[0]['G'], R_gamma[0]['gamma'] | ||
| if state not in Q_info: | ||
| Q_info[state] = {} | ||
| if act not in Q_info[state]: | ||
| Q_info[state][act] = { | ||
| 'Q':0, | ||
| '|S|':0 | ||
| } | ||
| G = R + get_gamma(gamma) * V_nxt | ||
| Q_info[state][act]['Q'] = (Q_info[state][act]['Q'] * Q_info[state][act]['|S|'] + G)/(Q_info[state][act]['|S|']+1) | ||
| Q_info[state][act]['|S|']+=1 | ||
| V_nxt = V_info[state]['V'] | ||
| return Q_info | ||
|
|
||
|
|
||
| def estimate_bh_policy(state_act, state, act): | ||
| # MLE estimation of discrete policy | ||
| Z = len(state_act[state]) | ||
| P_act = state_act[state].count(act)/Z | ||
| return P_act | ||
|
|
||
|
|
||
| def persist_Q_function(data_for_damds, Q_infos, state_acts, fn_tn_states, path_to_persist): | ||
| Q_fn = {} | ||
| for data_for_damd, Q_info, state_act , fn_tn_state in zip(data_for_damds, Q_infos, state_acts, fn_tn_states): | ||
| for fn, dia in data_for_damd.items(): | ||
| if fn not in test_fn: | ||
| log = dia['log'] | ||
| Q_fn[fn] = {} | ||
| for turn in log: | ||
| turn_num = turn['turn_num'] | ||
| state = get_state(fn, fn_tn_state, turn_num) | ||
| act = get_act(turn) | ||
| if state not in Q_info or act not in Q_info[state]: | ||
| raise Exception('I Dont see a reason to be here!') | ||
| else: | ||
| act_len = max(1,len(act.split())) | ||
| bh_policy = estimate_bh_policy(state_act, state, act) | ||
|
|
||
| Q_fn[fn][turn_num] = { | ||
| 'Q':Q_info[state][act]['Q'], | ||
| 'prob':bh_policy | ||
| } | ||
| if path_to_persist is not None: | ||
| print('path_to_persist:',path_to_persist, flush=True) | ||
| with open(path_to_persist, 'w') as f: | ||
| json.dump(Q_fn,f,indent=2) | ||
| return Q_fn | ||
|
|
||
|
|
||
| def set_seed(seed): | ||
| os.environ['PYTHONHASHSEED'] = str(seed) | ||
| np.random.seed(seed) | ||
| random.seed(seed) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| parser = ArgumentParser() | ||
| parser.add_argument("-s", "--seed", dest="seed", | ||
| default=11, | ||
| type=int, | ||
| help="seed") | ||
| parser.add_argument("-K", "--folds", | ||
| dest="folds", default=10, | ||
| type=int, | ||
| help="Number of folds") | ||
| parser.add_argument("-a", "--action_space", | ||
| dest="action_space", | ||
| choices={"act", "resp"}, | ||
| default='act', | ||
| help="action space. can either be act or resp") | ||
| parser.add_argument("-m", "--metric", | ||
| dest="metric", | ||
| choices={"hard", "soft"}, | ||
| default='soft', | ||
| help="metric used for pairwise reward candidate generation") | ||
| parser.add_argument("-g", "--gamma", | ||
| dest="gamma", | ||
| default=0.0, | ||
| type=float, | ||
| help="The discount factor used in reward learning") | ||
| parser.add_argument("--data_folder", type=str, default="") | ||
| args = parser.parse_args() | ||
| assert args.data_folder != "" | ||
|
|
||
| K=args.folds | ||
| TRAIN_ON=args.action_space | ||
| GAMMA_GLOBAL = args.gamma | ||
| METRIC = args.metric | ||
| fn_G_file_name = 'fn_Gs_{}_{}_{}_{}.json'.format(K, GAMMA_GLOBAL, TRAIN_ON, METRIC) | ||
|
|
||
| set_seed(args.seed) | ||
|
|
||
| root_path = f'./damd_multiwoz/{args.data_folder}' | ||
|
|
||
| test_fn_json_path = os.path.join(root_path,'multi-woz/testListFile.json') | ||
| valid_fn_json_path = os.path.join(root_path,'multi-woz/valListFile.json') | ||
|
|
||
| test_fn = set(open(test_fn_json_path,'r').read().lower().replace('.json','').split()) | ||
| valid_fn = set(open(valid_fn_json_path,'r').read().lower().replace('.json','').split()) | ||
|
|
||
| data_for_damd = json.loads(open(os.path.join(root_path,'multi-woz-processed/data_for_damd.json'),'r').read()) | ||
|
|
||
| print(fn_G_file_name, flush=True) | ||
| fn_Gs_file_path = os.path.join(root_path,'multi-woz-oppe',fn_G_file_name) | ||
| fn_Gs = json.loads(open(fn_Gs_file_path,'r').read()) | ||
|
|
||
| data_for_damd_only_train = { | ||
| fn:v for fn,v in data_for_damd.items() if fn not in test_fn and fn not in valid_fn | ||
| } | ||
| print('Train filtered/unfiltered={}/{}'.format(len(data_for_damd_only_train),len(data_for_damd)), flush=True) | ||
|
|
||
| data_for_damd_only_val = { | ||
| fn:v for fn,v in data_for_damd.items() if fn in valid_fn | ||
| } | ||
| print('Val filtered/unfiltered={}/{}'.format(len(data_for_damd_only_val),len(data_for_damd)), flush=True) | ||
|
|
||
| state_act_train,_,fn_tn_state_train = get_state_act(data_for_damd_only_train) | ||
| state_act_val,_,fn_tn_state_val = get_state_act(data_for_damd_only_val) | ||
|
|
||
| V_info_train = get_value_function(data_for_damd_only_train, fn_tn_state_train, fn_Gs) | ||
| V_info_val = get_value_function(data_for_damd_only_val, fn_tn_state_val, fn_Gs) | ||
|
|
||
| Q_info_train = get_Q_function(data_for_damd_only_train, V_info_train, fn_tn_state_train, fn_Gs) | ||
| Q_info_val = get_Q_function(data_for_damd_only_val, V_info_val, fn_tn_state_val, fn_Gs) | ||
|
|
||
| Q_fn_path_to_persist = os.path.join(root_path,'multi-woz-oppe',fn_G_file_name.replace('fn_Gs_','fn_Qs_')) | ||
|
|
||
| Q_fn = persist_Q_function([data_for_damd_only_train, data_for_damd_only_val], | ||
| [Q_info_train, Q_info_val], | ||
| [state_act_train, state_act_val], | ||
| [fn_tn_state_train, fn_tn_state_val], | ||
| Q_fn_path_to_persist) | ||
|
|
||
|
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| MIT License | ||
|
|
||
| Copyright (c) 2023 Shentao-YANG | ||
|
|
||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| of this software and associated documentation files (the "Software"), to deal | ||
| in the Software without restriction, including without limitation the rights | ||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| copies of the Software, and to permit persons to whom the Software is | ||
| furnished to do so, subject to the following conditions: | ||
|
|
||
| The above copyright notice and this permission notice shall be included in all | ||
| copies or substantial portions of the Software. | ||
|
|
||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Fantastic Rewards and How to Tame Them: A Case Study on Reward Learning for Task-Oriented Dialogue Systems | ||
|
|
||
| *** | ||
|
|
||
| ## Dependency | ||
| To install the required packages, first create and activate a `fantastic_reward` env in conda. | ||
| Then execute the following command: | ||
| ```angular2html | ||
| bash install_packages.sh | ||
| ``` | ||
|
|
||
| *** | ||
| ## Experiments | ||
|
|
||
| ### Data Setup | ||
| Our data-setup follows the [**CASPI**](https://github.com/salesforce/CASPI) paper. | ||
| Please download the pre-processed data from [here](https://drive.google.com/file/d/15A88j-pyI-jBznKmvJ17HtgW1h7fwDEl/view?usp=sharing). | ||
| Unzip the downloaded file and put the resulting folder `ExpStoreddata` into the folder `damd_multiwoz`. | ||
|
|
||
| ### Training the Reward and Response Models | ||
|
|
||
| For our variant of RewardNet+GS $N = 3,\Phi=(\cdot)^1$ in Table 1 of the paper, please run the following command | ||
| ```angular2html | ||
| bash ./run_multiple_seeds.sh --EXP_IDX ${EXP_IDX} --REWARD_SAMPLES 3 --REWARD_LOSS "listNet" --LISTMLE_TEMP 1 --LISTNET_POW 1 --POLICY_TRAIN_DATA_FRAC 1 --NEG_REW_WEIGHT 0.1 --REW_MODEL_EXP '0' | ||
| ``` | ||
| where `${EXP_IDX}` is the index of the experiment, such as `"2023"`. | ||
|
|
||
| For our variant of RewardMLE+GS $N = 5,\Phi=\exp(\cdot)$ in Table 1 of the paper, please run the following command | ||
| ```angular2html | ||
| bash ./run_multiple_seeds.sh --EXP_IDX ${EXP_IDX} --REWARD_SAMPLES 5 --REWARD_LOSS "listMLE" --LISTMLE_TEMP 1 --LISTNET_POW 0 --POLICY_TRAIN_DATA_FRAC 1 --NEG_REW_WEIGHT 1.0 --REW_MODEL_EXP '0' | ||
| ``` | ||
| where `${EXP_IDX}` is again the index of the experiment. | ||
|
|
||
| ### Evaluating the Released Checkpoints | ||
|
|
||
| To facilitate reproducibility, we release a checkpoint for each of the variant | ||
| RewardNet+GS $N = 3,\Phi=(\cdot)^1$ and RewardMLE+GS $N = 5,\Phi=\exp(\cdot)$ in Table 1 of the paper. | ||
| The released checkpoints are both trained under random seed `999` of the tested five seeds `(111 333 555 777 999)`. | ||
|
|
||
| To evaluate the checkpoints, please try the following steps. | ||
| Here `Exp1` corresponds to the variant of RewardNet+GS $N =3,\Phi=(\cdot)^1$ and `Exp2` for RewardMLE+GS $N = 5,\Phi=\exp(\cdot)$. | ||
|
|
||
| 1. Download and unzip the checkpoints from [here](https://drive.google.com/file/d/1EUIno8hq94smUqBBnzr_m8svMWKKH7P5/view?usp=sharing). Put the resulting folders into a folder named `experiments`. | ||
| 2. Download and unzip the processed data from [here](https://drive.google.com/file/d/1fwLK62U38B3pxYxzrycGyEwt4_AFRv7l/view?usp=sharing). Put the resulting folders into the folder `damd_multiwoz`. | ||
| 3. Try the following command | ||
| ```angular2html | ||
| python train.py --model_path "Exp${EXP_IDX}/all_sd999/" \ | ||
| --mode 'test' --context_window 2 --pretrained_checkpoint bart-large-cnn \ | ||
| --back_bone bart --cfg seed=999 cuda_device=0 batch_size=8 early_stop_count=7 \ | ||
| --caspi_returns_file="fn_Gs_10_0.0_resp_soft.json" --caspi_wt=5. \ | ||
| --caspi_data_file=data_for_damd.json --caspi_val_fraction=.5 --caspi --data_folder "Exp${EXP_IDX}data/s999_K10_GAMMA0.0" \ | ||
| --exp_idx ${EXP_IDX} | ||
| ``` | ||
| where `${EXP_IDX}` should be replaced by `1` or `2`. | ||
|
|
||
|
|
||
| ## Acknowledgement | ||
| This codebase builds on the following codebases and datasets: | ||
| * [**CASPI**](https://github.com/salesforce/CASPI). | ||
| * [**MinTL**](https://github.com/zlinao/MinTL). | ||
| * [**DAMD**](https://gitlab.com/ucdavisnlp/damd-multiwoz). | ||
| * [**Multiwoz2.0**](https://github.com/budzianowski/multiwoz). | ||
| * [**ConvLab Multiwoz2.0 annotation**](https://github.com/ConvLab/ConvLab/tree/master/data/multiwoz/annotation). |
Oops, something went wrong.