Skip to content

Commit

Permalink
minor restructure of main
Browse files Browse the repository at this point in the history
  • Loading branch information
Mitchell Wortsman committed Mar 26, 2019
1 parent 7765015 commit eab929a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ Use the following code to run the pretrained models on the test set. Add the arg
#### SAVN
```bash
python main.py --eval \
--test_or_val test \
--test_or_val test \
--episode_type TestValEpisode \
--load_model pretrained_models/savn_pretrained.dat \
--model SAVN \
--load_model pretrained_models/savn_pretrained.dat \
--model SAVN \
--results_json savn_test.json

cat savn_test.json
Expand All @@ -73,11 +73,11 @@ cat savn_test.json
```bash
python main.py --eval \
--test_or_val test \
--episode_type TestValEpisode \
--load_model pretrained_models/gcn_pretrained.dat \
--model GCN \
--glove_dir ./data/gcn \
--results_json scene_priors_test.json
--episode_type TestValEpisode \
--load_model pretrained_models/gcn_pretrained.dat \
--model GCN \
--glove_dir ./data/gcn \
--results_json scene_priors_test.json

cat scene_priors_test.json
```
Expand All @@ -86,9 +86,9 @@ cat scene_priors_test.json
#### Non-Adaptvie-A3C
```bash
python main.py --eval \
--test_or_val test \
--episode_type TestValEpisode \
--load_model pretrained_models/nonadaptivea3c_pretrained.dat \
--test_or_val test \
--episode_type TestValEpisode \
--load_model pretrained_models/nonadaptivea3c_pretrained.dat \
--results_json nonadaptivea3c_test.json

cat nonadaptivea3c_test.json
Expand All @@ -112,7 +112,7 @@ python main.py \
```bash
python main.py \
--title nonadaptivea3c_train \
--gpu-ids 0 1 \
--gpu-ids 0 1 \
--workers 12
```

Expand Down
17 changes: 9 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ def main():
setproctitle.setproctitle("Train/Test Manager")
args = flag_parser.parse_arguments()

if args.model == "BaseModel" or args.model == "GCN":
args.learned_loss = False
args.num_steps = 50
target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train
else:
args.learned_loss = True
args.num_steps = 6
target = savn_val if args.eval else savn_train

create_shared_model = model_class(args.model)
init_agent = agent_class(args.agent_type)
optimizer_type = optimizer_class(args.optimizer)
Expand Down Expand Up @@ -78,14 +87,6 @@ def main():
end_flag = mp.Value(ctypes.c_bool, False)

train_res_queue = mp.Queue()
if args.model == "BaseModel" or args.model == "GCN":
args.learned_loss = False
args.num_steps = 50
target = nonadaptivea3c_val if args.eval else nonadaptivea3c_train
else:
args.learned_loss = True
args.num_steps = 6
target = savn_val if args.eval else savn_train

for rank in range(0, args.workers):
p = mp.Process(
Expand Down

0 comments on commit eab929a

Please sign in to comment.