This repository has been archived by the owner on Dec 11, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 460
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
allow specifying preset as a commandline parameter to rollout worker
- Loading branch information
1 parent
3714d8e
commit e34b9ae
Showing
1 changed file
with
23 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,34 @@ | ||
import argparse | ||
|
||
from rl_coach.base_parameters import TaskParameters | ||
from rl_coach.coach import expand_preset | ||
from rl_coach.core_types import EnvironmentEpisodes, RunPhase | ||
from rl_coach.presets.CartPole_DQN import graph_manager | ||
from rl_coach.utils import short_dynamic_import | ||
|
||
|
||
|
||
# TODO: acce[t preset option | ||
# TODO: workers might need to define schedules in terms which can be synchronized: exploration(len(distributed_memory)) -> float | ||
# TODO: periodically reload policy (from disk?) | ||
# TODO: specify alternative distributed memory, or should this go in the preset? | ||
|
||
def main(): | ||
graph_manager.create_graph(TaskParameters()) | ||
def rollout_worker(graph_manager): | ||
task_parameters = TaskParameters() | ||
task_parameters.checkpoint_restore_dir='/checkpoint' | ||
graph_manager.create_graph(task_parameters) | ||
graph_manager.phase = RunPhase.TRAIN | ||
graph_manager.act(EnvironmentEpisodes(num_steps=10)) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-p', '--preset', | ||
help="(string) Name of a preset to run (class name from the 'presets' directory.)", | ||
type=str) | ||
args = parser.parse_args() | ||
|
||
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) | ||
rollout_worker(graph_manager) | ||
|
||
if __name__ == '__main__': | ||
main() |