Skip to content

Commit cbc6a70

Browse files
Pass action masker as input to CollectObservations (#3389)
* Sentencing Action masking the same as observations I am rather unsure about the doubling of the CollectObservation methods (and the copy pasta that comes along) Need to edit the documentation and the migrating doc once we agree we want to do this * Addressing the comments * Improvements to the documentation * Editing the documentation
1 parent c3d4417 commit cbc6a70

File tree

5 files changed

+92
-68
lines changed

5 files changed

+92
-68
lines changed

Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,22 @@ public override void InitializeAgent()
3131
{
3232
}
3333

34-
public override void CollectObservations(VectorSensor sensor)
34+
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
3535
{
3636
// There are no numeric observations to collect as this environment uses visual
3737
// observations.
3838

3939
// Mask the necessary actions if selected by the user.
4040
if (maskActions)
4141
{
42-
SetMask();
42+
SetMask(actionMasker);
4343
}
4444
}
4545

4646
/// <summary>
4747
/// Applies the mask for the agents action to disallow unnecessary actions.
4848
/// </summary>
49-
void SetMask()
49+
void SetMask(ActionMasker actionMasker)
5050
{
5151
// Prevents the agent from picking an action that would make it collide with a wall
5252
var positionX = (int)transform.position.x;
@@ -55,22 +55,22 @@ void SetMask()
5555

5656
if (positionX == 0)
5757
{
58-
SetActionMask(k_Left);
58+
actionMasker.SetActionMask(k_Left);
5959
}
6060

6161
if (positionX == maxPosition)
6262
{
63-
SetActionMask(k_Right);
63+
actionMasker.SetActionMask(k_Right);
6464
}
6565

6666
if (positionZ == 0)
6767
{
68-
SetActionMask(k_Down);
68+
actionMasker.SetActionMask(k_Down);
6969
}
7070

7171
if (positionZ == maxPosition)
7272
{
73-
SetActionMask(k_Up);
73+
actionMasker.SetActionMask(k_Up);
7474
}
7575
}
7676

com.unity.ml-agents/Runtime/ActionMasker.cs

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace MLAgents
66
{
7-
internal class ActionMasker
7+
public class ActionMasker
88
{
99
/// When using discrete control, is the starting indices of the actions
1010
/// when all the branches are concatenated with each other.
@@ -19,11 +19,47 @@ internal ActionMasker(BrainParameters brainParameters)
1919
m_BrainParameters = brainParameters;
2020
}
2121

22+
/// <summary>
23+
/// Sets an action mask for discrete control agents. When used, the agent will not be
24+
/// able to perform the actions passed as argument at the next decision.
25+
/// The actionIndices correspond to the actions the agent will be unable to perform
26+
/// on the branch 0.
27+
/// </summary>
28+
/// <param name="actionIndices">The indices of the masked actions on branch 0</param>
29+
public void SetActionMask(IEnumerable<int> actionIndices)
30+
{
31+
SetActionMask(0, actionIndices);
32+
}
33+
34+
/// <summary>
35+
/// Sets an action mask for discrete control agents. When used, the agent will not be
36+
/// able to perform the action passed as argument at the next decision for the specified
37+
/// action branch. The actionIndex correspond to the action the agent will be unable
38+
/// to perform.
39+
/// </summary>
40+
/// <param name="branch">The branch for which the actions will be masked</param>
41+
/// <param name="actionIndex">The index of the masked action</param>
42+
public void SetActionMask(int branch, int actionIndex)
43+
{
44+
SetActionMask(branch, new[] { actionIndex });
45+
}
46+
47+
/// <summary>
48+
/// Sets an action mask for discrete control agents. When used, the agent will not be
49+
/// able to perform the action passed as argument at the next decision. The actionIndex
50+
/// correspond to the action the agent will be unable to perform on the branch 0.
51+
/// </summary>
52+
/// <param name="actionIndex">The index of the masked action on branch 0</param>
53+
public void SetActionMask(int actionIndex)
54+
{
55+
SetActionMask(0, new[] { actionIndex });
56+
}
57+
2258
/// <summary>
2359
/// Modifies an action mask for discrete control agents. When used, the agent will not be
24-
/// able to perform the action passed as argument at the next decision. If no branch is
25-
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
26-
/// to the action the agent will be unable to perform.
60+
/// able to perform the actions passed as argument at the next decision for the specified
61+
/// action branch. The actionIndices correspond to the action options the agent will
62+
/// be unable to perform.
2763
/// </summary>
2864
/// <param name="branch">The branch for which the actions will be masked</param>
2965
/// <param name="actionIndices">The indices of the masked actions</param>
@@ -67,7 +103,7 @@ public void SetActionMask(int branch, IEnumerable<int> actionIndices)
67103
/// </summary>
68104
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
69105
/// actions.</returns>
70-
public bool[] GetMask()
106+
internal bool[] GetMask()
71107
{
72108
if (m_CurrentMask != null)
73109
{
@@ -103,7 +139,7 @@ void AssertMask()
103139
/// <summary>
104140
/// Resets the current mask for an agent
105141
/// </summary>
106-
public void ResetMask()
142+
internal void ResetMask()
107143
{
108144
if (m_CurrentMask != null)
109145
{

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 37 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ void SendInfoToBrain()
480480
UpdateSensors();
481481
using (TimerStack.Instance.Scoped("CollectObservations"))
482482
{
483-
CollectObservations(collectObservationsSensor);
483+
CollectObservations(collectObservationsSensor, m_ActionMasker);
484484
}
485485
m_Info.actionMasks = m_ActionMasker.GetMask();
486486

@@ -522,12 +522,6 @@ void UpdateSensors()
522522
/// - <see cref="AddObservation(float)"/>
523523
/// - <see cref="AddObservation(Vector3)"/>
524524
/// - <see cref="AddObservation(Vector2)"/>
525-
/// - <see>
526-
/// <cref>AddVectorObs(float[])</cref>
527-
/// </see>
528-
/// - <see>
529-
/// <cref>AddVectorObs(List{float})</cref>
530-
/// </see>
531525
/// - <see cref="AddObservation(Quaternion)"/>
532526
/// - <see cref="AddObservation(bool)"/>
533527
/// - <see cref="AddOneHotObservation(int, int)"/>
@@ -543,53 +537,44 @@ public virtual void CollectObservations(VectorSensor sensor)
543537
}
544538

545539
/// <summary>
546-
/// Sets an action mask for discrete control agents. When used, the agent will not be
547-
/// able to perform the action passed as argument at the next decision. If no branch is
548-
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
549-
/// to the action the agent will be unable to perform.
550-
/// </summary>
551-
/// <param name="actionIndices">The indices of the masked actions on branch 0</param>
552-
protected void SetActionMask(IEnumerable<int> actionIndices)
553-
{
554-
m_ActionMasker.SetActionMask(0, actionIndices);
555-
}
556-
557-
/// <summary>
558-
/// Sets an action mask for discrete control agents. When used, the agent will not be
559-
/// able to perform the action passed as argument at the next decision. If no branch is
560-
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
561-
/// to the action the agent will be unable to perform.
562-
/// </summary>
563-
/// <param name="actionIndex">The index of the masked action on branch 0</param>
564-
protected void SetActionMask(int actionIndex)
565-
{
566-
m_ActionMasker.SetActionMask(0, new[] { actionIndex });
567-
}
568-
569-
/// <summary>
570-
/// Sets an action mask for discrete control agents. When used, the agent will not be
571-
/// able to perform the action passed as argument at the next decision. If no branch is
572-
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
573-
/// to the action the agent will be unable to perform.
574-
/// </summary>
575-
/// <param name="branch">The branch for which the actions will be masked</param>
576-
/// <param name="actionIndex">The index of the masked action</param>
577-
protected void SetActionMask(int branch, int actionIndex)
578-
{
579-
m_ActionMasker.SetActionMask(branch, new[] { actionIndex });
580-
}
581-
582-
/// <summary>
583-
/// Modifies an action mask for discrete control agents. When used, the agent will not be
584-
/// able to perform the action passed as argument at the next decision. If no branch is
585-
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
586-
/// to the action the agent will be unable to perform.
540+
/// Collects the vector observations of the agent.
541+
/// The agent observation describes the current environment from the
542+
/// perspective of the agent.
587543
/// </summary>
588-
/// <param name="branch">The branch for which the actions will be masked</param>
589-
/// <param name="actionIndices">The indices of the masked actions</param>
590-
protected void SetActionMask(int branch, IEnumerable<int> actionIndices)
544+
/// <remarks>
545+
/// An agents observation is any environment information that helps
546+
/// the Agent achieve its goal. For example, for a fighting Agent, its
547+
/// observation could include distances to friends or enemies, or the
548+
/// current level of ammunition at its disposal.
549+
/// Recall that an Agent may attach vector or visual observations.
550+
/// Vector observations are added by calling the provided helper methods
551+
/// on the VectorSensor input:
552+
/// - <see cref="AddObservation(int)"/>
553+
/// - <see cref="AddObservation(float)"/>
554+
/// - <see cref="AddObservation(Vector3)"/>
555+
/// - <see cref="AddObservation(Vector2)"/>
556+
/// - <see cref="AddObservation(Quaternion)"/>
557+
/// - <see cref="AddObservation(bool)"/>
558+
/// - <see cref="AddOneHotObservation(int, int)"/>
559+
/// Depending on your environment, any combination of these helpers can
560+
/// be used. They just need to be used in the exact same order each time
561+
/// this method is called and the resulting size of the vector observation
562+
/// needs to match the vectorObservationSize attribute of the linked Brain.
563+
/// Visual observations are implicitly added from the cameras attached to
564+
/// the Agent.
565+
/// When using Discrete Control, you can prevent the Agent from using a certain
566+
/// action by masking it. You can call the following method on the ActionMasker
567+
/// input :
568+
/// - <see cref="SetActionMask(int branch, IEnumerable<int> actionIndices)"/>
569+
/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
570+
/// - <see cref="SetActionMask(IEnumerable<int> actionIndices)"/>
571+
/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
572+
/// The branch input is the index of the action, actionIndices are the indices of the
573+
/// invalid options for that action.
574+
/// </remarks>
575+
public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
591576
{
592-
m_ActionMasker.SetActionMask(branch, actionIndices);
577+
CollectObservations(sensor);
593578
}
594579

595580
/// <summary>

docs/Learning-Environment-Design-Agents.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,10 +391,12 @@ impossible for the next decision. When the Agent is controlled by a
391391
neural network, the Agent will be unable to perform the specified action. Note
392392
that when the Agent is controlled by its Heuristic, the Agent will
393393
still be able to decide to perform the masked action. In order to mask an
394-
action, call the method `SetActionMask` within the `CollectObservation` method :
394+
action, call the method `SetActionMask` on the optional `ActionMasker` argument of the `CollectObservation` method :
395395

396396
```csharp
397-
SetActionMask(branch, actionIndices)
397+
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker){
398+
actionMasker.SetActionMask(branch, actionIndices)
399+
}
398400
```
399401

400402
Where:

docs/Migrating.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ The versions can be found in
1313
* The `Agent.CollectObservations()` virtual method now takes as input a `VectorSensor` sensor as argument. The `Agent.AddVectorObs()` methods were removed.
1414
* The `Monitor` class has been moved to the Examples Project. (It was prone to errors during testing)
1515
* The `MLAgents.Sensor` namespace has been removed. All sensors now belong to the `MLAgents` namespace.
16+
* The `SetActionMask` method must now be called on the optional `ActionMasker` argument of the `CollectObservations` method. (We now consider an action mask as a type of observation)
1617

1718

1819
### Steps to Migrate
1920
* Replace your Agent's implementation of `CollectObservations()` with `CollectObservations(VectorSensor sensor)`. In addition, replace all calls to `AddVectorObs()` with `sensor.AddObservation()` or `sensor.AddOneHotObservation()` on the `VectorSensor` passed as argument.
20-
21+
* Replace your calls to `SetActionMask` on your Agent to `ActionMasker.SetActionMask` in `CollectObservations`
2122

2223

2324
## Migrating from 0.13 to 0.14

0 commit comments

Comments
 (0)