From 2c4a9d7b46fee219f17e4b59834160f086dfdd80 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 13 Jul 2020 15:44:03 -0700 Subject: [PATCH 1/8] joint extractors, articulation impl --- .../Sensors/ArticulationBodyJointExtractor.cs | 130 ++++++++++++++++++ .../Sensors/ArticulationBodyPoseExtractor.cs | 2 + .../ArticulationBodySensorComponent.cs | 10 +- .../Runtime/Sensors/JointExtractor.cs | 10 ++ .../Runtime/Sensors/PhysicsBodySensor.cs | 22 ++- .../Runtime/Sensors/PhysicsSensorSettings.cs | 2 + 6 files changed, 173 insertions(+), 3 deletions(-) create mode 100644 com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs create mode 100644 com.unity.ml-agents.extensions/Runtime/Sensors/JointExtractor.cs diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs new file mode 100644 index 0000000000..8c91a0d2d4 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs @@ -0,0 +1,130 @@ +#if UNITY_2020_1_OR_NEWER + +using System.Collections.Generic; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Sensors +{ + public class ArticulationBodyJointExtractor : JointExtractor + { + ArticulationBody m_Body; + + public ArticulationBodyJointExtractor(ArticulationBody body) + { + m_Body = body; + } + + public int NumObservations(PhysicsSensorSettings settings) + { + return NumObservations(m_Body, settings); + } + + public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings) + { + if (body == null || body.isRoot) + { + return 0; + } + + var totalCount = 0; + if (settings.UseJointPositions) + { + switch (body.jointType) + { + case ArticulationJointType.RevoluteJoint: + case ArticulationJointType.SphericalJoint: + // Two floats per angular component + totalCount += 2 * body.dofCount; + break; + case ArticulationJointType.FixedJoint: + break; // TODO (none?) + case ArticulationJointType.PrismaticJoint: + // One linear component + totalCount += body.dofCount; + break; + } + } + + return totalCount; + } + + public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset) + { + if (m_Body == null || m_Body.isRoot) + { + return 0; + } + + var currentOffset = offset; + + // Write joint positions + if (settings.UseJointPositions) + { + switch (m_Body.jointType) + { + case ArticulationJointType.RevoluteJoint: + case ArticulationJointType.SphericalJoint: + // All joint positions are angular + for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++) + { + var jointRotationRads = m_Body.jointPosition[dofIndex]; + writer[currentOffset++] = Mathf.Sin(jointRotationRads); + writer[currentOffset++] = Mathf.Cos(jointRotationRads); + } + break; + case ArticulationJointType.FixedJoint: + break; // TODO + case ArticulationJointType.PrismaticJoint: + writer[currentOffset++] = GetPrismaticValue(); + break; + } + } + + return currentOffset - offset; + } + + float GetPrismaticValue() + { + // Prismatic joints should have at most one free axis. + bool limited = false; + var drive = m_Body.xDrive; + if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion) + { + drive = m_Body.xDrive; + limited = true; + } + else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion) + { + drive = m_Body.yDrive; + limited = true; + } + else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion) + { + drive = m_Body.zDrive; + limited = true; + } + + var jointPos = m_Body.jointPosition[0]; + if (limited) + { + // If locked, interpolate between the limits. + var upperLimit = drive.upperLimit; + var lowerLimit = drive.lowerLimit; + if (upperLimit <= lowerLimit) + { + // Invalid limits (probably equal), so don't try to lerp + return 0; + } + var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos); + + // Convert [0, 1] -> [-1, 1] + var normalized = 2.0f * invLerped - 1.0f; + return normalized; + } + // TODO take tanh() to keep in [-1, 1]? + return jointPos; + } + } +} +#endif \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs index 512b857345..f354a614b7 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs @@ -71,6 +71,8 @@ protected override Pose GetPoseAt(int index) var t = go.transform; return new Pose { rotation = t.rotation, position = t.position }; } + + internal ArticulationBody[] Bodies => m_Bodies; } } #endif // UNITY_2020_1_OR_NEWER \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs index 23682bf2f6..a9995cbd34 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs @@ -33,7 +33,15 @@ public override int[] GetObservationShape() // TODO only update PoseExtractor when body changes? var poseExtractor = new ArticulationBodyPoseExtractor(RootBody); var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); - return new[] { numTransformObservations }; + var numJointObservations = 0; + // Start from i=1 to ignore the root + for (var i = 1; i < poseExtractor.Bodies.Length; i++) + { + numJointObservations += ArticulationBodyJointExtractor.NumObservations( + poseExtractor.Bodies[i], Settings + ); + } + return new[] { numTransformObservations + numJointObservations }; } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/JointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/JointExtractor.cs new file mode 100644 index 0000000000..e3219de315 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/JointExtractor.cs @@ -0,0 +1,10 @@ +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Sensors +{ + public interface JointExtractor + { + int NumObservations(PhysicsSensorSettings settings); + int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset); + } +} diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs index 6b0bb2ca0f..2a9c068130 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -12,6 +12,7 @@ public class PhysicsBodySensor : ISensor string m_SensorName; PoseExtractor m_PoseExtractor; + JointExtractor[] m_JointExtractors; PhysicsSensorSettings m_Settings; /// @@ -25,20 +26,33 @@ public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsS m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject); m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName; m_Settings = settings; + m_JointExtractors = new JointExtractor[0]; var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); + // TODO Account for JointExtractor sizes m_Shape = new[] { numTransformObservations }; } #if UNITY_2020_1_OR_NEWER public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null) { - m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody); + var poseExtractor = new ArticulationBodyPoseExtractor(rootBody); + m_PoseExtractor = poseExtractor; m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; m_Settings = settings; + var numJointExtractorObservations = 0; + var articBodies = poseExtractor.Bodies; + m_JointExtractors = new JointExtractor[articBodies.Length - 1]; // skip the root + for (var i = 1; i < articBodies.Length; i++) + { + var jointExtractor= new ArticulationBodyJointExtractor(articBodies[i]); + numJointExtractorObservations += jointExtractor.NumObservations(settings); + m_JointExtractors[i - 1] = jointExtractor; + } + var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); - m_Shape = new[] { numTransformObservations }; + m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; } #endif @@ -52,6 +66,10 @@ public int[] GetObservationShape() public int Write(ObservationWriter writer) { var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); + foreach (var jointExtractor in m_JointExtractors) + { + numWritten += jointExtractor.Write(m_Settings, writer, numWritten); + } return numWritten; } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs index 31a48e31c9..231aa77efa 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs @@ -40,6 +40,8 @@ public struct PhysicsSensorSettings /// public bool UseLocalSpaceLinearVelocity; + public bool UseJointPositions; + /// /// Creates a PhysicsSensorSettings with reasonable default values. /// From 5e6dcac2a2fb0a055e0a3355cd2ef1ffa339ad55 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 13 Jul 2020 16:16:31 -0700 Subject: [PATCH 2/8] rb joint extractor WIP --- .../Runtime/Sensors/ArticulationBodyJointExtractor.cs | 2 +- .../Sensors/ArticulationBodyJointExtractor.cs.meta | 11 +++++++++++ .../Sensors/{JointExtractor.cs => IJointExtractor.cs} | 2 +- .../Runtime/Sensors/IJointExtractor.cs.meta | 11 +++++++++++ .../Runtime/Sensors/PhysicsBodySensor.cs | 8 ++++---- .../Runtime/Sensors/RigidBodyJointExtractor.cs | 7 +++++++ .../Runtime/Sensors/RigidBodyJointExtractor.cs.meta | 11 +++++++++++ 7 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta rename com.unity.ml-agents.extensions/Runtime/Sensors/{JointExtractor.cs => IJointExtractor.cs} (86%) create mode 100644 com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta create mode 100644 com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs create mode 100644 com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs index 8c91a0d2d4..37e690c5ff 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs @@ -6,7 +6,7 @@ namespace Unity.MLAgents.Extensions.Sensors { - public class ArticulationBodyJointExtractor : JointExtractor + public class ArticulationBodyJointExtractor : IJointExtractor { ArticulationBody m_Body; diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta new file mode 100644 index 0000000000..8b5c4d6729 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 238d15f867b9c4ced9cef331b7420b27 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/JointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs similarity index 86% rename from com.unity.ml-agents.extensions/Runtime/Sensors/JointExtractor.cs rename to com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs index e3219de315..19df7bb24c 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/JointExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs @@ -2,7 +2,7 @@ namespace Unity.MLAgents.Extensions.Sensors { - public interface JointExtractor + public interface IJointExtractor { int NumObservations(PhysicsSensorSettings settings); int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset); diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta new file mode 100644 index 0000000000..a1ef9c2f7b --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 2d2a01ea194334a4682d5c8cad4a956b +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs index 2a9c068130..012dfdd493 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -12,7 +12,7 @@ public class PhysicsBodySensor : ISensor string m_SensorName; PoseExtractor m_PoseExtractor; - JointExtractor[] m_JointExtractors; + IJointExtractor[] m_JointExtractors; PhysicsSensorSettings m_Settings; /// @@ -26,10 +26,10 @@ public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsS m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject); m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName; m_Settings = settings; - m_JointExtractors = new JointExtractor[0]; + m_JointExtractors = new IJointExtractor[0]; var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); - // TODO Account for JointExtractor sizes + // TODO Account for IJointExtractor sizes m_Shape = new[] { numTransformObservations }; } @@ -43,7 +43,7 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin var numJointExtractorObservations = 0; var articBodies = poseExtractor.Bodies; - m_JointExtractors = new JointExtractor[articBodies.Length - 1]; // skip the root + m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root for (var i = 1; i < articBodies.Length; i++) { var jointExtractor= new ArticulationBodyJointExtractor(articBodies[i]); diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs new file mode 100644 index 0000000000..ca106a744f --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs @@ -0,0 +1,7 @@ +namespace Unity.MLAgents.Extensions.Sensors +{ + public class RigidBodyJointExtractor + { + + } +} diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta new file mode 100644 index 0000000000..9d3dc91df9 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5014d7ab95c6a44469f447b8a7019746 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: From dd5e6866fdae85d4024c7077dc9109f570b90cbc Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 13 Jul 2020 16:17:20 -0700 Subject: [PATCH 3/8] rb joint extractor WIP --- .../Runtime/Sensors/PhysicsBodySensor.cs | 19 +++++++--- .../Runtime/Sensors/PhysicsSensorSettings.cs | 2 ++ .../Sensors/RigidBodyJointExtractor.cs | 35 +++++++++++++++++-- .../Runtime/Sensors/RigidBodyPoseExtractor.cs | 2 ++ .../Sensors/RigidBodySensorComponent.cs | 11 +++++- 5 files changed, 61 insertions(+), 8 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs index 012dfdd493..a76e206f73 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -23,14 +23,23 @@ public class PhysicsBodySensor : ISensor /// public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null) { - m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject); + var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject); + m_PoseExtractor = poseExtractor; m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName; m_Settings = settings; - m_JointExtractors = new IJointExtractor[0]; + + var numJointExtractorObservations = 0; + var rigidBodies = poseExtractor.Bodies; + m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root + for (var i = 1; i < rigidBodies.Length; i++) + { + var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]); + numJointExtractorObservations += jointExtractor.NumObservations(settings); + m_JointExtractors[i - 1] = jointExtractor; + } var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); - // TODO Account for IJointExtractor sizes - m_Shape = new[] { numTransformObservations }; + m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; } #if UNITY_2020_1_OR_NEWER @@ -46,7 +55,7 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root for (var i = 1; i < articBodies.Length; i++) { - var jointExtractor= new ArticulationBodyJointExtractor(articBodies[i]); + var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]); numJointExtractorObservations += jointExtractor.NumObservations(settings); m_JointExtractors[i - 1] = jointExtractor; } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs index 231aa77efa..9783004fa3 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs @@ -42,6 +42,8 @@ public struct PhysicsSensorSettings public bool UseJointPositions; + public bool UseJointForces; + /// /// Creates a PhysicsSensorSettings with reasonable default values. /// diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs index ca106a744f..5bee8e1c90 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs @@ -1,7 +1,38 @@ +using System.Collections.Generic; +using UnityEngine; +using Unity.MLAgents.Sensors; + namespace Unity.MLAgents.Extensions.Sensors { - public class RigidBodyJointExtractor + public class RigidBodyJointExtractor : IJointExtractor { - + Rigidbody m_Body; + Joint m_Joint; + + public RigidBodyJointExtractor(Rigidbody body) + { + m_Body = body; + m_Joint = m_Body?.GetComponent(); + } + + public int NumObservations(PhysicsSensorSettings settings) + { + return NumObservations(m_Body, m_Joint, settings); + } + + public static int NumObservations(Rigidbody body, Joint joint, PhysicsSensorSettings settings) + { + if(body == null || joint == null) + { + return 0; + } + + + } + + public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset) + { + return 0; + } } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs index 9036cbc4e5..05b55ef737 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs @@ -75,6 +75,8 @@ protected override Pose GetPoseAt(int index) var body = m_Bodies[index]; return new Pose { rotation = body.rotation, position = body.position }; } + + internal Rigidbody[] Bodies => m_Bodies; } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs index 88202bbac1..6715e87054 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs @@ -45,7 +45,16 @@ public override int[] GetObservationShape() // TODO only update PoseExtractor when body changes? var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject); var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); - return new[] { numTransformObservations }; + + var numJointObservations = 0; + // Start from i=1 to ignore the root + for (var i = 1; i < poseExtractor.Bodies.Length; i++) + { + var body = poseExtractor.Bodies[i]; + var joint = body?.GetComponent(); + numJointObservations += RigidBodyJointExtractor.NumObservations(body, joint, Settings); + } + return new[] { numTransformObservations + numJointObservations }; } } From f65f7f7d39950c1b61a050d33c2321c9a7d800ca Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 13 Jul 2020 18:41:45 -0700 Subject: [PATCH 4/8] cleanup and tests --- .../Sensors/ArticulationBodyJointExtractor.cs | 18 ++++++++- .../ArticulationBodySensorComponent.cs | 4 +- .../Runtime/Sensors/PhysicsBodySensor.cs | 38 +++++++++++++------ .../Runtime/Sensors/PhysicsSensorSettings.cs | 20 ---------- .../Runtime/Sensors/PoseExtractor.cs | 18 +++++++++ .../Sensors/RigidBodyJointExtractor.cs | 26 ++++++++++++- .../Sensors/RigidBodySensorComponent.cs | 4 +- .../Sensors/ArticulationBodySensorTests.cs | 34 ++++++++++++++++- .../Editor/Sensors/RigidBodySensorTests.cs | 22 +++++++++++ 9 files changed, 144 insertions(+), 40 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs index 37e690c5ff..78df2d2718 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs @@ -46,6 +46,11 @@ public static int NumObservations(ArticulationBody body, PhysicsSensorSettings s } } + if (settings.UseJointForces) + { + totalCount += body.dofCount; + } + return totalCount; } @@ -81,6 +86,15 @@ public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int o } } + if (settings.UseJointForces) + { + for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++) + { + // take tanh to keep in [-1, 1] + writer[currentOffset++] = (float) System.Math.Tanh(m_Body.jointForce[dofIndex]); + } + } + return currentOffset - offset; } @@ -122,8 +136,8 @@ float GetPrismaticValue() var normalized = 2.0f * invLerped - 1.0f; return normalized; } - // TODO take tanh() to keep in [-1, 1]? - return jointPos; + // take tanh() to keep in [-1, 1] + return (float) System.Math.Tanh(jointPos); } } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs index a9995cbd34..82418a99cf 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs @@ -32,7 +32,7 @@ public override int[] GetObservationShape() // TODO static method in PhysicsBodySensor? // TODO only update PoseExtractor when body changes? var poseExtractor = new ArticulationBodyPoseExtractor(RootBody); - var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); + var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); var numJointObservations = 0; // Start from i=1 to ignore the root for (var i = 1; i < poseExtractor.Bodies.Length; i++) @@ -41,7 +41,7 @@ public override int[] GetObservationShape() poseExtractor.Bodies[i], Settings ); } - return new[] { numTransformObservations + numJointObservations }; + return new[] { numPoseObservations + numJointObservations }; } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs index a76e206f73..de9d3866f6 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -30,15 +30,22 @@ public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsS var numJointExtractorObservations = 0; var rigidBodies = poseExtractor.Bodies; - m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root - for (var i = 1; i < rigidBodies.Length; i++) + if (rigidBodies != null) { - var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]); - numJointExtractorObservations += jointExtractor.NumObservations(settings); - m_JointExtractors[i - 1] = jointExtractor; + m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root + for (var i = 1; i < rigidBodies.Length; i++) + { + var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]); + numJointExtractorObservations += jointExtractor.NumObservations(settings); + m_JointExtractors[i - 1] = jointExtractor; + } + } + else + { + m_JointExtractors = new IJointExtractor[0]; } - var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); + var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; } @@ -52,15 +59,22 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin var numJointExtractorObservations = 0; var articBodies = poseExtractor.Bodies; - m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root - for (var i = 1; i < articBodies.Length; i++) + if (articBodies != null) + { + m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root + for (var i = 1; i < articBodies.Length; i++) + { + var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]); + numJointExtractorObservations += jointExtractor.NumObservations(settings); + m_JointExtractors[i - 1] = jointExtractor; + } + } + else { - var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]); - numJointExtractorObservations += jointExtractor.NumObservations(settings); - m_JointExtractors[i - 1] = jointExtractor; + m_JointExtractors = new IJointExtractor[0]; } - var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); + var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; } #endif diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs index 9783004fa3..94729b2979 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs @@ -72,26 +72,6 @@ public bool UseLocalSpace { get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; } } - - - /// - /// The number of floats needed to represent a given number of transforms. - /// - /// - /// - public int TransformSize(int numTransforms) - { - int obsPerTransform = 0; - obsPerTransform += UseModelSpaceTranslations ? 3 : 0; - obsPerTransform += UseModelSpaceRotations ? 4 : 0; - obsPerTransform += UseLocalSpaceTranslations ? 3 : 0; - obsPerTransform += UseLocalSpaceRotations ? 4 : 0; - - obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0; - obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0; - - return numTransforms * obsPerTransform; - } } internal static class ObservationWriterPhysicsExtensions diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs index 03902442ec..6a5c31a7c0 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs @@ -167,6 +167,24 @@ public void UpdateLocalSpacePoses() } } + /// + /// Compute the number of floats needed to represent the poses for the given PhysicsSensorSettings. + /// + /// + /// + public int GetNumPoseObservations(PhysicsSensorSettings settings) + { + int obsPerPose = 0; + obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0; + obsPerPose += settings.UseModelSpaceRotations ? 4 : 0; + obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0; + obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0; + + obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0; + obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0; + + return NumPoses * obsPerPose; + } internal void DrawModelSpace(Vector3 offset) { diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs index 5bee8e1c90..dda1aed27f 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs @@ -27,12 +27,36 @@ public static int NumObservations(Rigidbody body, Joint joint, PhysicsSensorSett return 0; } + var numObservations = 0; + if (settings.UseJointForces) + { + // 3 force and 3 torque values + numObservations += 6; + } + return numObservations; } public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset) { - return 0; + if (m_Body == null || m_Joint == null) + { + return 0; + } + + var currentOffset = offset; + if (settings.UseJointForces) + { + // Take tanh of the forces and torques to ensure they're in [-1, 1] + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.x); + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.y); + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.z); + + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.x); + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.y); + writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.z); + } + return currentOffset - offset; } } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs index 6715e87054..ce6cf05379 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs @@ -44,7 +44,7 @@ public override int[] GetObservationShape() // TODO static method in PhysicsBodySensor? // TODO only update PoseExtractor when body changes? var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject); - var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); + var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); var numJointObservations = 0; // Start from i=1 to ignore the root @@ -54,7 +54,7 @@ public override int[] GetObservationShape() var joint = body?.GetComponent(); numJointObservations += RigidBodyJointExtractor.NumObservations(body, joint, Settings); } - return new[] { numTransformObservations + numJointObservations }; + return new[] { numPoseObservations + numJointObservations }; } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs index 33d23d4697..31f7ccdfe5 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs @@ -42,6 +42,7 @@ public void TestSingleBody() 0f, 0f, 0f, 1f // LocalSpaceRotations }; SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); } [Test] @@ -61,7 +62,14 @@ public void TestBodiesWithJoint() var leafArticBody = leafGameObj.AddComponent(); leafGameObj.transform.SetParent(middleGamObj.transform); leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); - leafArticBody.jointType = ArticulationJointType.RevoluteJoint; + leafArticBody.jointType = ArticulationJointType.PrismaticJoint; + leafArticBody.linearLockZ = ArticulationDofLock.LimitedMotion; + leafArticBody.zDrive = new ArticulationDrive + { + lowerLimit = -3, + upperLimit = 1 + }; + #if UNITY_2020_2_OR_NEWER // ArticulationBody.velocity is read-only in 2020.1 @@ -107,6 +115,30 @@ public void TestBodiesWithJoint() #endif }; SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + + // Update the settings to only process joint observations + sensorComponent.Settings = new PhysicsSensorSettings + { + UseJointForces = true, + UseJointPositions = true, + }; + + sensor = sensorComponent.CreateSensor(); + sensor.Update(); + + expected = new[] + { + // revolute + 0f, 1f, // joint1.position (sin and cos) + 0f, // joint1.force + + // prismatic + 0.5f, // joint2.position (interpolate between limits) + 0f, // joint2.force + }; + SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); } } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs index 5fbb74c9cf..7baaeec60d 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs @@ -52,6 +52,7 @@ public void TestSingleRigidbody() 0f, 0f, 0f, 1f // LocalSpaceRotations }; SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); } [Test] @@ -107,6 +108,27 @@ public void TestBodiesWithJoint() 0f, -1f, 1f // Leaf vel }; SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + + // Update the settings to only process joint observations + sensorComponent.Settings = new PhysicsSensorSettings + { + UseJointForces = true, + UseJointPositions = true, + }; + + sensor = sensorComponent.CreateSensor(); + sensor.Update(); + + expected = new[] + { + 0f, 0f, 0f, // joint1.force + 0f, 0f, 0f, // joint1.torque + 0f, 0f, 0f, // joint2.force + 0f, 0f, 0f, // joint2.torque + }; + SensorTestHelper.CompareObservation(sensor, expected); + Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); } } From 3044754c3d045bd103cc0d98f10cbe3632eb7892 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 15 Jul 2020 15:26:58 -0700 Subject: [PATCH 5/8] rename and docstrings --- .yamato/com.unity.ml-agents-pack.yml | 2 +- .../Runtime/Sensors/ArticulationBodyJointExtractor.cs | 4 ++-- .../Runtime/Sensors/PhysicsSensorSettings.cs | 8 +++++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.yamato/com.unity.ml-agents-pack.yml b/.yamato/com.unity.ml-agents-pack.yml index 293869e7b8..1487123949 100644 --- a/.yamato/com.unity.ml-agents-pack.yml +++ b/.yamato/com.unity.ml-agents-pack.yml @@ -6,7 +6,7 @@ pack: flavor: b1.small commands: - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm - - upm-ci project pack --project-path Project + - upm-ci package pack --package-path com.unity.ml-agents artifacts: packages: paths: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs index 78df2d2718..31c70afc68 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs @@ -28,7 +28,7 @@ public static int NumObservations(ArticulationBody body, PhysicsSensorSettings s } var totalCount = 0; - if (settings.UseJointPositions) + if (settings.UseJointPositionsAndAngles) { switch (body.jointType) { @@ -64,7 +64,7 @@ public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int o var currentOffset = offset; // Write joint positions - if (settings.UseJointPositions) + if (settings.UseJointPositionsAndAngles) { switch (m_Body.jointType) { diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs index 94729b2979..9109d9592e 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs @@ -40,8 +40,14 @@ public struct PhysicsSensorSettings /// public bool UseLocalSpaceLinearVelocity; - public bool UseJointPositions; + /// + /// Whether to use joint-specific positions and angles as observations. + /// + public bool UseJointPositionsAndAngles; + /// + /// Whether to use the joint forces and torques that are applied by the solver as observations. + /// public bool UseJointForces; /// From 15dfcc794004bbcdec76ccf83b608816beaaae91 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 16 Jul 2020 14:36:25 -0700 Subject: [PATCH 6/8] fix rename --- .../Tests/Editor/Sensors/ArticulationBodySensorTests.cs | 2 +- .../Tests/Editor/Sensors/RigidBodySensorTests.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs index 31f7ccdfe5..94e708ed68 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs @@ -121,7 +121,7 @@ public void TestBodiesWithJoint() sensorComponent.Settings = new PhysicsSensorSettings { UseJointForces = true, - UseJointPositions = true, + UseJointPositionsAndAngles = true, }; sensor = sensorComponent.CreateSensor(); diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs index 7baaeec60d..279fc7007d 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs @@ -113,8 +113,8 @@ public void TestBodiesWithJoint() // Update the settings to only process joint observations sensorComponent.Settings = new PhysicsSensorSettings { + UseJointPositionsAndAngles = true, UseJointForces = true, - UseJointPositions = true, }; sensor = sensorComponent.CreateSensor(); From fb2089ab752400333129e7415e3b9f99c9bf5297 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 16 Jul 2020 16:03:29 -0700 Subject: [PATCH 7/8] undo pack job change --- .yamato/com.unity.ml-agents-pack.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.yamato/com.unity.ml-agents-pack.yml b/.yamato/com.unity.ml-agents-pack.yml index 1487123949..293869e7b8 100644 --- a/.yamato/com.unity.ml-agents-pack.yml +++ b/.yamato/com.unity.ml-agents-pack.yml @@ -6,7 +6,7 @@ pack: flavor: b1.small commands: - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm - - upm-ci package pack --package-path com.unity.ml-agents + - upm-ci project pack --project-path Project artifacts: packages: paths: From 2b5bac79326179b73edf55e4c531592458b172a4 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 22 Jul 2020 16:53:21 -0700 Subject: [PATCH 8/8] comments --- .../Sensors/ArticulationBodyJointExtractor.cs | 9 ++++++--- .../Runtime/Sensors/IJointExtractor.cs | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs index 31c70afc68..49dd17b1a0 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs @@ -34,11 +34,13 @@ public static int NumObservations(ArticulationBody body, PhysicsSensorSettings s { case ArticulationJointType.RevoluteJoint: case ArticulationJointType.SphericalJoint: - // Two floats per angular component + // Both RevoluteJoint and SphericalJoint have all angular components. + // We use sine and cosine of the angles for the observations. totalCount += 2 * body.dofCount; break; case ArticulationJointType.FixedJoint: - break; // TODO (none?) + // Since FixedJoint can't moved, there aren't any interesting observations for it. + break; case ArticulationJointType.PrismaticJoint: // One linear component totalCount += body.dofCount; @@ -79,7 +81,8 @@ public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int o } break; case ArticulationJointType.FixedJoint: - break; // TODO + // No observations + break; case ArticulationJointType.PrismaticJoint: writer[currentOffset++] = GetPrismaticValue(); break; diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs index 19df7bb24c..401e3abf50 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs @@ -2,9 +2,26 @@ namespace Unity.MLAgents.Extensions.Sensors { + /// + /// Interface for generating observations from a physical joint or constraint. + /// public interface IJointExtractor { + /// + /// Determine the number of observations that would be generated for the particular joint + /// using the provided PhysicsSensorSettings. + /// + /// + /// Number of floats that will be written. int NumObservations(PhysicsSensorSettings settings); + + /// + /// Write the observations to the ObservationWriter, starting at the specified offset. + /// + /// + /// + /// + /// Number of floats that were written. int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset); } }