Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Shentao-YANG committed Feb 7, 2023
0 parents commit 60978a3
Show file tree
Hide file tree
Showing 69 changed files with 114,642 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
damd_multiwoz/data/embeddings
damd_multiwoz/data/multi-woz-analysis
damd_multiwoz/data/multi-woz-processed
damd_multiwoz/data/multi-woz-oppe
__pycache__
log/*/*.txt
log/*/*.log
experiments/**
damd_multiwoz/log/*/*.txt
damd_multiwoz/log/*/*.log
damd_multiwoz/experiments/**
.idea/*
*.pdf
*.ipynb
*.log
512 changes: 512 additions & 0 deletions BART.py

Large diffs are not rendered by default.

663 changes: 663 additions & 0 deletions DST.py

Large diffs are not rendered by default.

264 changes: 264 additions & 0 deletions EstimateBehaviorPolicy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from collections import defaultdict
import json
import os
import random

import numpy as np
from argparse import ArgumentParser

K=10
TRAIN_ON=['act','resp'][0]
GAMMA_GLOBAL = 0.0
USE_R_AS_G = True
METRIC=['soft', 'hard'][0]


def get_turn_state(full_state,turn_domain):
all_domain = ['[police]', '[taxi]', '[restaurant]', '[attraction]', '[hotel]', '[hospital]', '[train]', '[general]']
if turn_domain not in full_state:
return turn_domain
turn_domain_idx = full_state.index(turn_domain)
end_idx = len(full_state)
for domain in all_domain:
if turn_domain!=domain:
if domain in full_state:
domain_idx = full_state.index(domain)
if domain_idx>turn_domain_idx and domain_idx<end_idx:
end_idx=domain_idx

return full_state[turn_domain_idx:end_idx].strip()

def get_state_act(data_for_damd):
fn_tn_state = defaultdict(dict)
state_act = defaultdict(list)
act_state = defaultdict(list)
for fn,dia in data_for_damd.items():
if fn not in test_fn:
prev_act = None
for turn_num,log in enumerate(dia['log']):
full_state=log['cons_delex']
turn_domains = log['turn_domain']
turn_domains = turn_domains.split(' ')
turn_domain = turn_domains[-1]
if turn_domain not in full_state:
turn_state = turn_domain
else:
turn_state = get_turn_state(full_state,turn_domain)
if prev_act is not None:
act_state[prev_act].append(turn_state)
act = log['sys_act']
state_act[turn_state].append(act)
fn_tn_state[fn][turn_num] = turn_state
prev_act = act
return state_act,act_state,fn_tn_state

def get_state(fn, fn_tn_state, tn):
turn_state = fn_tn_state[fn][tn]
return turn_state

def get_act(turn):
return turn['sys_act']

def get_gamma(gamma_local):
if GAMMA_GLOBAL is not None:
gamma = GAMMA_GLOBAL
else:
gamma = gamma_local
return gamma


not_in_fn_Gs = set()
def get_reward_gamma(fn_Gs, fn, turn_num):
turn_num = str(turn_num)
if fn not in fn_Gs:
not_in_fn_Gs.add(fn)
return None
if USE_R_AS_G == False:
reward = fn_Gs[fn][turn_num]['R']
gamma = fn_Gs[fn][turn_num]['gamma']
elif USE_R_AS_G == True:
reward = fn_Gs[fn][turn_num]
gamma = 0
else:
raise Exception('Invalid USE_R_AS_G selection')
return reward, gamma


def get_value_function(data_for_damd, fn_tn_state, fn_Gs):
V_info = {}
for fn, dia in data_for_damd.items():
if fn not in test_fn:
log = dia['log']
G_nxt = 0
for turn in reversed(log):
turn_num = turn['turn_num']
state = get_state(fn, fn_tn_state, turn_num)
R_gamma = get_reward_gamma(fn_Gs ,fn, turn_num)
if R_gamma is None:
continue
R,gamma = R_gamma[0]['G'], R_gamma[0]['gamma']
if state not in V_info:
V_info[state] = {
'V':0,
'|S|':0
}
if USE_R_AS_G == False:
G = R + get_gamma(gamma) * G_nxt
elif USE_R_AS_G == True:
G = R
else:
raise Exception('Invalid USE_R_AS_G selection')
V_info[state]['V'] = (V_info[state]['V'] * V_info[state]['|S|'] + G)/(V_info[state]['|S|']+1)
V_info[state]['|S|']+=1
G_nxt = G
return V_info


def get_Q_function(data_for_damd, V_info, fn_tn_state, fn_Gs):
Q_info = {}
for fn, dia in data_for_damd.items():
if fn not in test_fn:
log = dia['log']
V_nxt = 0
for turn in reversed(log):
turn_num = turn['turn_num']
state = get_state(fn, fn_tn_state, turn_num)
act = get_act(turn)
R_gamma = get_reward_gamma(fn_Gs, fn, turn_num)
if R_gamma is None:
continue
R,gamma = R_gamma[0]['G'], R_gamma[0]['gamma']
if state not in Q_info:
Q_info[state] = {}
if act not in Q_info[state]:
Q_info[state][act] = {
'Q':0,
'|S|':0
}
G = R + get_gamma(gamma) * V_nxt
Q_info[state][act]['Q'] = (Q_info[state][act]['Q'] * Q_info[state][act]['|S|'] + G)/(Q_info[state][act]['|S|']+1)
Q_info[state][act]['|S|']+=1
V_nxt = V_info[state]['V']
return Q_info


def estimate_bh_policy(state_act, state, act):
# MLE estimation of discrete policy
Z = len(state_act[state])
P_act = state_act[state].count(act)/Z
return P_act


def persist_Q_function(data_for_damds, Q_infos, state_acts, fn_tn_states, path_to_persist):
Q_fn = {}
for data_for_damd, Q_info, state_act , fn_tn_state in zip(data_for_damds, Q_infos, state_acts, fn_tn_states):
for fn, dia in data_for_damd.items():
if fn not in test_fn:
log = dia['log']
Q_fn[fn] = {}
for turn in log:
turn_num = turn['turn_num']
state = get_state(fn, fn_tn_state, turn_num)
act = get_act(turn)
if state not in Q_info or act not in Q_info[state]:
raise Exception('I Dont see a reason to be here!')
else:
act_len = max(1,len(act.split()))
bh_policy = estimate_bh_policy(state_act, state, act)

Q_fn[fn][turn_num] = {
'Q':Q_info[state][act]['Q'],
'prob':bh_policy
}
if path_to_persist is not None:
print('path_to_persist:',path_to_persist, flush=True)
with open(path_to_persist, 'w') as f:
json.dump(Q_fn,f,indent=2)
return Q_fn


def set_seed(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
random.seed(seed)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("-s", "--seed", dest="seed",
default=11,
type=int,
help="seed")
parser.add_argument("-K", "--folds",
dest="folds", default=10,
type=int,
help="Number of folds")
parser.add_argument("-a", "--action_space",
dest="action_space",
choices={"act", "resp"},
default='act',
help="action space. can either be act or resp")
parser.add_argument("-m", "--metric",
dest="metric",
choices={"hard", "soft"},
default='soft',
help="metric used for pairwise reward candidate generation")
parser.add_argument("-g", "--gamma",
dest="gamma",
default=0.0,
type=float,
help="The discount factor used in reward learning")
parser.add_argument("--data_folder", type=str, default="")
args = parser.parse_args()
assert args.data_folder != ""

K=args.folds
TRAIN_ON=args.action_space
GAMMA_GLOBAL = args.gamma
METRIC = args.metric
fn_G_file_name = 'fn_Gs_{}_{}_{}_{}.json'.format(K, GAMMA_GLOBAL, TRAIN_ON, METRIC)

set_seed(args.seed)

root_path = f'./damd_multiwoz/{args.data_folder}'

test_fn_json_path = os.path.join(root_path,'multi-woz/testListFile.json')
valid_fn_json_path = os.path.join(root_path,'multi-woz/valListFile.json')

test_fn = set(open(test_fn_json_path,'r').read().lower().replace('.json','').split())
valid_fn = set(open(valid_fn_json_path,'r').read().lower().replace('.json','').split())

data_for_damd = json.loads(open(os.path.join(root_path,'multi-woz-processed/data_for_damd.json'),'r').read())

print(fn_G_file_name, flush=True)
fn_Gs_file_path = os.path.join(root_path,'multi-woz-oppe',fn_G_file_name)
fn_Gs = json.loads(open(fn_Gs_file_path,'r').read())

data_for_damd_only_train = {
fn:v for fn,v in data_for_damd.items() if fn not in test_fn and fn not in valid_fn
}
print('Train filtered/unfiltered={}/{}'.format(len(data_for_damd_only_train),len(data_for_damd)), flush=True)

data_for_damd_only_val = {
fn:v for fn,v in data_for_damd.items() if fn in valid_fn
}
print('Val filtered/unfiltered={}/{}'.format(len(data_for_damd_only_val),len(data_for_damd)), flush=True)

state_act_train,_,fn_tn_state_train = get_state_act(data_for_damd_only_train)
state_act_val,_,fn_tn_state_val = get_state_act(data_for_damd_only_val)

V_info_train = get_value_function(data_for_damd_only_train, fn_tn_state_train, fn_Gs)
V_info_val = get_value_function(data_for_damd_only_val, fn_tn_state_val, fn_Gs)

Q_info_train = get_Q_function(data_for_damd_only_train, V_info_train, fn_tn_state_train, fn_Gs)
Q_info_val = get_Q_function(data_for_damd_only_val, V_info_val, fn_tn_state_val, fn_Gs)

Q_fn_path_to_persist = os.path.join(root_path,'multi-woz-oppe',fn_G_file_name.replace('fn_Gs_','fn_Qs_'))

Q_fn = persist_Q_function([data_for_damd_only_train, data_for_damd_only_val],
[Q_info_train, Q_info_val],
[state_act_train, state_act_val],
[fn_tn_state_train, fn_tn_state_val],
Q_fn_path_to_persist)


21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Shentao-YANG

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
63 changes: 63 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Fantastic Rewards and How to Tame Them: A Case Study on Reward Learning for Task-Oriented Dialogue Systems

***

## Dependency
To install the required packages, first create and activate a `fantastic_reward` env in conda.
Then execute the following command:
```angular2html
bash install_packages.sh
```

***
## Experiments

### Data Setup
Our data-setup follows the [**CASPI**](https://github.com/salesforce/CASPI) paper.
Please download the pre-processed data from [here](https://drive.google.com/file/d/15A88j-pyI-jBznKmvJ17HtgW1h7fwDEl/view?usp=sharing).
Unzip the downloaded file and put the resulting folder `ExpStoreddata` into the folder `damd_multiwoz`.

### Training the Reward and Response Models

For our variant of RewardNet+GS $N = 3,\Phi=(\cdot)^1$ in Table 1 of the paper, please run the following command
```angular2html
bash ./run_multiple_seeds.sh --EXP_IDX ${EXP_IDX} --REWARD_SAMPLES 3 --REWARD_LOSS "listNet" --LISTMLE_TEMP 1 --LISTNET_POW 1 --POLICY_TRAIN_DATA_FRAC 1 --NEG_REW_WEIGHT 0.1 --REW_MODEL_EXP '0'
```
where `${EXP_IDX}` is the index of the experiment, such as `"2023"`.

For our variant of RewardMLE+GS $N = 5,\Phi=\exp(\cdot)$ in Table 1 of the paper, please run the following command
```angular2html
bash ./run_multiple_seeds.sh --EXP_IDX ${EXP_IDX} --REWARD_SAMPLES 5 --REWARD_LOSS "listMLE" --LISTMLE_TEMP 1 --LISTNET_POW 0 --POLICY_TRAIN_DATA_FRAC 1 --NEG_REW_WEIGHT 1.0 --REW_MODEL_EXP '0'
```
where `${EXP_IDX}` is again the index of the experiment.

### Evaluating the Released Checkpoints

To facilitate reproducibility, we release a checkpoint for each of the variant
RewardNet+GS $N = 3,\Phi=(\cdot)^1$ and RewardMLE+GS $N = 5,\Phi=\exp(\cdot)$ in Table 1 of the paper.
The released checkpoints are both trained under random seed `999` of the tested five seeds `(111 333 555 777 999)`.

To evaluate the checkpoints, please try the following steps.
Here `Exp1` corresponds to the variant of RewardNet+GS $N =3,\Phi=(\cdot)^1$ and `Exp2` for RewardMLE+GS $N = 5,\Phi=\exp(\cdot)$.

1. Download and unzip the checkpoints from [here](https://drive.google.com/file/d/1EUIno8hq94smUqBBnzr_m8svMWKKH7P5/view?usp=sharing). Put the resulting folders into a folder named `experiments`.
2. Download and unzip the processed data from [here](https://drive.google.com/file/d/1fwLK62U38B3pxYxzrycGyEwt4_AFRv7l/view?usp=sharing). Put the resulting folders into the folder `damd_multiwoz`.
3. Try the following command
```angular2html
python train.py --model_path "Exp${EXP_IDX}/all_sd999/" \
--mode 'test' --context_window 2 --pretrained_checkpoint bart-large-cnn \
--back_bone bart --cfg seed=999 cuda_device=0 batch_size=8 early_stop_count=7 \
--caspi_returns_file="fn_Gs_10_0.0_resp_soft.json" --caspi_wt=5. \
--caspi_data_file=data_for_damd.json --caspi_val_fraction=.5 --caspi --data_folder "Exp${EXP_IDX}data/s999_K10_GAMMA0.0" \
--exp_idx ${EXP_IDX}
```
where `${EXP_IDX}` should be replaced by `1` or `2`.


## Acknowledgement
This codebase builds on the following codebases and datasets:
* [**CASPI**](https://github.com/salesforce/CASPI).
* [**MinTL**](https://github.com/zlinao/MinTL).
* [**DAMD**](https://gitlab.com/ucdavisnlp/damd-multiwoz).
* [**Multiwoz2.0**](https://github.com/budzianowski/multiwoz).
* [**ConvLab Multiwoz2.0 annotation**](https://github.com/ConvLab/ConvLab/tree/master/data/multiwoz/annotation).
Loading

0 comments on commit 60978a3

Please sign in to comment.