diff --git a/ml-agents/mlagents/trainers/policy/policy.py b/ml-agents/mlagents/trainers/policy/policy.py index 0e23f10ec8..833caecf9f 100644 --- a/ml-agents/mlagents/trainers/policy/policy.py +++ b/ml-agents/mlagents/trainers/policy/policy.py @@ -132,6 +132,16 @@ def get_action( ) -> ActionInfo: raise NotImplementedError + @staticmethod + def check_nan_action(action: Optional[np.ndarray]) -> None: + # Fast NaN check on the action + # See https://stackoverflow.com/questions/6736590/fast-check-for-nan-in-numpy for background. + if action is not None: + d = np.sum(action) + has_nan = np.isnan(d) + if has_nan: + raise RuntimeError("NaN action detected.") + @abstractmethod def update_normalization(self, vector_obs: np.ndarray) -> None: pass diff --git a/ml-agents/mlagents/trainers/policy/tf_policy.py b/ml-agents/mlagents/trainers/policy/tf_policy.py index f10299dd90..7c35d01005 100644 --- a/ml-agents/mlagents/trainers/policy/tf_policy.py +++ b/ml-agents/mlagents/trainers/policy/tf_policy.py @@ -270,17 +270,10 @@ def get_action( ) self.save_memories(global_agent_ids, run_out.get("memory_out")) - action = run_out.get("action") - # Fast NaN check on the action - # See https://stackoverflow.com/questions/6736590/fast-check-for-nan-in-numpy for background. - if action is not None: - d = np.sum(action) - has_nan = np.isnan(d) - if has_nan: - raise RuntimeError("NaN action detected.") + self.check_nan_action(run_out.get("action")) return ActionInfo( - action=action, + action=run_out.get("action"), value=run_out.get("value"), outputs=run_out, agent_ids=decision_requests.agent_id, diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index 7e7fe521d5..5e6e07b674 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -229,6 +229,7 @@ def get_action( decision_requests, global_agent_ids ) # pylint: disable=assignment-from-no-return self.save_memories(global_agent_ids, run_out.get("memory_out")) + self.check_nan_action(run_out.get("action")) return ActionInfo( action=run_out.get("action"), value=run_out.get("value"), diff --git a/ml-agents/mlagents/trainers/tests/tensorflow/test_nn_policy.py b/ml-agents/mlagents/trainers/tests/tensorflow/test_nn_policy.py index 3308df44ec..6134619e8e 100644 --- a/ml-agents/mlagents/trainers/tests/tensorflow/test_nn_policy.py +++ b/ml-agents/mlagents/trainers/tests/tensorflow/test_nn_policy.py @@ -265,5 +265,25 @@ def test_min_visual_size(): enc_func(vis_input, 32, ModelUtils.swish, 1, "test", False) +def test_step_overflow(): + behavior_spec = mb.setup_test_behavior_specs( + use_discrete=True, use_visual=False, vector_action_space=[2], vector_obs_space=1 + ) + + policy = TFPolicy( + 0, + behavior_spec, + TrainerSettings(network_settings=NetworkSettings(normalize=True)), + create_tf_graph=False, + ) + policy.create_input_placeholders() + policy.initialize() + + policy.set_step(2 ** 31 - 1) + assert policy.get_current_step() == 2 ** 31 - 1 + policy.increment_step(3) + assert policy.get_current_step() == 2 ** 31 + 2 + + if __name__ == "__main__": pytest.main() diff --git a/ml-agents/mlagents/trainers/tests/torch/test_policy.py b/ml-agents/mlagents/trainers/tests/torch/test_policy.py index 192d0dd229..b21af43fa5 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_policy.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_policy.py @@ -140,3 +140,11 @@ def test_sample_actions(rnn, visual, discrete): if rnn: assert memories.shape == (1, 1, policy.m_size) + + +def test_step_overflow(): + policy = create_policy_mock(TrainerSettings()) + policy.set_step(2 ** 31 - 1) + assert policy.get_current_step() == 2 ** 31 - 1 # step = 2147483647 + policy.increment_step(3) + assert policy.get_current_step() == 2 ** 31 + 2 # step = 2147483650 diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index b89029e404..e9d4a5d96f 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -488,7 +488,9 @@ def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: class GlobalSteps(nn.Module): def __init__(self): super().__init__() - self.__global_step = nn.Parameter(torch.Tensor([0]), requires_grad=False) + self.__global_step = nn.Parameter( + torch.Tensor([0]).to(torch.int64), requires_grad=False + ) @property def current_step(self):