From d27e2091c01eec15b4b12fd15d6d7bbbf493d894 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Wed, 19 Aug 2020 11:43:26 -0700 Subject: [PATCH 1/3] Adding seeds to simple RL tests --- ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py index e7b9e8c939..d5d3db4d0b 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py @@ -2,6 +2,7 @@ import tempfile import pytest import numpy as np +import torch import attr from typing import Dict @@ -114,7 +115,10 @@ def _check_environment_trains( env_parameter_manager=None, success_threshold=0.9, env_manager=None, + seed=1337, ): + np.random.seed(seed) + torch.manual_seed(seed) if env_parameter_manager is None: env_parameter_manager = EnvironmentParameterManager() # Create controller and begin training. @@ -210,7 +214,7 @@ def test_visual_advanced_ppo(vis_encode_type, num_visual): PPO_CONFIG, hyperparameters=new_hyperparams, network_settings=new_networksettings, - max_steps=500, + max_steps=1000, summary_freq=100, ) # The number of steps is pretty small for these encoders From ffc8e100a08a22e583dce7b098384005b37ea1a3 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Wed, 19 Aug 2020 11:53:50 -0700 Subject: [PATCH 2/3] Setting seed in trainer_controller --- ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py | 4 ---- ml-agents/mlagents/trainers/trainer_controller.py | 7 +++++++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py index d5d3db4d0b..1306525d69 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py @@ -2,7 +2,6 @@ import tempfile import pytest import numpy as np -import torch import attr from typing import Dict @@ -115,10 +114,7 @@ def _check_environment_trains( env_parameter_manager=None, success_threshold=0.9, env_manager=None, - seed=1337, ): - np.random.seed(seed) - torch.manual_seed(seed) if env_parameter_manager is None: env_parameter_manager = EnvironmentParameterManager() # Create controller and begin training. diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index 870ce7a813..65fc05b381 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -30,6 +30,11 @@ from mlagents.trainers.agent_processor import AgentManager from mlagents.tf_utils.globals import get_rank +try: + import torch +except ModuleNotFoundError: + torch = None # type: ignore + class TrainerController: def __init__( @@ -66,6 +71,8 @@ def __init__( self.kill_trainers = False np.random.seed(training_seed) tf.set_random_seed(training_seed) + if torch: + torch.manual_seed(training_seed) self.rank = get_rank() @timed From b517edd693cc4e2f5afcbf74e8c4cb22be54959f Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Wed, 19 Aug 2020 11:54:44 -0700 Subject: [PATCH 3/3] Fixing if statement --- ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py | 2 +- ml-agents/mlagents/trainers/trainer_controller.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py index 1306525d69..e7b9e8c939 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py @@ -210,7 +210,7 @@ def test_visual_advanced_ppo(vis_encode_type, num_visual): PPO_CONFIG, hyperparameters=new_hyperparams, network_settings=new_networksettings, - max_steps=1000, + max_steps=500, summary_freq=100, ) # The number of steps is pretty small for these encoders diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index 65fc05b381..550a514a1d 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -71,7 +71,7 @@ def __init__( self.kill_trainers = False np.random.seed(training_seed) tf.set_random_seed(training_seed) - if torch: + if torch is not None: torch.manual_seed(training_seed) self.rank = get_rank()