diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs new file mode 100644 index 0000000000..49dd17b1a0 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs @@ -0,0 +1,147 @@ +#if UNITY_2020_1_OR_NEWER + +using System.Collections.Generic; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Sensors +{ + public class ArticulationBodyJointExtractor : IJointExtractor + { + ArticulationBody m_Body; + + public ArticulationBodyJointExtractor(ArticulationBody body) + { + m_Body = body; + } + + public int NumObservations(PhysicsSensorSettings settings) + { + return NumObservations(m_Body, settings); + } + + public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings) + { + if (body == null || body.isRoot) + { + return 0; + } + + var totalCount = 0; + if (settings.UseJointPositionsAndAngles) + { + switch (body.jointType) + { + case ArticulationJointType.RevoluteJoint: + case ArticulationJointType.SphericalJoint: + // Both RevoluteJoint and SphericalJoint have all angular components. + // We use sine and cosine of the angles for the observations. + totalCount += 2 * body.dofCount; + break; + case ArticulationJointType.FixedJoint: + // Since FixedJoint can't moved, there aren't any interesting observations for it. + break; + case ArticulationJointType.PrismaticJoint: + // One linear component + totalCount += body.dofCount; + break; + } + } + + if (settings.UseJointForces) + { + totalCount += body.dofCount; + } + + return totalCount; + } + + public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset) + { + if (m_Body == null || m_Body.isRoot) + { + return 0; + } + + var currentOffset = offset; + + // Write joint positions + if (settings.UseJointPositionsAndAngles) + { + switch (m_Body.jointType) + { + case ArticulationJointType.RevoluteJoint: + case ArticulationJointType.SphericalJoint: + // All joint positions are angular + for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++) + { + var jointRotationRads = m_Body.jointPosition[dofIndex]; + writer[currentOffset++] = Mathf.Sin(jointRotationRads); + writer[currentOffset++] = Mathf.Cos(jointRotationRads); + } + break; + case ArticulationJointType.FixedJoint: + // No observations + break; + case ArticulationJointType.PrismaticJoint: + writer[currentOffset++] = GetPrismaticValue(); + break; + } + } + + if (settings.UseJointForces) + { + for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++) + { + // take tanh to keep in [-1, 1] + writer[currentOffset++] = (float) System.Math.Tanh(m_Body.jointForce[dofIndex]); + } + } + + return currentOffset - offset; + } + + float GetPrismaticValue() + { + // Prismatic joints should have at most one free axis. + bool limited = false; + var drive = m_Body.xDrive; + if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion) + { + drive = m_Body.xDrive; + limited = true; + } + else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion) + { + drive = m_Body.yDrive; + limited = true; + } + else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion) + { + drive = m_Body.zDrive; + limited = true; + } + + var jointPos = m_Body.jointPosition[0]; + if (limited) + { + // If locked, interpolate between the limits. + var upperLimit = drive.upperLimit; + var lowerLimit = drive.lowerLimit; + if (upperLimit <= lowerLimit) + { + // Invalid limits (probably equal), so don't try to lerp + return 0; + } + var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos); + + // Convert [0, 1] -> [-1, 1] + var normalized = 2.0f * invLerped - 1.0f; + return normalized; + } + // take tanh() to keep in [-1, 1] + return (float) System.Math.Tanh(jointPos); + } + } +} +#endif \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta new file mode 100644 index 0000000000..8b5c4d6729 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 238d15f867b9c4ced9cef331b7420b27 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs index 512b857345..f354a614b7 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs @@ -71,6 +71,8 @@ protected override Pose GetPoseAt(int index) var t = go.transform; return new Pose { rotation = t.rotation, position = t.position }; } + + internal ArticulationBody[] Bodies => m_Bodies; } } #endif // UNITY_2020_1_OR_NEWER \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs index 23682bf2f6..82418a99cf 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs @@ -32,8 +32,16 @@ public override int[] GetObservationShape() // TODO static method in PhysicsBodySensor? // TODO only update PoseExtractor when body changes? var poseExtractor = new ArticulationBodyPoseExtractor(RootBody); - var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); - return new[] { numTransformObservations }; + var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); + var numJointObservations = 0; + // Start from i=1 to ignore the root + for (var i = 1; i < poseExtractor.Bodies.Length; i++) + { + numJointObservations += ArticulationBodyJointExtractor.NumObservations( + poseExtractor.Bodies[i], Settings + ); + } + return new[] { numPoseObservations + numJointObservations }; } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs new file mode 100644 index 0000000000..401e3abf50 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs @@ -0,0 +1,27 @@ +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Sensors +{ + /// + /// Interface for generating observations from a physical joint or constraint. + /// + public interface IJointExtractor + { + /// + /// Determine the number of observations that would be generated for the particular joint + /// using the provided PhysicsSensorSettings. + /// + /// + /// Number of floats that will be written. + int NumObservations(PhysicsSensorSettings settings); + + /// + /// Write the observations to the ObservationWriter, starting at the specified offset. + /// + /// + /// + /// + /// Number of floats that were written. + int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset); + } +} diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta new file mode 100644 index 0000000000..a1ef9c2f7b --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 2d2a01ea194334a4682d5c8cad4a956b +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs index 6b0bb2ca0f..de9d3866f6 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -12,6 +12,7 @@ public class PhysicsBodySensor : ISensor string m_SensorName; PoseExtractor m_PoseExtractor; + IJointExtractor[] m_JointExtractors; PhysicsSensorSettings m_Settings; /// @@ -22,23 +23,59 @@ public class PhysicsBodySensor : ISensor /// public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null) { - m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject); + var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject); + m_PoseExtractor = poseExtractor; m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName; m_Settings = settings; - var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); - m_Shape = new[] { numTransformObservations }; + var numJointExtractorObservations = 0; + var rigidBodies = poseExtractor.Bodies; + if (rigidBodies != null) + { + m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root + for (var i = 1; i < rigidBodies.Length; i++) + { + var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]); + numJointExtractorObservations += jointExtractor.NumObservations(settings); + m_JointExtractors[i - 1] = jointExtractor; + } + } + else + { + m_JointExtractors = new IJointExtractor[0]; + } + + var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); + m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; } #if UNITY_2020_1_OR_NEWER public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null) { - m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody); + var poseExtractor = new ArticulationBodyPoseExtractor(rootBody); + m_PoseExtractor = poseExtractor; m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; m_Settings = settings; - var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); - m_Shape = new[] { numTransformObservations }; + var numJointExtractorObservations = 0; + var articBodies = poseExtractor.Bodies; + if (articBodies != null) + { + m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root + for (var i = 1; i < articBodies.Length; i++) + { + var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]); + numJointExtractorObservations += jointExtractor.NumObservations(settings); + m_JointExtractors[i - 1] = jointExtractor; + } + } + else + { + m_JointExtractors = new IJointExtractor[0]; + } + + var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); + m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; } #endif @@ -52,6 +89,10 @@ public int[] GetObservationShape() public int Write(ObservationWriter writer) { var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); + foreach (var jointExtractor in m_JointExtractors) + { + numWritten += jointExtractor.Write(m_Settings, writer, numWritten); + } return numWritten; } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs index 31a48e31c9..9109d9592e 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs @@ -40,6 +40,16 @@ public struct PhysicsSensorSettings /// public bool UseLocalSpaceLinearVelocity; + /// + /// Whether to use joint-specific positions and angles as observations. + /// + public bool UseJointPositionsAndAngles; + + /// + /// Whether to use the joint forces and torques that are applied by the solver as observations. + /// + public bool UseJointForces; + /// /// Creates a PhysicsSensorSettings with reasonable default values. /// @@ -68,26 +78,6 @@ public bool UseLocalSpace { get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; } } - - - /// - /// The number of floats needed to represent a given number of transforms. - /// - /// - /// - public int TransformSize(int numTransforms) - { - int obsPerTransform = 0; - obsPerTransform += UseModelSpaceTranslations ? 3 : 0; - obsPerTransform += UseModelSpaceRotations ? 4 : 0; - obsPerTransform += UseLocalSpaceTranslations ? 3 : 0; - obsPerTransform += UseLocalSpaceRotations ? 4 : 0; - - obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0; - obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0; - - return numTransforms * obsPerTransform; - } } internal static class ObservationWriterPhysicsExtensions diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs index 03902442ec..6a5c31a7c0 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs @@ -167,6 +167,24 @@ public void UpdateLocalSpacePoses() } } + /// + /// Compute the number of floats needed to represent the poses for the given PhysicsSensorSettings. + /// + /// + /// + public int GetNumPoseObservations(PhysicsSensorSettings settings) + { + int obsPerPose = 0; + obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0; + obsPerPose += settings.UseModelSpaceRotations ? 4 : 0; + obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0; + obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0; + + obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0; + obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0; + + return NumPoses * obsPerPose; + } internal void DrawModelSpace(Vector3 offset) { diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs new file mode 100644 index 0000000000..dda1aed27f --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs @@ -0,0 +1,62 @@ +using System.Collections.Generic; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Sensors +{ + public class RigidBodyJointExtractor : IJointExtractor + { + Rigidbody m_Body; + Joint m_Joint; + + public RigidBodyJointExtractor(Rigidbody body) + { + m_Body = body; + m_Joint = m_Body?.GetComponent(); + } + + public int NumObservations(PhysicsSensorSettings settings) + { + return NumObservations(m_Body, m_Joint, settings); + } + + public static int NumObservations(Rigidbody body, Joint joint, PhysicsSensorSettings settings) + { + if(body == null || joint == null) + { + return 0; + } + + var numObservations = 0; + if (settings.UseJointForces) + { + // 3 force and 3 torque values + numObservations += 6; + } + + return numObservations; + } + + public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset) + { + if (m_Body == null || m_Joint == null) + { + return 0; + } + + var currentOffset = offset; + if (settings.UseJointForces) + { + // Take tanh of the forces and torques to ensure they're in [-1, 1] + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.x); + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.y); + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.z); + + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.x); + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.y); + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.z); + } + return currentOffset - offset; + } + } +} diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta new file mode 100644 index 0000000000..9d3dc91df9 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5014d7ab95c6a44469f447b8a7019746 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs index 9036cbc4e5..05b55ef737 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs @@ -75,6 +75,8 @@ protected override Pose GetPoseAt(int index) var body = m_Bodies[index]; return new Pose { rotation = body.rotation, position = body.position }; } + + internal Rigidbody[] Bodies => m_Bodies; } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs index 88202bbac1..ce6cf05379 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs @@ -44,8 +44,17 @@ public override int[] GetObservationShape() // TODO static method in PhysicsBodySensor? // TODO only update PoseExtractor when body changes? var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject); - var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); - return new[] { numTransformObservations }; + var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); + + var numJointObservations = 0; + // Start from i=1 to ignore the root + for (var i = 1; i < poseExtractor.Bodies.Length; i++) + { + var body = poseExtractor.Bodies[i]; + var joint = body?.GetComponent(); + numJointObservations += RigidBodyJointExtractor.NumObservations(body, joint, Settings); + } + return new[] { numPoseObservations + numJointObservations }; } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs index 33d23d4697..94e708ed68 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs @@ -42,6 +42,7 @@ public void TestSingleBody() 0f, 0f, 0f, 1f // LocalSpaceRotations }; SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); } [Test] @@ -61,7 +62,14 @@ public void TestBodiesWithJoint() var leafArticBody = leafGameObj.AddComponent(); leafGameObj.transform.SetParent(middleGamObj.transform); leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); - leafArticBody.jointType = ArticulationJointType.RevoluteJoint; + leafArticBody.jointType = ArticulationJointType.PrismaticJoint; + leafArticBody.linearLockZ = ArticulationDofLock.LimitedMotion; + leafArticBody.zDrive = new ArticulationDrive + { + lowerLimit = -3, + upperLimit = 1 + }; + #if UNITY_2020_2_OR_NEWER // ArticulationBody.velocity is read-only in 2020.1 @@ -107,6 +115,30 @@ public void TestBodiesWithJoint() #endif }; SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + + // Update the settings to only process joint observations + sensorComponent.Settings = new PhysicsSensorSettings + { + UseJointForces = true, + UseJointPositionsAndAngles = true, + }; + + sensor = sensorComponent.CreateSensor(); + sensor.Update(); + + expected = new[] + { + // revolute + 0f, 1f, // joint1.position (sin and cos) + 0f, // joint1.force + + // prismatic + 0.5f, // joint2.position (interpolate between limits) + 0f, // joint2.force + }; + SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); } } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs index 5fbb74c9cf..279fc7007d 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs @@ -52,6 +52,7 @@ public void TestSingleRigidbody() 0f, 0f, 0f, 1f // LocalSpaceRotations }; SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); } [Test] @@ -107,6 +108,27 @@ public void TestBodiesWithJoint() 0f, -1f, 1f // Leaf vel }; SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + + // Update the settings to only process joint observations + sensorComponent.Settings = new PhysicsSensorSettings + { + UseJointPositionsAndAngles = true, + UseJointForces = true, + }; + + sensor = sensorComponent.CreateSensor(); + sensor.Update(); + + expected = new[] + { + 0f, 0f, 0f, // joint1.force + 0f, 0f, 0f, // joint1.torque + 0f, 0f, 0f, // joint2.force + 0f, 0f, 0f, // joint2.torque + }; + SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); } }