diff --git a/com.unity.ml-agents.extensions/LICENSE.md b/com.unity.ml-agents.extensions/LICENSE.md
index 1f43a89bfa..3b2574fbe2 100644
--- a/com.unity.ml-agents.extensions/LICENSE.md
+++ b/com.unity.ml-agents.extensions/LICENSE.md
@@ -1,4 +1,4 @@
-com.unity.ml-agents.extensions copyright © 2020 Unity Technologies
+com.unity.ml-agents.extensions copyright © 2020 Unity Technologies ApS
Licensed under the Unity Companion License for Unity-dependent projects -- see
[Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
index 3f5068ed92..512b857345 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
@@ -5,13 +5,20 @@
namespace Unity.MLAgents.Extensions.Sensors
{
-
+ ///
+ /// Utility class to track a hierarchy of ArticulationBodies.
+ ///
public class ArticulationBodyPoseExtractor : PoseExtractor
{
ArticulationBody[] m_Bodies;
public ArticulationBodyPoseExtractor(ArticulationBody rootBody)
{
+ if (rootBody == null)
+ {
+ return;
+ }
+
if (!rootBody.isRoot)
{
Debug.Log("Must pass ArticulationBody.isRoot");
@@ -38,14 +45,25 @@ public ArticulationBodyPoseExtractor(ArticulationBody rootBody)
for (var i = 1; i < numBodies; i++)
{
- var body = m_Bodies[i];
- var parent = body.GetComponentInParent();
- parentIndices[i] = bodyToIndex[parent];
+ var currentArticBody = m_Bodies[i];
+ // Component.GetComponentInParent will consider the provided object as well.
+ // So start looking from the parent.
+ var currentGameObject = currentArticBody.gameObject;
+ var parentGameObject = currentGameObject.transform.parent;
+ var parentArticBody = parentGameObject.GetComponentInParent();
+ parentIndices[i] = bodyToIndex[parentArticBody];
}
SetParentIndices(parentIndices);
}
+ ///
+ protected override Vector3 GetLinearVelocityAt(int index)
+ {
+ return m_Bodies[index].velocity;
+ }
+
+ ///
protected override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];
@@ -53,8 +71,6 @@ protected override Pose GetPoseAt(int index)
var t = go.transform;
return new Pose { rotation = t.rotation, position = t.position };
}
-
-
}
}
#endif // UNITY_2020_1_OR_NEWER
\ No newline at end of file
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
new file mode 100644
index 0000000000..23682bf2f6
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
@@ -0,0 +1,41 @@
+#if UNITY_2020_1_OR_NEWER
+using UnityEngine;
+using Unity.MLAgents.Sensors;
+
+namespace Unity.MLAgents.Extensions.Sensors
+{
+ public class ArticulationBodySensorComponent : SensorComponent
+ {
+ public ArticulationBody RootBody;
+
+ [SerializeField]
+ public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();
+ public string sensorName;
+
+ ///
+ /// Creates a PhysicsBodySensor.
+ ///
+ ///
+ public override ISensor CreateSensor()
+ {
+ return new PhysicsBodySensor(RootBody, Settings, sensorName);
+ }
+
+ ///
+ public override int[] GetObservationShape()
+ {
+ if (RootBody == null)
+ {
+ return new[] { 0 };
+ }
+
+ // TODO static method in PhysicsBodySensor?
+ // TODO only update PoseExtractor when body changes?
+ var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
+ var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
+ return new[] { numTransformObservations };
+ }
+ }
+
+}
+#endif // UNITY_2020_1_OR_NEWER
\ No newline at end of file
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta
new file mode 100644
index 0000000000..3cdd83ac52
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: e57a788acd5e049c6aa9642b450ca318
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
new file mode 100644
index 0000000000..6b0bb2ca0f
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
@@ -0,0 +1,93 @@
+using UnityEngine;
+using Unity.MLAgents.Sensors;
+
+namespace Unity.MLAgents.Extensions.Sensors
+{
+ ///
+ /// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
+ ///
+ public class PhysicsBodySensor : ISensor
+ {
+ int[] m_Shape;
+ string m_SensorName;
+
+ PoseExtractor m_PoseExtractor;
+ PhysicsSensorSettings m_Settings;
+
+ ///
+ /// Construct a new PhysicsBodySensor
+ ///
+ ///
+ ///
+ ///
+ public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
+ {
+ m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
+ m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
+ m_Settings = settings;
+
+ var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
+ m_Shape = new[] { numTransformObservations };
+ }
+
+#if UNITY_2020_1_OR_NEWER
+ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
+ {
+ m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody);
+ m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
+ m_Settings = settings;
+
+ var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
+ m_Shape = new[] { numTransformObservations };
+ }
+#endif
+
+ ///
+ public int[] GetObservationShape()
+ {
+ return m_Shape;
+ }
+
+ ///
+ public int Write(ObservationWriter writer)
+ {
+ var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
+ return numWritten;
+ }
+
+ ///
+ public byte[] GetCompressedObservation()
+ {
+ return null;
+ }
+
+ ///
+ public void Update()
+ {
+ if (m_Settings.UseModelSpace)
+ {
+ m_PoseExtractor.UpdateModelSpacePoses();
+ }
+
+ if (m_Settings.UseLocalSpace)
+ {
+ m_PoseExtractor.UpdateLocalSpacePoses();
+ }
+ }
+
+ ///
+ public void Reset() {}
+
+ ///
+ public SensorCompressionType GetCompressionType()
+ {
+ return SensorCompressionType.None;
+ }
+
+ ///
+ public string GetName()
+ {
+ return m_SensorName;
+ }
+ }
+}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta
new file mode 100644
index 0000000000..2fce9c0200
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 254640b3578a24bd2838c1fa39f1011a
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
index 3a920eb8cd..31a48e31c9 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
@@ -4,6 +4,9 @@
namespace Unity.MLAgents.Extensions.Sensors
{
+ ///
+ /// Settings that define the observations generated for physics-based sensors.
+ ///
[Serializable]
public struct PhysicsSensorSettings
{
@@ -13,7 +16,7 @@ public struct PhysicsSensorSettings
public bool UseModelSpaceTranslations;
///
- /// Whether to use model space (relative to the root body) rotatoins as observations.
+ /// Whether to use model space (relative to the root body) rotations as observations.
///
public bool UseModelSpaceRotations;
@@ -27,6 +30,16 @@ public struct PhysicsSensorSettings
///
public bool UseLocalSpaceRotations;
+ ///
+ /// Whether to use model space (relative to the root body) linear velocities as observations.
+ ///
+ public bool UseModelSpaceLinearVelocity;
+
+ ///
+ /// Whether to use local space (relative to the parent body) linear velocities as observations.
+ ///
+ public bool UseLocalSpaceLinearVelocity;
+
///
/// Creates a PhysicsSensorSettings with reasonable default values.
///
@@ -45,7 +58,7 @@ public static PhysicsSensorSettings Default()
///
public bool UseModelSpace
{
- get { return UseModelSpaceTranslations || UseModelSpaceRotations; }
+ get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; }
}
///
@@ -53,7 +66,7 @@ public bool UseModelSpace
///
public bool UseLocalSpace
{
- get { return UseLocalSpaceTranslations || UseLocalSpaceRotations; }
+ get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
}
@@ -70,6 +83,9 @@ public int TransformSize(int numTransforms)
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;
+ obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0;
+ obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0;
+
return numTransforms * obsPerTransform;
}
}
@@ -89,8 +105,12 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
var offset = baseOffset;
if (settings.UseModelSpace)
{
- foreach (var pose in poseExtractor.ModelSpacePoses)
+ var poses = poseExtractor.ModelSpacePoses;
+ var vels = poseExtractor.ModelSpaceVelocities;
+
+ for(var i=0; i
/// Read access to the model space transforms.
///
@@ -35,13 +39,44 @@ public IList LocalSpacePoses
}
///
- /// Number of transforms in the hierarchy (read-only).
+ /// Read access to the model space linear velocities.
+ ///
+ public IList ModelSpaceVelocities
+ {
+ get { return m_ModelSpaceLinearVelocities; }
+ }
+
+ ///
+ /// Read access to the local space linear velocities.
+ ///
+ public IList LocalSpaceVelocities
+ {
+ get { return m_LocalSpaceLinearVelocities; }
+ }
+
+ ///
+ /// Number of poses in the hierarchy (read-only).
///
public int NumPoses
{
get { return m_ModelSpacePoses?.Length ?? 0; }
}
+ ///
+ /// Get the parent index of the body at the specified index.
+ ///
+ ///
+ ///
+ public int GetParentIndex(int index)
+ {
+ if (m_ParentIndices == null)
+ {
+ return -1;
+ }
+
+ return m_ParentIndices[index];
+ }
+
///
/// Initialize with the mapping of parent indices.
/// The 0th element is assumed to be -1, indicating that it's the root.
@@ -53,6 +88,9 @@ protected void SetParentIndices(int[] parentIndices)
var numTransforms = parentIndices.Length;
m_ModelSpacePoses = new Pose[numTransforms];
m_LocalSpacePoses = new Pose[numTransforms];
+
+ m_ModelSpaceLinearVelocities = new Vector3[numTransforms];
+ m_LocalSpaceLinearVelocities = new Vector3[numTransforms];
}
///
@@ -62,6 +100,14 @@ protected void SetParentIndices(int[] parentIndices)
///
protected abstract Pose GetPoseAt(int index);
+ ///
+ /// Return the world space linear velocity of the i'th object.
+ ///
+ ///
+ ///
+ protected abstract Vector3 GetLinearVelocityAt(int index);
+
+
///
/// Update the internal model space transform storage based on the underlying system.
///
@@ -72,13 +118,19 @@ public void UpdateModelSpacePoses()
return;
}
- var worldTransform = GetPoseAt(0);
- var worldToModel = worldTransform.Inverse();
+ var rootWorldTransform = GetPoseAt(0);
+ var worldToModel = rootWorldTransform.Inverse();
+ var rootLinearVel = GetLinearVelocityAt(0);
for (var i = 0; i < m_ModelSpacePoses.Length; i++)
{
- var currentTransform = GetPoseAt(i);
- m_ModelSpacePoses[i] = worldToModel.Multiply(currentTransform);
+ var currentWorldSpacePose = GetPoseAt(i);
+ var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose);
+ m_ModelSpacePoses[i] = currentModelSpacePose;
+
+ var currentBodyLinearVel = GetLinearVelocityAt(i);
+ var relativeVelocity = currentBodyLinearVel - rootLinearVel;
+ m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
}
}
@@ -102,16 +154,21 @@ public void UpdateLocalSpacePoses()
var invParent = parentTransform.Inverse();
var currentTransform = GetPoseAt(i);
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
+
+ var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
+ var currentLinearVel = GetLinearVelocityAt(i);
+ m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
}
else
{
m_LocalSpacePoses[i] = Pose.identity;
+ m_LocalSpaceLinearVelocities[i] = Vector3.zero;
}
}
}
- public void DrawModelSpace(Vector3 offset)
+ internal void DrawModelSpace(Vector3 offset)
{
UpdateLocalSpacePoses();
UpdateModelSpacePoses();
@@ -138,6 +195,9 @@ public void DrawModelSpace(Vector3 offset)
}
}
+ ///
+ /// Extension methods for the Pose struct, in order to improve the readability of some math.
+ ///
public static class PoseExtensions
{
///
@@ -165,6 +225,19 @@ public static Pose Multiply(this Pose pose, Pose rhs)
return rhs.GetTransformedBy(pose);
}
+ ///
+ /// Transform the vector by the pose. Conceptually this is equivalent to treating the Pose
+ /// as a 4x4 matrix and multiplying the augmented vector.
+ /// See https://en.wikipedia.org/wiki/Affine_transformation#Augmented_matrix for more details.
+ ///
+ ///
+ ///
+ ///
+ public static Vector3 Multiply(this Pose pose, Vector3 rhs)
+ {
+ return pose.rotation * rhs + pose.position;
+ }
+
// TODO optimize inv(A)*B?
}
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
index 176f21b2da..9036cbc4e5 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
@@ -16,13 +16,22 @@ public class RigidBodyPoseExtractor : PoseExtractor
/// Initialize given a root RigidBody.
///
///
- public RigidBodyPoseExtractor(Rigidbody rootBody)
+ public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null)
{
if (rootBody == null)
{
return;
}
- var rbs = rootBody.GetComponentsInChildren ();
+
+ Rigidbody[] rbs;
+ if (rootGameObject == null)
+ {
+ rbs = rootBody.GetComponentsInChildren();
+ }
+ else
+ {
+ rbs = rootGameObject.GetComponentsInChildren();
+ }
var bodyToIndex = new Dictionary(rbs.Length);
var parentIndices = new int[rbs.Length];
@@ -54,17 +63,18 @@ public RigidBodyPoseExtractor(Rigidbody rootBody)
SetParentIndices(parentIndices);
}
- ///
- /// Get the pose of the i'th RigidBody.
- ///
- ///
- ///
+ ///
+ protected override Vector3 GetLinearVelocityAt(int index)
+ {
+ return m_Bodies[index].velocity;
+ }
+
+ ///
protected override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];
return new Pose { rotation = body.rotation, position = body.position };
}
-
-
}
+
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
new file mode 100644
index 0000000000..88202bbac1
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
@@ -0,0 +1,52 @@
+using UnityEngine;
+using Unity.MLAgents.Sensors;
+
+namespace Unity.MLAgents.Extensions.Sensors
+{
+ ///
+ /// Editor component that creates a PhysicsBodySensor for the Agent.
+ ///
+ public class RigidBodySensorComponent : SensorComponent
+ {
+ ///
+ /// The root Rigidbody of the system.
+ ///
+ public Rigidbody RootBody;
+
+ ///
+ /// Settings defining what types of observations will be generated.
+ ///
+ [SerializeField]
+ public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();
+
+ ///
+ /// Optional sensor name. This must be unique for each Agent.
+ ///
+ public string sensorName;
+
+ ///
+ /// Creates a PhysicsBodySensor.
+ ///
+ ///
+ public override ISensor CreateSensor()
+ {
+ return new PhysicsBodySensor(RootBody, gameObject, Settings, sensorName);
+ }
+
+ ///
+ public override int[] GetObservationShape()
+ {
+ if (RootBody == null)
+ {
+ return new[] { 0 };
+ }
+
+ // TODO static method in PhysicsBodySensor?
+ // TODO only update PoseExtractor when body changes?
+ var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject);
+ var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
+ return new[] { numTransformObservations };
+ }
+ }
+
+}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta
new file mode 100644
index 0000000000..59ba148382
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: df0f8be9a37d6486498061e2cbc4cd94
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs
new file mode 100644
index 0000000000..b6f640e53e
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs
@@ -0,0 +1,63 @@
+#if UNITY_2020_1_OR_NEWER
+using UnityEngine;
+using NUnit.Framework;
+using Unity.MLAgents.Extensions.Sensors;
+
+namespace Unity.MLAgents.Extensions.Tests.Sensors
+{
+ public class ArticulationBodyPoseExtractorTests
+ {
+ [TearDown]
+ public void RemoveGameObjects()
+ {
+ var objects = GameObject.FindObjectsOfType();
+ foreach (var o in objects)
+ {
+ UnityEngine.Object.DestroyImmediate(o);
+ }
+ }
+
+ [Test]
+ public void TestNullRoot()
+ {
+ var poseExtractor = new ArticulationBodyPoseExtractor(null);
+ // These should be no-ops
+ poseExtractor.UpdateLocalSpacePoses();
+ poseExtractor.UpdateModelSpacePoses();
+
+ Assert.AreEqual(0, poseExtractor.NumPoses);
+ }
+
+ [Test]
+ public void TestSingleBody()
+ {
+ var go = new GameObject();
+ var rootArticBody = go.AddComponent();
+ var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody);
+ Assert.AreEqual(1, poseExtractor.NumPoses);
+ }
+
+ [Test]
+ public void TestTwoBodies()
+ {
+ // * rootObj
+ // - rootArticBody
+ // * leafGameObj
+ // - leafArticBody
+ var rootObj = new GameObject();
+ var rootArticBody = rootObj.AddComponent();
+
+ var leafGameObj = new GameObject();
+ var leafArticBody = leafGameObj.AddComponent();
+ leafGameObj.transform.SetParent(rootObj.transform);
+
+ leafArticBody.jointType = ArticulationJointType.RevoluteJoint;
+
+ var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody);
+ Assert.AreEqual(2, poseExtractor.NumPoses);
+ Assert.AreEqual(-1, poseExtractor.GetParentIndex(0));
+ Assert.AreEqual(0, poseExtractor.GetParentIndex(1));
+ }
+ }
+}
+#endif
\ No newline at end of file
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta
new file mode 100644
index 0000000000..af76a34446
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 934ea08cde59d4356bc41e040d333c3d
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
new file mode 100644
index 0000000000..33d23d4697
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
@@ -0,0 +1,113 @@
+#if UNITY_2020_1_OR_NEWER
+using UnityEngine;
+using NUnit.Framework;
+using Unity.MLAgents.Extensions.Sensors;
+
+
+namespace Unity.MLAgents.Extensions.Tests.Sensors
+{
+
+ public class ArticulationBodySensorTests
+ {
+ [Test]
+ public void TestNullRootBody()
+ {
+ var gameObj = new GameObject();
+
+ var sensorComponent = gameObj.AddComponent();
+ var sensor = sensorComponent.CreateSensor();
+ SensorTestHelper.CompareObservation(sensor, new float[0]);
+ }
+
+ [Test]
+ public void TestSingleBody()
+ {
+ var gameObj = new GameObject();
+ var articulationBody = gameObj.AddComponent();
+ var sensorComponent = gameObj.AddComponent();
+ sensorComponent.RootBody = articulationBody;
+ sensorComponent.Settings = new PhysicsSensorSettings
+ {
+ UseModelSpaceLinearVelocity = true,
+ UseLocalSpaceTranslations = true,
+ UseLocalSpaceRotations = true
+ };
+
+ var sensor = sensorComponent.CreateSensor();
+ sensor.Update();
+ var expected = new[]
+ {
+ 0f, 0f, 0f, // ModelSpaceLinearVelocity
+ 0f, 0f, 0f, // LocalSpaceTranslations
+ 0f, 0f, 0f, 1f // LocalSpaceRotations
+ };
+ SensorTestHelper.CompareObservation(sensor, expected);
+ }
+
+ [Test]
+ public void TestBodiesWithJoint()
+ {
+ var rootObj = new GameObject();
+ var rootArticBody = rootObj.AddComponent();
+
+ var middleGamObj = new GameObject();
+ var middleArticBody = middleGamObj.AddComponent();
+ middleArticBody.AddForce(new Vector3(0f, 1f, 0f));
+ middleGamObj.transform.SetParent(rootObj.transform);
+ middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f);
+ middleArticBody.jointType = ArticulationJointType.RevoluteJoint;
+
+ var leafGameObj = new GameObject();
+ var leafArticBody = leafGameObj.AddComponent();
+ leafGameObj.transform.SetParent(middleGamObj.transform);
+ leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f);
+ leafArticBody.jointType = ArticulationJointType.RevoluteJoint;
+
+#if UNITY_2020_2_OR_NEWER
+ // ArticulationBody.velocity is read-only in 2020.1
+ rootArticBody.velocity = new Vector3(1f, 0f, 0f);
+ middleArticBody.velocity = new Vector3(0f, 1f, 0f);
+ leafArticBody.velocity = new Vector3(0f, 0f, 1f);
+#endif
+
+ var sensorComponent = rootObj.AddComponent();
+ sensorComponent.RootBody = rootArticBody;
+ sensorComponent.Settings = new PhysicsSensorSettings
+ {
+ UseModelSpaceTranslations = true,
+ UseLocalSpaceTranslations = true,
+#if UNITY_2020_2_OR_NEWER
+ UseLocalSpaceLinearVelocity = true
+#endif
+ };
+
+ var sensor = sensorComponent.CreateSensor();
+ sensor.Update();
+ var expected = new[]
+ {
+ // Model space
+ 0f, 0f, 0f, // Root pos
+ 13.37f, 0f, 0f, // Middle pos
+ leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
+
+ // Local space
+ 0f, 0f, 0f, // Root pos
+#if UNITY_2020_2_OR_NEWER
+ 0f, 0f, 0f, // Root vel
+#endif
+
+ 13.37f, 0f, 0f, // Attached pos
+#if UNITY_2020_2_OR_NEWER
+ -1f, 1f, 0f, // Attached vel
+#endif
+
+ 4.2f, 0f, 0f, // Leaf pos
+#if UNITY_2020_2_OR_NEWER
+ 0f, -1f, 1f // Leaf vel
+#endif
+ };
+ SensorTestHelper.CompareObservation(sensor, expected);
+ }
+ }
+}
+#endif // #if UNITY_2020_1_OR_NEWER
\ No newline at end of file
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta
new file mode 100644
index 0000000000..97a2d4ba47
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 0ef757469348342418a68826f51d0783
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
index b140e1f607..627b4d7b8b 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
@@ -13,6 +13,11 @@ protected override Pose GetPoseAt(int index)
return Pose.identity;
}
+ protected override Vector3 GetLinearVelocityAt(int index)
+ {
+ return Vector3.zero;
+ }
+
public void Init(int[] parentIndices)
{
SetParentIndices(parentIndices);
@@ -68,6 +73,12 @@ protected override Pose GetPoseAt(int index)
position = translation
};
}
+
+ protected override Vector3 GetLinearVelocityAt(int index)
+ {
+ return Vector3.zero;
+ }
+
}
[Test]
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
index 079e7dac23..a5d8b5bcb5 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
@@ -1,4 +1,3 @@
-using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
new file mode 100644
index 0000000000..5fbb74c9cf
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
@@ -0,0 +1,113 @@
+using UnityEngine;
+using NUnit.Framework;
+using Unity.MLAgents.Sensors;
+using Unity.MLAgents.Extensions.Sensors;
+
+
+namespace Unity.MLAgents.Extensions.Tests.Sensors
+{
+
+ public static class SensorTestHelper
+ {
+ public static void CompareObservation(ISensor sensor, float[] expected)
+ {
+ string errorMessage;
+ bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
+ Assert.IsTrue(isOK, errorMessage);
+ }
+ }
+
+ public class RigidBodySensorTests
+ {
+ [Test]
+ public void TestNullRootBody()
+ {
+ var gameObj = new GameObject();
+
+ var sensorComponent = gameObj.AddComponent();
+ var sensor = sensorComponent.CreateSensor();
+ SensorTestHelper.CompareObservation(sensor, new float[0]);
+ }
+
+ [Test]
+ public void TestSingleRigidbody()
+ {
+ var gameObj = new GameObject();
+ var rootRb = gameObj.AddComponent();
+ var sensorComponent = gameObj.AddComponent();
+ sensorComponent.RootBody = rootRb;
+ sensorComponent.Settings = new PhysicsSensorSettings
+ {
+ UseModelSpaceLinearVelocity = true,
+ UseLocalSpaceTranslations = true,
+ UseLocalSpaceRotations = true
+ };
+
+ var sensor = sensorComponent.CreateSensor();
+ sensor.Update();
+ var expected = new[]
+ {
+ 0f, 0f, 0f, // ModelSpaceLinearVelocity
+ 0f, 0f, 0f, // LocalSpaceTranslations
+ 0f, 0f, 0f, 1f // LocalSpaceRotations
+ };
+ SensorTestHelper.CompareObservation(sensor, expected);
+ }
+
+ [Test]
+ public void TestBodiesWithJoint()
+ {
+ var rootObj = new GameObject();
+ var rootRb = rootObj.AddComponent();
+ rootRb.velocity = new Vector3(1f, 0f, 0f);
+
+ var middleGamObj = new GameObject();
+ var middleRb = middleGamObj.AddComponent();
+ middleRb.velocity = new Vector3(0f, 1f, 0f);
+ middleGamObj.transform.SetParent(rootObj.transform);
+ middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f);
+ var joint = middleGamObj.AddComponent();
+ joint.connectedBody = rootRb;
+
+ var leafGameObj = new GameObject();
+ var leafRb = leafGameObj.AddComponent();
+ leafRb.velocity = new Vector3(0f, 0f, 1f);
+ leafGameObj.transform.SetParent(middleGamObj.transform);
+ leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f);
+ var joint2 = leafGameObj.AddComponent();
+ joint2.connectedBody = middleRb;
+
+
+ var sensorComponent = rootObj.AddComponent();
+ sensorComponent.RootBody = rootRb;
+ sensorComponent.Settings = new PhysicsSensorSettings
+ {
+ UseModelSpaceTranslations = true,
+ UseLocalSpaceTranslations = true,
+ UseLocalSpaceLinearVelocity = true
+ };
+
+ var sensor = sensorComponent.CreateSensor();
+ sensor.Update();
+ var expected = new[]
+ {
+ // Model space
+ 0f, 0f, 0f, // Root pos
+ 13.37f, 0f, 0f, // Middle pos
+ leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
+
+ // Local space
+ 0f, 0f, 0f, // Root pos
+ 0f, 0f, 0f, // Root vel
+
+ 13.37f, 0f, 0f, // Attached pos
+ -1f, 1f, 0f, // Attached vel
+
+ 4.2f, 0f, 0f, // Leaf pos
+ 0f, -1f, 1f // Leaf vel
+ };
+ SensorTestHelper.CompareObservation(sensor, expected);
+
+ }
+ }
+}
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta
new file mode 100644
index 0000000000..f7f204ce99
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: d8daf5517a7c94bfd9ac7f45f8d1bcd3
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef b/com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef
index 7d0dd20f4f..2c13fc3ac1 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef
@@ -2,7 +2,8 @@
"name": "Unity.ML-Agents.Extensions.EditorTests",
"references": [
"Unity.ML-Agents.Extensions.Editor",
- "Unity.ML-Agents.Extensions"
+ "Unity.ML-Agents.Extensions",
+ "Unity.ML-Agents"
],
"optionalUnityReferences": [
"TestAssemblies"
@@ -10,5 +11,8 @@
"includePlatforms": [
"Editor"
],
- "excludePlatforms": []
+ "excludePlatforms": [],
+ "defineConstraints": [
+ "UNITY_INCLUDE_TESTS"
+ ]
}
diff --git a/com.unity.ml-agents/Runtime/SensorHelper.cs b/com.unity.ml-agents/Runtime/SensorHelper.cs
new file mode 100644
index 0000000000..471a768b27
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/SensorHelper.cs
@@ -0,0 +1,66 @@
+using UnityEngine;
+
+namespace Unity.MLAgents.Sensors
+{
+ ///
+ /// Utility methods related to implementations.
+ ///
+ public static class SensorHelper
+ {
+ ///
+ /// Generates the observations for the provided sensor, and returns true if they equal the
+ /// expected values. If they are unequal, errorMessage is also set.
+ /// This should not generally be used in production code. It is only intended for
+ /// simplifying unit tests.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static bool CompareObservation(ISensor sensor, float[] expected, out string errorMessage)
+ {
+ var numExpected = expected.Length;
+ const float fill = -1337f;
+ var output = new float[numExpected];
+ for (var i = 0; i < numExpected; i++)
+ {
+ output[i] = fill;
+ }
+
+ if (numExpected > 0)
+ {
+ if (fill != output[0])
+ {
+ errorMessage = "Error setting output buffer.";
+ return false;
+ }
+ }
+
+ ObservationWriter writer = new ObservationWriter();
+ writer.SetTarget(output, sensor.GetObservationShape(), 0);
+
+ // Make sure ObservationWriter didn't touch anything
+ if (numExpected > 0)
+ {
+ if (fill != output[0])
+ {
+ errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have.";
+ return false;
+ }
+ }
+
+ sensor.Write(writer);
+ for (var i = 0; i < output.Length; i++)
+ {
+ if (expected[i] != output[i])
+ {
+ errorMessage = $"Expected and actual differed in position {i}. Expected: {expected[i]} Actual: {output[i]} ";
+ return false;
+ }
+ }
+
+ errorMessage = null;
+ return true;
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/SensorHelper.cs.meta b/com.unity.ml-agents/Runtime/SensorHelper.cs.meta
new file mode 100644
index 0000000000..c331abd0b6
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/SensorHelper.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 7c1189c0af42c46f7b533350d49ad3e7
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
index 7df0638680..b24dbddcc0 100644
--- a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
@@ -4,27 +4,13 @@
namespace Unity.MLAgents.Tests
{
- public class SensorTestHelper
+ public static class SensorTestHelper
{
public static void CompareObservation(ISensor sensor, float[] expected)
{
- var numExpected = expected.Length;
- const float fill = -1337f;
- var output = new float[numExpected];
- for (var i = 0; i < numExpected; i++)
- {
- output[i] = fill;
- }
- Assert.AreEqual(fill, output[0]);
-
- ObservationWriter writer = new ObservationWriter();
- writer.SetTarget(output, sensor.GetObservationShape(), 0);
-
- // Make sure ObservationWriter didn't touch anything
- Assert.AreEqual(fill, output[0]);
-
- sensor.Write(writer);
- Assert.AreEqual(expected, output);
+ string errorMessage;
+ bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
+ Assert.IsTrue(isOK, errorMessage);
}
}