diff --git a/com.unity.ml-agents.extensions/LICENSE.md b/com.unity.ml-agents.extensions/LICENSE.md index 1f43a89bfa..3b2574fbe2 100644 --- a/com.unity.ml-agents.extensions/LICENSE.md +++ b/com.unity.ml-agents.extensions/LICENSE.md @@ -1,4 +1,4 @@ -com.unity.ml-agents.extensions copyright © 2020 Unity Technologies +com.unity.ml-agents.extensions copyright © 2020 Unity Technologies ApS Licensed under the Unity Companion License for Unity-dependent projects -- see [Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License). diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs index 3f5068ed92..512b857345 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs @@ -5,13 +5,20 @@ namespace Unity.MLAgents.Extensions.Sensors { - + /// + /// Utility class to track a hierarchy of ArticulationBodies. + /// public class ArticulationBodyPoseExtractor : PoseExtractor { ArticulationBody[] m_Bodies; public ArticulationBodyPoseExtractor(ArticulationBody rootBody) { + if (rootBody == null) + { + return; + } + if (!rootBody.isRoot) { Debug.Log("Must pass ArticulationBody.isRoot"); @@ -38,14 +45,25 @@ public ArticulationBodyPoseExtractor(ArticulationBody rootBody) for (var i = 1; i < numBodies; i++) { - var body = m_Bodies[i]; - var parent = body.GetComponentInParent(); - parentIndices[i] = bodyToIndex[parent]; + var currentArticBody = m_Bodies[i]; + // Component.GetComponentInParent will consider the provided object as well. + // So start looking from the parent. + var currentGameObject = currentArticBody.gameObject; + var parentGameObject = currentGameObject.transform.parent; + var parentArticBody = parentGameObject.GetComponentInParent(); + parentIndices[i] = bodyToIndex[parentArticBody]; } SetParentIndices(parentIndices); } + /// + protected override Vector3 GetLinearVelocityAt(int index) + { + return m_Bodies[index].velocity; + } + + /// protected override Pose GetPoseAt(int index) { var body = m_Bodies[index]; @@ -53,8 +71,6 @@ protected override Pose GetPoseAt(int index) var t = go.transform; return new Pose { rotation = t.rotation, position = t.position }; } - - } } #endif // UNITY_2020_1_OR_NEWER \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs new file mode 100644 index 0000000000..23682bf2f6 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs @@ -0,0 +1,41 @@ +#if UNITY_2020_1_OR_NEWER +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Sensors +{ + public class ArticulationBodySensorComponent : SensorComponent + { + public ArticulationBody RootBody; + + [SerializeField] + public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default(); + public string sensorName; + + /// + /// Creates a PhysicsBodySensor. + /// + /// + public override ISensor CreateSensor() + { + return new PhysicsBodySensor(RootBody, Settings, sensorName); + } + + /// + public override int[] GetObservationShape() + { + if (RootBody == null) + { + return new[] { 0 }; + } + + // TODO static method in PhysicsBodySensor? + // TODO only update PoseExtractor when body changes? + var poseExtractor = new ArticulationBodyPoseExtractor(RootBody); + var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); + return new[] { numTransformObservations }; + } + } + +} +#endif // UNITY_2020_1_OR_NEWER \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta new file mode 100644 index 0000000000..3cdd83ac52 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e57a788acd5e049c6aa9642b450ca318 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs new file mode 100644 index 0000000000..6b0bb2ca0f --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -0,0 +1,93 @@ +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Sensors +{ + /// + /// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies. + /// + public class PhysicsBodySensor : ISensor + { + int[] m_Shape; + string m_SensorName; + + PoseExtractor m_PoseExtractor; + PhysicsSensorSettings m_Settings; + + /// + /// Construct a new PhysicsBodySensor + /// + /// + /// + /// + public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null) + { + m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject); + m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName; + m_Settings = settings; + + var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); + m_Shape = new[] { numTransformObservations }; + } + +#if UNITY_2020_1_OR_NEWER + public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null) + { + m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody); + m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; + m_Settings = settings; + + var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); + m_Shape = new[] { numTransformObservations }; + } +#endif + + /// + public int[] GetObservationShape() + { + return m_Shape; + } + + /// + public int Write(ObservationWriter writer) + { + var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); + return numWritten; + } + + /// + public byte[] GetCompressedObservation() + { + return null; + } + + /// + public void Update() + { + if (m_Settings.UseModelSpace) + { + m_PoseExtractor.UpdateModelSpacePoses(); + } + + if (m_Settings.UseLocalSpace) + { + m_PoseExtractor.UpdateLocalSpacePoses(); + } + } + + /// + public void Reset() {} + + /// + public SensorCompressionType GetCompressionType() + { + return SensorCompressionType.None; + } + + /// + public string GetName() + { + return m_SensorName; + } + } +} diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta new file mode 100644 index 0000000000..2fce9c0200 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 254640b3578a24bd2838c1fa39f1011a +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs index 3a920eb8cd..31a48e31c9 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs @@ -4,6 +4,9 @@ namespace Unity.MLAgents.Extensions.Sensors { + /// + /// Settings that define the observations generated for physics-based sensors. + /// [Serializable] public struct PhysicsSensorSettings { @@ -13,7 +16,7 @@ public struct PhysicsSensorSettings public bool UseModelSpaceTranslations; /// - /// Whether to use model space (relative to the root body) rotatoins as observations. + /// Whether to use model space (relative to the root body) rotations as observations. /// public bool UseModelSpaceRotations; @@ -27,6 +30,16 @@ public struct PhysicsSensorSettings /// public bool UseLocalSpaceRotations; + /// + /// Whether to use model space (relative to the root body) linear velocities as observations. + /// + public bool UseModelSpaceLinearVelocity; + + /// + /// Whether to use local space (relative to the parent body) linear velocities as observations. + /// + public bool UseLocalSpaceLinearVelocity; + /// /// Creates a PhysicsSensorSettings with reasonable default values. /// @@ -45,7 +58,7 @@ public static PhysicsSensorSettings Default() /// public bool UseModelSpace { - get { return UseModelSpaceTranslations || UseModelSpaceRotations; } + get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; } } /// @@ -53,7 +66,7 @@ public bool UseModelSpace /// public bool UseLocalSpace { - get { return UseLocalSpaceTranslations || UseLocalSpaceRotations; } + get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; } } @@ -70,6 +83,9 @@ public int TransformSize(int numTransforms) obsPerTransform += UseLocalSpaceTranslations ? 3 : 0; obsPerTransform += UseLocalSpaceRotations ? 4 : 0; + obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0; + obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0; + return numTransforms * obsPerTransform; } } @@ -89,8 +105,12 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting var offset = baseOffset; if (settings.UseModelSpace) { - foreach (var pose in poseExtractor.ModelSpacePoses) + var poses = poseExtractor.ModelSpacePoses; + var vels = poseExtractor.ModelSpaceVelocities; + + for(var i=0; i /// Read access to the model space transforms. /// @@ -35,13 +39,44 @@ public IList LocalSpacePoses } /// - /// Number of transforms in the hierarchy (read-only). + /// Read access to the model space linear velocities. + /// + public IList ModelSpaceVelocities + { + get { return m_ModelSpaceLinearVelocities; } + } + + /// + /// Read access to the local space linear velocities. + /// + public IList LocalSpaceVelocities + { + get { return m_LocalSpaceLinearVelocities; } + } + + /// + /// Number of poses in the hierarchy (read-only). /// public int NumPoses { get { return m_ModelSpacePoses?.Length ?? 0; } } + /// + /// Get the parent index of the body at the specified index. + /// + /// + /// + public int GetParentIndex(int index) + { + if (m_ParentIndices == null) + { + return -1; + } + + return m_ParentIndices[index]; + } + /// /// Initialize with the mapping of parent indices. /// The 0th element is assumed to be -1, indicating that it's the root. @@ -53,6 +88,9 @@ protected void SetParentIndices(int[] parentIndices) var numTransforms = parentIndices.Length; m_ModelSpacePoses = new Pose[numTransforms]; m_LocalSpacePoses = new Pose[numTransforms]; + + m_ModelSpaceLinearVelocities = new Vector3[numTransforms]; + m_LocalSpaceLinearVelocities = new Vector3[numTransforms]; } /// @@ -62,6 +100,14 @@ protected void SetParentIndices(int[] parentIndices) /// protected abstract Pose GetPoseAt(int index); + /// + /// Return the world space linear velocity of the i'th object. + /// + /// + /// + protected abstract Vector3 GetLinearVelocityAt(int index); + + /// /// Update the internal model space transform storage based on the underlying system. /// @@ -72,13 +118,19 @@ public void UpdateModelSpacePoses() return; } - var worldTransform = GetPoseAt(0); - var worldToModel = worldTransform.Inverse(); + var rootWorldTransform = GetPoseAt(0); + var worldToModel = rootWorldTransform.Inverse(); + var rootLinearVel = GetLinearVelocityAt(0); for (var i = 0; i < m_ModelSpacePoses.Length; i++) { - var currentTransform = GetPoseAt(i); - m_ModelSpacePoses[i] = worldToModel.Multiply(currentTransform); + var currentWorldSpacePose = GetPoseAt(i); + var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose); + m_ModelSpacePoses[i] = currentModelSpacePose; + + var currentBodyLinearVel = GetLinearVelocityAt(i); + var relativeVelocity = currentBodyLinearVel - rootLinearVel; + m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity; } } @@ -102,16 +154,21 @@ public void UpdateLocalSpacePoses() var invParent = parentTransform.Inverse(); var currentTransform = GetPoseAt(i); m_LocalSpacePoses[i] = invParent.Multiply(currentTransform); + + var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]); + var currentLinearVel = GetLinearVelocityAt(i); + m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel); } else { m_LocalSpacePoses[i] = Pose.identity; + m_LocalSpaceLinearVelocities[i] = Vector3.zero; } } } - public void DrawModelSpace(Vector3 offset) + internal void DrawModelSpace(Vector3 offset) { UpdateLocalSpacePoses(); UpdateModelSpacePoses(); @@ -138,6 +195,9 @@ public void DrawModelSpace(Vector3 offset) } } + /// + /// Extension methods for the Pose struct, in order to improve the readability of some math. + /// public static class PoseExtensions { /// @@ -165,6 +225,19 @@ public static Pose Multiply(this Pose pose, Pose rhs) return rhs.GetTransformedBy(pose); } + /// + /// Transform the vector by the pose. Conceptually this is equivalent to treating the Pose + /// as a 4x4 matrix and multiplying the augmented vector. + /// See https://en.wikipedia.org/wiki/Affine_transformation#Augmented_matrix for more details. + /// + /// + /// + /// + public static Vector3 Multiply(this Pose pose, Vector3 rhs) + { + return pose.rotation * rhs + pose.position; + } + // TODO optimize inv(A)*B? } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs index 176f21b2da..9036cbc4e5 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs @@ -16,13 +16,22 @@ public class RigidBodyPoseExtractor : PoseExtractor /// Initialize given a root RigidBody. /// /// - public RigidBodyPoseExtractor(Rigidbody rootBody) + public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null) { if (rootBody == null) { return; } - var rbs = rootBody.GetComponentsInChildren (); + + Rigidbody[] rbs; + if (rootGameObject == null) + { + rbs = rootBody.GetComponentsInChildren(); + } + else + { + rbs = rootGameObject.GetComponentsInChildren(); + } var bodyToIndex = new Dictionary(rbs.Length); var parentIndices = new int[rbs.Length]; @@ -54,17 +63,18 @@ public RigidBodyPoseExtractor(Rigidbody rootBody) SetParentIndices(parentIndices); } - /// - /// Get the pose of the i'th RigidBody. - /// - /// - /// + /// + protected override Vector3 GetLinearVelocityAt(int index) + { + return m_Bodies[index].velocity; + } + + /// protected override Pose GetPoseAt(int index) { var body = m_Bodies[index]; return new Pose { rotation = body.rotation, position = body.position }; } - - } + } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs new file mode 100644 index 0000000000..88202bbac1 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs @@ -0,0 +1,52 @@ +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Sensors +{ + /// + /// Editor component that creates a PhysicsBodySensor for the Agent. + /// + public class RigidBodySensorComponent : SensorComponent + { + /// + /// The root Rigidbody of the system. + /// + public Rigidbody RootBody; + + /// + /// Settings defining what types of observations will be generated. + /// + [SerializeField] + public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default(); + + /// + /// Optional sensor name. This must be unique for each Agent. + /// + public string sensorName; + + /// + /// Creates a PhysicsBodySensor. + /// + /// + public override ISensor CreateSensor() + { + return new PhysicsBodySensor(RootBody, gameObject, Settings, sensorName); + } + + /// + public override int[] GetObservationShape() + { + if (RootBody == null) + { + return new[] { 0 }; + } + + // TODO static method in PhysicsBodySensor? + // TODO only update PoseExtractor when body changes? + var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject); + var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); + return new[] { numTransformObservations }; + } + } + +} diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta new file mode 100644 index 0000000000..59ba148382 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: df0f8be9a37d6486498061e2cbc4cd94 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs new file mode 100644 index 0000000000..b6f640e53e --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs @@ -0,0 +1,63 @@ +#if UNITY_2020_1_OR_NEWER +using UnityEngine; +using NUnit.Framework; +using Unity.MLAgents.Extensions.Sensors; + +namespace Unity.MLAgents.Extensions.Tests.Sensors +{ + public class ArticulationBodyPoseExtractorTests + { + [TearDown] + public void RemoveGameObjects() + { + var objects = GameObject.FindObjectsOfType(); + foreach (var o in objects) + { + UnityEngine.Object.DestroyImmediate(o); + } + } + + [Test] + public void TestNullRoot() + { + var poseExtractor = new ArticulationBodyPoseExtractor(null); + // These should be no-ops + poseExtractor.UpdateLocalSpacePoses(); + poseExtractor.UpdateModelSpacePoses(); + + Assert.AreEqual(0, poseExtractor.NumPoses); + } + + [Test] + public void TestSingleBody() + { + var go = new GameObject(); + var rootArticBody = go.AddComponent(); + var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody); + Assert.AreEqual(1, poseExtractor.NumPoses); + } + + [Test] + public void TestTwoBodies() + { + // * rootObj + // - rootArticBody + // * leafGameObj + // - leafArticBody + var rootObj = new GameObject(); + var rootArticBody = rootObj.AddComponent(); + + var leafGameObj = new GameObject(); + var leafArticBody = leafGameObj.AddComponent(); + leafGameObj.transform.SetParent(rootObj.transform); + + leafArticBody.jointType = ArticulationJointType.RevoluteJoint; + + var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody); + Assert.AreEqual(2, poseExtractor.NumPoses); + Assert.AreEqual(-1, poseExtractor.GetParentIndex(0)); + Assert.AreEqual(0, poseExtractor.GetParentIndex(1)); + } + } +} +#endif \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta new file mode 100644 index 0000000000..af76a34446 --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 934ea08cde59d4356bc41e040d333c3d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs new file mode 100644 index 0000000000..33d23d4697 --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs @@ -0,0 +1,113 @@ +#if UNITY_2020_1_OR_NEWER +using UnityEngine; +using NUnit.Framework; +using Unity.MLAgents.Extensions.Sensors; + + +namespace Unity.MLAgents.Extensions.Tests.Sensors +{ + + public class ArticulationBodySensorTests + { + [Test] + public void TestNullRootBody() + { + var gameObj = new GameObject(); + + var sensorComponent = gameObj.AddComponent(); + var sensor = sensorComponent.CreateSensor(); + SensorTestHelper.CompareObservation(sensor, new float[0]); + } + + [Test] + public void TestSingleBody() + { + var gameObj = new GameObject(); + var articulationBody = gameObj.AddComponent(); + var sensorComponent = gameObj.AddComponent(); + sensorComponent.RootBody = articulationBody; + sensorComponent.Settings = new PhysicsSensorSettings + { + UseModelSpaceLinearVelocity = true, + UseLocalSpaceTranslations = true, + UseLocalSpaceRotations = true + }; + + var sensor = sensorComponent.CreateSensor(); + sensor.Update(); + var expected = new[] + { + 0f, 0f, 0f, // ModelSpaceLinearVelocity + 0f, 0f, 0f, // LocalSpaceTranslations + 0f, 0f, 0f, 1f // LocalSpaceRotations + }; + SensorTestHelper.CompareObservation(sensor, expected); + } + + [Test] + public void TestBodiesWithJoint() + { + var rootObj = new GameObject(); + var rootArticBody = rootObj.AddComponent(); + + var middleGamObj = new GameObject(); + var middleArticBody = middleGamObj.AddComponent(); + middleArticBody.AddForce(new Vector3(0f, 1f, 0f)); + middleGamObj.transform.SetParent(rootObj.transform); + middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f); + middleArticBody.jointType = ArticulationJointType.RevoluteJoint; + + var leafGameObj = new GameObject(); + var leafArticBody = leafGameObj.AddComponent(); + leafGameObj.transform.SetParent(middleGamObj.transform); + leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); + leafArticBody.jointType = ArticulationJointType.RevoluteJoint; + +#if UNITY_2020_2_OR_NEWER + // ArticulationBody.velocity is read-only in 2020.1 + rootArticBody.velocity = new Vector3(1f, 0f, 0f); + middleArticBody.velocity = new Vector3(0f, 1f, 0f); + leafArticBody.velocity = new Vector3(0f, 0f, 1f); +#endif + + var sensorComponent = rootObj.AddComponent(); + sensorComponent.RootBody = rootArticBody; + sensorComponent.Settings = new PhysicsSensorSettings + { + UseModelSpaceTranslations = true, + UseLocalSpaceTranslations = true, +#if UNITY_2020_2_OR_NEWER + UseLocalSpaceLinearVelocity = true +#endif + }; + + var sensor = sensorComponent.CreateSensor(); + sensor.Update(); + var expected = new[] + { + // Model space + 0f, 0f, 0f, // Root pos + 13.37f, 0f, 0f, // Middle pos + leafGameObj.transform.position.x, 0f, 0f, // Leaf pos + + // Local space + 0f, 0f, 0f, // Root pos +#if UNITY_2020_2_OR_NEWER + 0f, 0f, 0f, // Root vel +#endif + + 13.37f, 0f, 0f, // Attached pos +#if UNITY_2020_2_OR_NEWER + -1f, 1f, 0f, // Attached vel +#endif + + 4.2f, 0f, 0f, // Leaf pos +#if UNITY_2020_2_OR_NEWER + 0f, -1f, 1f // Leaf vel +#endif + }; + SensorTestHelper.CompareObservation(sensor, expected); + } + } +} +#endif // #if UNITY_2020_1_OR_NEWER \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta new file mode 100644 index 0000000000..97a2d4ba47 --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 0ef757469348342418a68826f51d0783 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs index b140e1f607..627b4d7b8b 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs @@ -13,6 +13,11 @@ protected override Pose GetPoseAt(int index) return Pose.identity; } + protected override Vector3 GetLinearVelocityAt(int index) + { + return Vector3.zero; + } + public void Init(int[] parentIndices) { SetParentIndices(parentIndices); @@ -68,6 +73,12 @@ protected override Pose GetPoseAt(int index) position = translation }; } + + protected override Vector3 GetLinearVelocityAt(int index) + { + return Vector3.zero; + } + } [Test] diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs index 079e7dac23..a5d8b5bcb5 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs @@ -1,4 +1,3 @@ -using System.Collections.Generic; using UnityEngine; using NUnit.Framework; using Unity.MLAgents.Extensions.Sensors; diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs new file mode 100644 index 0000000000..5fbb74c9cf --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs @@ -0,0 +1,113 @@ +using UnityEngine; +using NUnit.Framework; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Extensions.Sensors; + + +namespace Unity.MLAgents.Extensions.Tests.Sensors +{ + + public static class SensorTestHelper + { + public static void CompareObservation(ISensor sensor, float[] expected) + { + string errorMessage; + bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); + Assert.IsTrue(isOK, errorMessage); + } + } + + public class RigidBodySensorTests + { + [Test] + public void TestNullRootBody() + { + var gameObj = new GameObject(); + + var sensorComponent = gameObj.AddComponent(); + var sensor = sensorComponent.CreateSensor(); + SensorTestHelper.CompareObservation(sensor, new float[0]); + } + + [Test] + public void TestSingleRigidbody() + { + var gameObj = new GameObject(); + var rootRb = gameObj.AddComponent(); + var sensorComponent = gameObj.AddComponent(); + sensorComponent.RootBody = rootRb; + sensorComponent.Settings = new PhysicsSensorSettings + { + UseModelSpaceLinearVelocity = true, + UseLocalSpaceTranslations = true, + UseLocalSpaceRotations = true + }; + + var sensor = sensorComponent.CreateSensor(); + sensor.Update(); + var expected = new[] + { + 0f, 0f, 0f, // ModelSpaceLinearVelocity + 0f, 0f, 0f, // LocalSpaceTranslations + 0f, 0f, 0f, 1f // LocalSpaceRotations + }; + SensorTestHelper.CompareObservation(sensor, expected); + } + + [Test] + public void TestBodiesWithJoint() + { + var rootObj = new GameObject(); + var rootRb = rootObj.AddComponent(); + rootRb.velocity = new Vector3(1f, 0f, 0f); + + var middleGamObj = new GameObject(); + var middleRb = middleGamObj.AddComponent(); + middleRb.velocity = new Vector3(0f, 1f, 0f); + middleGamObj.transform.SetParent(rootObj.transform); + middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f); + var joint = middleGamObj.AddComponent(); + joint.connectedBody = rootRb; + + var leafGameObj = new GameObject(); + var leafRb = leafGameObj.AddComponent(); + leafRb.velocity = new Vector3(0f, 0f, 1f); + leafGameObj.transform.SetParent(middleGamObj.transform); + leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); + var joint2 = leafGameObj.AddComponent(); + joint2.connectedBody = middleRb; + + + var sensorComponent = rootObj.AddComponent(); + sensorComponent.RootBody = rootRb; + sensorComponent.Settings = new PhysicsSensorSettings + { + UseModelSpaceTranslations = true, + UseLocalSpaceTranslations = true, + UseLocalSpaceLinearVelocity = true + }; + + var sensor = sensorComponent.CreateSensor(); + sensor.Update(); + var expected = new[] + { + // Model space + 0f, 0f, 0f, // Root pos + 13.37f, 0f, 0f, // Middle pos + leafGameObj.transform.position.x, 0f, 0f, // Leaf pos + + // Local space + 0f, 0f, 0f, // Root pos + 0f, 0f, 0f, // Root vel + + 13.37f, 0f, 0f, // Attached pos + -1f, 1f, 0f, // Attached vel + + 4.2f, 0f, 0f, // Leaf pos + 0f, -1f, 1f // Leaf vel + }; + SensorTestHelper.CompareObservation(sensor, expected); + + } + } +} diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta new file mode 100644 index 0000000000..f7f204ce99 --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d8daf5517a7c94bfd9ac7f45f8d1bcd3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef b/com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef index 7d0dd20f4f..2c13fc3ac1 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef +++ b/com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef @@ -2,7 +2,8 @@ "name": "Unity.ML-Agents.Extensions.EditorTests", "references": [ "Unity.ML-Agents.Extensions.Editor", - "Unity.ML-Agents.Extensions" + "Unity.ML-Agents.Extensions", + "Unity.ML-Agents" ], "optionalUnityReferences": [ "TestAssemblies" @@ -10,5 +11,8 @@ "includePlatforms": [ "Editor" ], - "excludePlatforms": [] + "excludePlatforms": [], + "defineConstraints": [ + "UNITY_INCLUDE_TESTS" + ] } diff --git a/com.unity.ml-agents/Runtime/SensorHelper.cs b/com.unity.ml-agents/Runtime/SensorHelper.cs new file mode 100644 index 0000000000..471a768b27 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SensorHelper.cs @@ -0,0 +1,66 @@ +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// Utility methods related to implementations. + /// + public static class SensorHelper + { + /// + /// Generates the observations for the provided sensor, and returns true if they equal the + /// expected values. If they are unequal, errorMessage is also set. + /// This should not generally be used in production code. It is only intended for + /// simplifying unit tests. + /// + /// + /// + /// + /// + public static bool CompareObservation(ISensor sensor, float[] expected, out string errorMessage) + { + var numExpected = expected.Length; + const float fill = -1337f; + var output = new float[numExpected]; + for (var i = 0; i < numExpected; i++) + { + output[i] = fill; + } + + if (numExpected > 0) + { + if (fill != output[0]) + { + errorMessage = "Error setting output buffer."; + return false; + } + } + + ObservationWriter writer = new ObservationWriter(); + writer.SetTarget(output, sensor.GetObservationShape(), 0); + + // Make sure ObservationWriter didn't touch anything + if (numExpected > 0) + { + if (fill != output[0]) + { + errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have."; + return false; + } + } + + sensor.Write(writer); + for (var i = 0; i < output.Length; i++) + { + if (expected[i] != output[i]) + { + errorMessage = $"Expected and actual differed in position {i}. Expected: {expected[i]} Actual: {output[i]} "; + return false; + } + } + + errorMessage = null; + return true; + } + } +} diff --git a/com.unity.ml-agents/Runtime/SensorHelper.cs.meta b/com.unity.ml-agents/Runtime/SensorHelper.cs.meta new file mode 100644 index 0000000000..c331abd0b6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SensorHelper.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7c1189c0af42c46f7b533350d49ad3e7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs index 7df0638680..b24dbddcc0 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs @@ -4,27 +4,13 @@ namespace Unity.MLAgents.Tests { - public class SensorTestHelper + public static class SensorTestHelper { public static void CompareObservation(ISensor sensor, float[] expected) { - var numExpected = expected.Length; - const float fill = -1337f; - var output = new float[numExpected]; - for (var i = 0; i < numExpected; i++) - { - output[i] = fill; - } - Assert.AreEqual(fill, output[0]); - - ObservationWriter writer = new ObservationWriter(); - writer.SetTarget(output, sensor.GetObservationShape(), 0); - - // Make sure ObservationWriter didn't touch anything - Assert.AreEqual(fill, output[0]); - - sensor.Write(writer); - Assert.AreEqual(expected, output); + string errorMessage; + bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); + Assert.IsTrue(isOK, errorMessage); } }