Skip to content

Commit

Permalink
[MLA-16] add filter mask to ray perception (#3111)
Browse files Browse the repository at this point in the history
* add filter mask to ray perception

* use LayerMask type
  • Loading branch information
Chris Elion committed Dec 20, 2019
1 parent dfe9c11 commit ee81d99
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
6 changes: 6 additions & 0 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/RayPerceptionTests.cs
Expand Up @@ -30,5 +30,11 @@ public void TestPerception2D()
tags);
Assert.IsTrue(result.Count == angles.Length * (tags.Length + 2));
}

[Test]
public void TestConstants()
{
Assert.AreEqual(Physics.DefaultRaycastLayers, Physics2D.DefaultRaycastLayers);
}
}
}
20 changes: 14 additions & 6 deletions UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs
Expand Up @@ -25,6 +25,7 @@ public enum CastType
float m_CastRadius;
CastType m_CastType;
Transform m_Transform;
int m_LayerMask;

/// <summary>
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
Expand Down Expand Up @@ -67,7 +68,8 @@ public DebugDisplayInfo debugDisplayInfo
}

public RayPerceptionSensor(string name, float rayDistance, List<string> detectableObjects, float[] angles,
Transform transform, float startOffset, float endOffset, float castRadius, CastType castType)
Transform transform, float startOffset, float endOffset, float castRadius, CastType castType,
int rayLayerMask)
{
var numObservations = (detectableObjects.Count + 2) * angles.Length;
m_Shape = new[] { numObservations };
Expand All @@ -84,6 +86,7 @@ public DebugDisplayInfo debugDisplayInfo
m_EndOffset = endOffset;
m_CastRadius = castRadius;
m_CastType = castType;
m_LayerMask = rayLayerMask;

if (Application.isEditor)
{
Expand All @@ -97,7 +100,8 @@ public int Write(WriteAdapter adapter)
{
PerceiveStatic(
m_RayDistance, m_Angles, m_DetectableObjects, m_StartOffset, m_EndOffset,
m_CastRadius, m_Transform, m_CastType, m_Observations, false, m_DebugDisplayInfo
m_CastRadius, m_Transform, m_CastType, m_Observations, false, m_LayerMask,
m_DebugDisplayInfo
);
adapter.AddRange(m_Observations);
}
Expand Down Expand Up @@ -164,6 +168,7 @@ public virtual SensorCompressionType GetCompressionType()
float startOffset, float endOffset, float castRadius,
Transform transform, CastType castType, float[] perceptionBuffer,
bool legacyHitFractionBehavior = false,
int layerMask = Physics.DefaultRaycastLayers,
DebugDisplayInfo debugInfo = null)
{
Array.Clear(perceptionBuffer, 0, perceptionBuffer.Length);
Expand Down Expand Up @@ -221,11 +226,13 @@ public virtual SensorCompressionType GetCompressionType()
RaycastHit rayHit;
if (castRadius > 0f)
{
castHit = Physics.SphereCast(startPositionWorld, castRadius, rayDirection, out rayHit, rayLength);
castHit = Physics.SphereCast(startPositionWorld, castRadius, rayDirection, out rayHit,
rayLength, layerMask);
}
else
{
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit, rayLength);
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit,
rayLength, layerMask);
}

hitFraction = castHit ? rayHit.distance / rayLength : 1.0f;
Expand All @@ -236,11 +243,12 @@ public virtual SensorCompressionType GetCompressionType()
RaycastHit2D rayHit;
if (castRadius > 0f)
{
rayHit = Physics2D.CircleCast(startPositionWorld, castRadius, rayDirection, rayLength);
rayHit = Physics2D.CircleCast(startPositionWorld, castRadius, rayDirection,
rayLength, layerMask);
}
else
{
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, rayLength);
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, rayLength, layerMask);
}

castHit = rayHit;
Expand Down
Expand Up @@ -27,6 +27,9 @@ public abstract class RayPerceptionSensorComponentBase : SensorComponent
[Tooltip("Length of the rays to cast.")]
public float rayLength = 20f;

[Tooltip("Controls which layers the rays can hit.")]
public LayerMask rayLayerMask = Physics.DefaultRaycastLayers;

[Range(1, 50)]
[Tooltip("Whether to stack previous observations. Using 1 means no previous observations.")]
public int observationStacks = 1;
Expand Down Expand Up @@ -57,7 +60,8 @@ public override ISensor CreateSensor()
{
var rayAngles = GetRayAngles(raysPerDirection, maxRayDegrees);
m_RaySensor = new RayPerceptionSensor(sensorName, rayLength, detectableTags, rayAngles,
transform, GetStartVerticalOffset(), GetEndVerticalOffset(), sphereCastRadius, GetCastType()
transform, GetStartVerticalOffset(), GetEndVerticalOffset(), sphereCastRadius, GetCastType(),
rayLayerMask
);

if (observationStacks != 1)
Expand Down
2 changes: 1 addition & 1 deletion UnitySDK/ProjectSettings/ProjectVersion.txt
@@ -1 +1 @@
m_EditorVersion: 2017.4.32f1
m_EditorVersion: 2017.4.33f1

0 comments on commit ee81d99

Please sign in to comment.