From b80400d07ced8dcbcee5d104fdddebb044f14209 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Fri, 24 Aug 2018 16:54:58 -0700 Subject: [PATCH 1/5] GridWorld now uses action masking --- .../Examples/GridWorld/Scripts/GridAgent.cs | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs index 96186d597c..48079e83f6 100755 --- a/unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs +++ b/unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs @@ -16,7 +16,26 @@ public override void InitializeAgent() public override void CollectObservations() { - + // Prevents the agent from picking an action that would make it collide with a wall + var positionX = (int) transform.position.x; + var positionZ = (int) transform.position.z; + var maxPosition = academy.gridSize - 1; + if (positionX == 0) + { + SetActionMask(3); + } + if (positionX == maxPosition) + { + SetActionMask(4); + } + if (positionZ == 0) + { + SetActionMask(2); + } + if (positionZ == maxPosition) + { + SetActionMask(1); + } } // to be implemented by the developer From 83932fc6f171798866b1b34354bb2a85d0a285b3 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Mon, 27 Aug 2018 10:21:15 -0700 Subject: [PATCH 2/5] Addressed the comments --- .../Examples/GridWorld/Scenes/GridWorld.unity | 199 ++++++++++-------- .../Examples/GridWorld/Scripts/GridAgent.cs | 21 +- 2 files changed, 121 insertions(+), 99 deletions(-) diff --git a/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity b/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity index a026a6490b..60fca12cd5 100644 --- a/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity +++ b/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity @@ -13,7 +13,7 @@ OcclusionCullingSettings: --- !u!104 &2 RenderSettings: m_ObjectHideFlags: 0 - serializedVersion: 9 + serializedVersion: 8 m_Fog: 0 m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1} m_FogMode: 3 @@ -38,8 +38,7 @@ RenderSettings: m_ReflectionIntensity: 1 m_CustomReflection: {fileID: 0} m_Sun: {fileID: 0} - m_IndirectSpecularColor: {r: 0.43668893, g: 0.4842832, b: 0.56452656, a: 1} - m_UseRadianceAmbientProbe: 0 + m_IndirectSpecularColor: {r: 0.43667555, g: 0.4842717, b: 0.56452394, a: 1} --- !u!157 &3 LightmapSettings: m_ObjectHideFlags: 0 @@ -55,10 +54,11 @@ LightmapSettings: m_EnableBakedLightmaps: 1 m_EnableRealtimeLightmaps: 1 m_LightmapEditorSettings: - serializedVersion: 10 + serializedVersion: 9 m_Resolution: 2 m_BakeResolution: 40 - m_AtlasSize: 1024 + m_TextureWidth: 1024 + m_TextureHeight: 1024 m_AO: 0 m_AOMaxDistance: 1 m_CompAOExponent: 1 @@ -77,18 +77,15 @@ LightmapSettings: m_PVRDirectSampleCount: 32 m_PVRSampleCount: 500 m_PVRBounces: 2 - m_PVRFilterTypeDirect: 0 - m_PVRFilterTypeIndirect: 0 - m_PVRFilterTypeAO: 0 + m_PVRFiltering: 0 m_PVRFilteringMode: 1 m_PVRCulling: 1 m_PVRFilteringGaussRadiusDirect: 1 m_PVRFilteringGaussRadiusIndirect: 5 m_PVRFilteringGaussRadiusAO: 2 - m_PVRFilteringAtrousPositionSigmaDirect: 0.5 - m_PVRFilteringAtrousPositionSigmaIndirect: 2 - m_PVRFilteringAtrousPositionSigmaAO: 1 - m_ShowResolutionOverlay: 1 + m_PVRFilteringAtrousColorSigma: 1 + m_PVRFilteringAtrousNormalSigma: 1 + m_PVRFilteringAtrousPositionSigma: 1 m_LightingDataAsset: {fileID: 0} m_UseShadowmask: 0 --- !u!196 &4 @@ -110,8 +107,6 @@ NavMeshSettings: manualTileSize: 0 tileSize: 256 accuratePlacement: 0 - debug: - m_Flags: 0 m_NavMeshData: {fileID: 0} --- !u!1 &2047662 GameObject: @@ -231,6 +226,20 @@ Light: m_Lightmapping: 4 m_AreaSize: {x: 1, y: 1} m_BounceIntensity: 1 + m_FalloffTable: + m_Table[0]: 0 + m_Table[1]: 0 + m_Table[2]: 0 + m_Table[3]: 0 + m_Table[4]: 0 + m_Table[5]: 0 + m_Table[6]: 0 + m_Table[7]: 0 + m_Table[8]: 0 + m_Table[9]: 0 + m_Table[10]: 0 + m_Table[11]: 0 + m_Table[12]: 0 m_ColorTemperature: 6570 m_UseColorTemperature: 0 m_ShadowRadius: 0 @@ -311,11 +320,11 @@ Camera: m_TargetEye: 3 m_HDR: 0 m_AllowMSAA: 1 - m_AllowDynamicResolution: 0 m_ForceIntoRT: 0 m_OcclusionCulling: 1 m_StereoConvergence: 10 m_StereoSeparation: 0.022 + m_StereoMirrorMode: 0 --- !u!4 &231883447 Transform: m_ObjectHideFlags: 0 @@ -423,6 +432,18 @@ RectTransform: m_AnchoredPosition: {x: 0, y: 0} m_SizeDelta: {x: 0, y: 0} m_Pivot: {x: 0, y: 0} +--- !u!114 &382957182 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 0} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 35813a1be64e144f887d7d5f15b963fa, type: 3} + m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone) + m_EditorClassIdentifier: + brain: {fileID: 1535917239} --- !u!1 &486401523 GameObject: m_ObjectHideFlags: 0 @@ -516,11 +537,11 @@ Camera: m_TargetEye: 3 m_HDR: 0 m_AllowMSAA: 0 - m_AllowDynamicResolution: 0 m_ForceIntoRT: 0 m_OcclusionCulling: 1 m_StereoConvergence: 10 m_StereoSeparation: 0.022 + m_StereoMirrorMode: 0 --- !u!1 &742849316 GameObject: m_ObjectHideFlags: 0 @@ -610,11 +631,9 @@ MeshRenderer: m_Enabled: 1 m_CastShadows: 0 m_ReceiveShadows: 1 - m_DynamicOccludee: 1 m_MotionVectors: 1 m_LightProbeUsage: 1 m_ReflectionProbeUsage: 1 - m_RenderingLayerMask: 4294967295 m_Materials: - {fileID: 2100000, guid: 214660f4189b04cada2137381f5c3607, type: 2} m_StaticBatchInfo: @@ -627,7 +646,6 @@ MeshRenderer: m_PreserveUVs: 1 m_IgnoreNormalsForChartDetection: 0 m_ImportantGI: 0 - m_StitchLightmapSeams: 0 m_SelectedEditorRenderState: 3 m_MinimumChartSize: 4 m_AutoUVMaxDistance: 0.5 @@ -695,11 +713,9 @@ MeshRenderer: m_Enabled: 1 m_CastShadows: 0 m_ReceiveShadows: 1 - m_DynamicOccludee: 1 m_MotionVectors: 1 m_LightProbeUsage: 1 m_ReflectionProbeUsage: 1 - m_RenderingLayerMask: 4294967295 m_Materials: - {fileID: 2100000, guid: 214660f4189b04cada2137381f5c3607, type: 2} m_StaticBatchInfo: @@ -712,7 +728,6 @@ MeshRenderer: m_PreserveUVs: 1 m_IgnoreNormalsForChartDetection: 0 m_ImportantGI: 0 - m_StitchLightmapSeams: 0 m_SelectedEditorRenderState: 3 m_MinimumChartSize: 4 m_AutoUVMaxDistance: 0.5 @@ -753,31 +768,6 @@ Transform: m_Father: {fileID: 486401524} m_RootOrder: 1 m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} ---- !u!114 &1049818079 -MonoBehaviour: - m_ObjectHideFlags: 0 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 0} - m_GameObject: {fileID: 0} - m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: 943466ab374444748a364f9d6c3e2fe2, type: 3} - m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone) - m_EditorClassIdentifier: - broadcast: 1 - brain: {fileID: 1535917239} ---- !u!114 &1176639739 -MonoBehaviour: - m_ObjectHideFlags: 0 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 0} - m_GameObject: {fileID: 0} - m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: 35813a1be64e144f887d7d5f15b963fa, type: 3} - m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone) - m_EditorClassIdentifier: - brain: {fileID: 1535917239} --- !u!1 &1208586857 GameObject: m_ObjectHideFlags: 0 @@ -805,11 +795,9 @@ MeshRenderer: m_Enabled: 1 m_CastShadows: 0 m_ReceiveShadows: 1 - m_DynamicOccludee: 1 m_MotionVectors: 1 m_LightProbeUsage: 1 m_ReflectionProbeUsage: 1 - m_RenderingLayerMask: 4294967295 m_Materials: - {fileID: 2100000, guid: 8d8e8962a89d44eb28cf1b21b88014ec, type: 2} m_StaticBatchInfo: @@ -822,7 +810,6 @@ MeshRenderer: m_PreserveUVs: 1 m_IgnoreNormalsForChartDetection: 0 m_ImportantGI: 0 - m_StitchLightmapSeams: 0 m_SelectedEditorRenderState: 3 m_MinimumChartSize: 4 m_AutoUVMaxDistance: 0.5 @@ -840,9 +827,9 @@ MeshCollider: m_Material: {fileID: 0} m_IsTrigger: 0 m_Enabled: 0 - serializedVersion: 3 + serializedVersion: 2 m_Convex: 0 - m_CookingOptions: 14 + m_InflateMesh: 0 m_SkinWidth: 0.01 m_Mesh: {fileID: 10209, guid: 0000000000000000e000000000000000, type: 0} --- !u!33 &1208586860 @@ -920,11 +907,11 @@ MonoBehaviour: vectorActionSpaceType: 0 brainType: 0 CoreBrains: - - {fileID: 2068928320} - - {fileID: 1049818079} - - {fileID: 1176639739} - - {fileID: 0} - instanceID: 142728 + - {fileID: 1836367505} + - {fileID: 1702244739} + - {fileID: 382957182} + - {fileID: 1802138297} + instanceID: 14072 --- !u!1 &1553342942 GameObject: m_ObjectHideFlags: 0 @@ -999,6 +986,19 @@ CanvasRenderer: m_PrefabParentObject: {fileID: 0} m_PrefabInternal: {fileID: 0} m_GameObject: {fileID: 1553342942} +--- !u!114 &1702244739 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 0} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 943466ab374444748a364f9d6c3e2fe2, type: 3} + m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone) + m_EditorClassIdentifier: + broadcast: 1 + brain: {fileID: 1535917239} --- !u!1 &1726089810 GameObject: m_ObjectHideFlags: 0 @@ -1026,11 +1026,9 @@ MeshRenderer: m_Enabled: 1 m_CastShadows: 0 m_ReceiveShadows: 1 - m_DynamicOccludee: 1 m_MotionVectors: 1 m_LightProbeUsage: 1 m_ReflectionProbeUsage: 1 - m_RenderingLayerMask: 4294967295 m_Materials: - {fileID: 2100000, guid: 214660f4189b04cada2137381f5c3607, type: 2} m_StaticBatchInfo: @@ -1043,7 +1041,6 @@ MeshRenderer: m_PreserveUVs: 1 m_IgnoreNormalsForChartDetection: 0 m_ImportantGI: 0 - m_StitchLightmapSeams: 0 m_SelectedEditorRenderState: 3 m_MinimumChartSize: 4 m_AutoUVMaxDistance: 0.5 @@ -1084,6 +1081,57 @@ Transform: m_Father: {fileID: 486401524} m_RootOrder: 4 m_LocalEulerAnglesHint: {x: 0, y: 90, z: 0} +--- !u!114 &1802138297 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 0} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 8b23992c8eb17439887f5e944bf04a40, type: 3} + m_Name: (Clone) + m_EditorClassIdentifier: + broadcast: 1 + graphModel: {fileID: 0} + graphScope: + graphPlaceholders: [] + BatchSizePlaceholderName: batch_size + VectorObservationPlacholderName: vector_observation + RecurrentInPlaceholderName: recurrent_in + RecurrentOutPlaceholderName: recurrent_out + VisualObservationPlaceholderName: [] + ActionPlaceholderName: action + PreviousActionPlaceholderName: prev_action + brain: {fileID: 0} +--- !u!114 &1836367505 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 0} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 41e9bda8f3cf1492fa74926a530f6f70, type: 3} + m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone) + m_EditorClassIdentifier: + broadcast: 1 + keyContinuousPlayerActions: [] + axisContinuousPlayerActions: [] + discretePlayerActions: + - key: 273 + branchIndex: 0 + value: 1 + - key: 274 + branchIndex: 0 + value: 2 + - key: 276 + branchIndex: 0 + value: 3 + - key: 275 + branchIndex: 0 + value: 4 + brain: {fileID: 1535917239} --- !u!1 &1938864789 GameObject: m_ObjectHideFlags: 0 @@ -1111,11 +1159,9 @@ MeshRenderer: m_Enabled: 1 m_CastShadows: 0 m_ReceiveShadows: 1 - m_DynamicOccludee: 1 m_MotionVectors: 1 m_LightProbeUsage: 1 m_ReflectionProbeUsage: 1 - m_RenderingLayerMask: 4294967295 m_Materials: - {fileID: 2100000, guid: 214660f4189b04cada2137381f5c3607, type: 2} m_StaticBatchInfo: @@ -1128,7 +1174,6 @@ MeshRenderer: m_PreserveUVs: 1 m_IgnoreNormalsForChartDetection: 0 m_ImportantGI: 0 - m_StitchLightmapSeams: 0 m_SelectedEditorRenderState: 3 m_MinimumChartSize: 4 m_AutoUVMaxDistance: 0.5 @@ -1235,31 +1280,3 @@ GameObject: m_PrefabParentObject: {fileID: 1657514749044530, guid: 628960e910f094ad1909ecc88cc8016d, type: 2} m_PrefabInternal: {fileID: 2008405821} ---- !u!114 &2068928320 -MonoBehaviour: - m_ObjectHideFlags: 0 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 0} - m_GameObject: {fileID: 0} - m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: 41e9bda8f3cf1492fa74926a530f6f70, type: 3} - m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone) - m_EditorClassIdentifier: - broadcast: 1 - keyContinuousPlayerActions: [] - axisContinuousPlayerActions: [] - discretePlayerActions: - - key: 273 - branchIndex: 0 - value: 0 - - key: 274 - branchIndex: 0 - value: 1 - - key: 276 - branchIndex: 0 - value: 2 - - key: 275 - branchIndex: 0 - value: 3 - brain: {fileID: 1535917239} diff --git a/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs index 48079e83f6..c15c4a665e 100755 --- a/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs +++ b/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs @@ -9,6 +9,11 @@ public class GridAgent : Agent public float timeBetweenDecisionsAtInference; private float timeSinceDecision; + static int up = 1; + static int down = 2; + static int left = 3; + static int right = 4; + public override void InitializeAgent() { academy = FindObjectOfType(typeof(GridAcademy)) as GridAcademy; @@ -22,19 +27,19 @@ public override void CollectObservations() var maxPosition = academy.gridSize - 1; if (positionX == 0) { - SetActionMask(3); + SetActionMask(left); } if (positionX == maxPosition) { - SetActionMask(4); + SetActionMask(right); } if (positionZ == 0) { - SetActionMask(2); + SetActionMask(down); } if (positionZ == maxPosition) { - SetActionMask(1); + SetActionMask(up); } } @@ -46,22 +51,22 @@ public override void AgentAction(float[] vectorAction, string textAction) // 0 - Forward, 1 - Backward, 2 - Left, 3 - Right Vector3 targetPos = transform.position; - if (action == 4) + if (action == right) { targetPos = transform.position + new Vector3(1f, 0, 0f); } - if (action == 3) + if (action == left) { targetPos = transform.position + new Vector3(-1f, 0, 0f); } - if (action == 1) + if (action == up) { targetPos = transform.position + new Vector3(0f, 0, 1f); } - if (action == 2) + if (action == down) { targetPos = transform.position + new Vector3(0f, 0, -1f); } From b97c702f3bc99b3833ed7c6559562cd85beb579c Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Mon, 27 Aug 2018 16:49:06 -0700 Subject: [PATCH 3/5] addressed comments --- .../Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs index c15c4a665e..4ca92f1f20 100755 --- a/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs +++ b/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs @@ -49,7 +49,6 @@ public override void AgentAction(float[] vectorAction, string textAction) AddReward(-0.01f); int action = Mathf.FloorToInt(vectorAction[0]); - // 0 - Forward, 1 - Backward, 2 - Left, 3 - Right Vector3 targetPos = transform.position; if (action == right) { From b97d6e92934af2e60765cc331ae48a4e39c542d1 Mon Sep 17 00:00:00 2001 From: Marwan Mattar Date: Wed, 29 Aug 2018 12:14:59 -0700 Subject: [PATCH 4/5] Added checkbox to turn action masking on/off (#1146) * Added checkbox to turn action masking on/off * Fix to handle the no-action option --- .../Examples/GridWorld/Scripts/GridAgent.cs | 85 ++++++++++++------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs index 4ca92f1f20..a1957f5310 100755 --- a/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs +++ b/MLAgentsSDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs @@ -1,4 +1,5 @@ -using UnityEngine; +using System; +using UnityEngine; using System.Linq; using MLAgents; @@ -9,10 +10,15 @@ public class GridAgent : Agent public float timeBetweenDecisionsAtInference; private float timeSinceDecision; - static int up = 1; - static int down = 2; - static int left = 3; - static int right = 4; + [Tooltip("Selecting will turn on action masking. Note that a model trained with action " + + "masking turned on may not behave optimally when action masking is turned off.")] + public bool maskActions = true; + + private const int NoAction = 0; // do nothing! + private const int Up = 1; + private const int Down = 2; + private const int Left = 3; + private const int Right = 4; public override void InitializeAgent() { @@ -20,26 +26,45 @@ public override void InitializeAgent() } public override void CollectObservations() + { + // There are no numeric observations to collect as this environment uses visual + // observations. + + // Mask the necessary actions if selected by the user. + if (maskActions) + { + SetMask(); + } + } + + /// + /// Applies the mask for the agents action to disallow unnecessary actions. + /// + private void SetMask() { // Prevents the agent from picking an action that would make it collide with a wall var positionX = (int) transform.position.x; var positionZ = (int) transform.position.z; var maxPosition = academy.gridSize - 1; + if (positionX == 0) { - SetActionMask(left); + SetActionMask(Left); } + if (positionX == maxPosition) { - SetActionMask(right); + SetActionMask(Right); } + if (positionZ == 0) { - SetActionMask(down); + SetActionMask(Down); } + if (positionZ == maxPosition) { - SetActionMask(up); + SetActionMask(Up); } } @@ -50,42 +75,42 @@ public override void AgentAction(float[] vectorAction, string textAction) int action = Mathf.FloorToInt(vectorAction[0]); Vector3 targetPos = transform.position; - if (action == right) + switch (action) { - targetPos = transform.position + new Vector3(1f, 0, 0f); - } - - if (action == left) - { - targetPos = transform.position + new Vector3(-1f, 0, 0f); - } - - if (action == up) - { - targetPos = transform.position + new Vector3(0f, 0, 1f); - } - - if (action == down) - { - targetPos = transform.position + new Vector3(0f, 0, -1f); + case NoAction: + // do nothing + break; + case Right: + targetPos = transform.position + new Vector3(1f, 0, 0f); + break; + case Left: + targetPos = transform.position + new Vector3(-1f, 0, 0f); + break; + case Up: + targetPos = transform.position + new Vector3(0f, 0, 1f); + break; + case Down: + targetPos = transform.position + new Vector3(0f, 0, -1f); + break; + default: + throw new ArgumentException("Invalid action value"); } Collider[] blockTest = Physics.OverlapBox(targetPos, new Vector3(0.3f, 0.3f, 0.3f)); - if (blockTest.Where(col => col.gameObject.tag == "wall").ToArray().Length == 0) + if (blockTest.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0) { transform.position = targetPos; - if (blockTest.Where(col => col.gameObject.tag == "goal").ToArray().Length == 1) + if (blockTest.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1) { Done(); SetReward(1f); } - if (blockTest.Where(col => col.gameObject.tag == "pit").ToArray().Length == 1) + if (blockTest.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1) { Done(); SetReward(-1f); } - } } From d6642820739fc5db341b173311be4aae50697ad4 Mon Sep 17 00:00:00 2001 From: Marwan Mattar Date: Wed, 29 Aug 2018 13:05:28 -0700 Subject: [PATCH 5/5] Added comment to GridWorld mentioning the use of action masking. (#1153) --- docs/Learning-Environment-Examples.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/Learning-Environment-Examples.md b/docs/Learning-Environment-Examples.md index db4695c77c..fcdc38df7d 100644 --- a/docs/Learning-Environment-Examples.md +++ b/docs/Learning-Environment-Examples.md @@ -75,7 +75,11 @@ If you would like to contribute environments, please see our * Brains: One brain with the following observation/action space. * Vector Observation space: None * Vector Action space: (Discrete) Size of 4, corresponding to movement in - cardinal directions. + cardinal directions. Note that for this environment, + [action masking](Learning-Environment-Design-Agents.md#masking-discrete-actions) + is turned on by default (this option can be toggled + using the `Mask Actions` checkbox within the `trueAgent` GameObject). + The trained model file provided was generated with action masking turned on. * Visual Observations: One corresponding to top-down view of GridWorld. * Reset Parameters: Three, corresponding to grid size, number of obstacles, and number of goals.