Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ void AgentStep()
{
m_StepCount += 1;
}

if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
Expand Down
41 changes: 41 additions & 0 deletions com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ internal void SetPolicy(IPolicy policy)

public int initializeAgentCalls;
public int collectObservationsCalls;
public int collectObservationsCallsSinceLastReset;
public int agentActionCalls;
public int agentActionCallsSinceLastReset;
public int agentResetCalls;
public override void InitializeAgent()
{
Expand All @@ -54,18 +56,22 @@ public override void InitializeAgent()
public override void CollectObservations()
{
collectObservationsCalls += 1;
collectObservationsCallsSinceLastReset += 1;
AddVectorObs(0f);
}

public override void AgentAction(float[] vectorAction)
{
agentActionCalls += 1;
agentActionCallsSinceLastReset += 1;
AddReward(0.1f);
}

public override void AgentReset()
{
agentResetCalls += 1;
collectObservationsCallsSinceLastReset = 0;
agentActionCallsSinceLastReset = 0;
}

public override float[] Heuristic()
Expand Down Expand Up @@ -500,5 +506,40 @@ public void TestCumulativeReward()
aca.EnvironmentStep();
}
}

[Test]
public void TestMaxStepsReset()
{
var agentGo1 = new GameObject("TestAgent");
agentGo1.AddComponent<TestAgent>();
var agent1 = agentGo1.GetComponent<TestAgent>();
var aca = Academy.Instance;

var decisionRequester = agent1.gameObject.AddComponent<DecisionRequester>();
decisionRequester.DecisionPeriod = 1;
decisionRequester.Awake();

var maxStep = 6;
agent1.maxStep = maxStep;
agent1.LazyInitialize();

for (var i = 0; i < 15; i++)
{
// We expect resets to occur when there are maxSteps actions since the last reset (and on the first step)
var expectReset = agent1.agentActionCallsSinceLastReset == maxStep || (i == 0);
var previousNumResets = agent1.agentResetCalls;

aca.EnvironmentStep();

if (expectReset)
{
Assert.AreEqual(previousNumResets + 1, agent1.agentResetCalls);
}
else
{
Assert.AreEqual(previousNumResets, agent1.agentResetCalls);
}
}
}
}
}