diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 66e8520ba5..88997f6262 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -762,6 +762,7 @@ void AgentStep() { m_StepCount += 1; } + if ((m_RequestAction) && (m_Brain != null)) { m_RequestAction = false; diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 1d84ff8638..ff2ba643b9 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -37,7 +37,9 @@ internal void SetPolicy(IPolicy policy) public int initializeAgentCalls; public int collectObservationsCalls; + public int collectObservationsCallsSinceLastReset; public int agentActionCalls; + public int agentActionCallsSinceLastReset; public int agentResetCalls; public override void InitializeAgent() { @@ -54,18 +56,22 @@ public override void InitializeAgent() public override void CollectObservations() { collectObservationsCalls += 1; + collectObservationsCallsSinceLastReset += 1; AddVectorObs(0f); } public override void AgentAction(float[] vectorAction) { agentActionCalls += 1; + agentActionCallsSinceLastReset += 1; AddReward(0.1f); } public override void AgentReset() { agentResetCalls += 1; + collectObservationsCallsSinceLastReset = 0; + agentActionCallsSinceLastReset = 0; } public override float[] Heuristic() @@ -500,5 +506,40 @@ public void TestCumulativeReward() aca.EnvironmentStep(); } } + + [Test] + public void TestMaxStepsReset() + { + var agentGo1 = new GameObject("TestAgent"); + agentGo1.AddComponent(); + var agent1 = agentGo1.GetComponent(); + var aca = Academy.Instance; + + var decisionRequester = agent1.gameObject.AddComponent(); + decisionRequester.DecisionPeriod = 1; + decisionRequester.Awake(); + + var maxStep = 6; + agent1.maxStep = maxStep; + agent1.LazyInitialize(); + + for (var i = 0; i < 15; i++) + { + // We expect resets to occur when there are maxSteps actions since the last reset (and on the first step) + var expectReset = agent1.agentActionCallsSinceLastReset == maxStep || (i == 0); + var previousNumResets = agent1.agentResetCalls; + + aca.EnvironmentStep(); + + if (expectReset) + { + Assert.AreEqual(previousNumResets + 1, agent1.agentResetCalls); + } + else + { + Assert.AreEqual(previousNumResets, agent1.agentResetCalls); + } + } + } } }