Skip to content

Commit

Permalink
update readme, config import
Browse files Browse the repository at this point in the history
  • Loading branch information
bdhingra committed Apr 20, 2017
1 parent 32a612c commit e699953
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 107 deletions.
2 changes: 2 additions & 0 deletions .gitignore
@@ -0,0 +1,2 @@
*.pyc
data/
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -111,14 +111,14 @@ optional arguments:
what you gave for training). Pass "pretrained" to use
pretrained models.
```
Example:
Run without the `--model_name` argument to test on pre-trained models. Example:
```sh
python sim.py --agent rl-soft --db imdb-M --model_name rl_soft_example.m
python sim.py --agent rl-soft --db imdb-M
```
Hyperparameters
-------------------------------------------------
The following agent options can be specified in the config.py file-
The default hyperparameters for each KB split are in `settings/config_<db_name>.py`. These include:
1. RL agent options-
* `nhid`: Number of hidden units
* `batch`: Batch size
Expand Down
98 changes: 0 additions & 98 deletions config.py

This file was deleted.

6 changes: 4 additions & 2 deletions interact.py
Expand Up @@ -46,6 +46,7 @@
import numpy as np
import cPickle as pkl
import datetime
import importlib

agent_map = {'rule-no' : 'nl-rule-no',
'rl-no' : 'simple-rl-no',
Expand Down Expand Up @@ -77,8 +78,9 @@
params['dontknow_prob'] = 0.5
params['sub_prob'] = 0.05
params['max_first_turn'] = 5
shutil.copyfile('settings/config_'+params['db']+'.py', 'config.py')
from config import *
config = importlib.import_module('settings.config_'+params['db'])
agent_params = config.agent_params
dataset_params = config.dataset_params
for k,v in dataset_params[params['db']].iteritems():
params[k] = v

Expand Down
Empty file added settings/__init__.py
Empty file.
5 changes: 3 additions & 2 deletions sim.py
Expand Up @@ -42,8 +42,9 @@
params['nlg_slots_path'] = './data/nlg_slot_set.txt'
params['nlg_model_path'] = './data/pretrained/lstm_tanh_[1470015675.73]_115_120_0.657.p'

shutil.copyfile('settings/config_'+params['db']+'.py', 'config.py')
from config import *
config = importlib.import_module('settings.config_'+params['db'])
agent_params = config.agent_params
dataset_params = config.dataset_params
for k,v in dataset_params[params['db']].iteritems():
params[k] = v
for k,v in agent_params[agent_map[params['agent_type']]].iteritems():
Expand Down
5 changes: 3 additions & 2 deletions train.py
Expand Up @@ -46,8 +46,9 @@
params['nlg_slots_path'] = './data/nlg_slot_set.txt'
params['nlg_model_path'] = './data/pretrained/lstm_tanh_[1470015675.73]_115_120_0.657.p'

shutil.copyfile('settings/config_'+params['db']+'.py', 'config.py')
from config import *
config = importlib.import_module('settings.config_'+params['db'])
agent_params = config.agent_params
dataset_params = config.dataset_params
for k,v in dataset_params[params['db']].iteritems():
params[k] = v
for k,v in agent_params[agent_map[params['agent_type']]].iteritems():
Expand Down

0 comments on commit e699953

Please sign in to comment.