diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs
new file mode 100644
index 0000000000..49dd17b1a0
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs
@@ -0,0 +1,147 @@
+#if UNITY_2020_1_OR_NEWER
+
+using System.Collections.Generic;
+using UnityEngine;
+using Unity.MLAgents.Sensors;
+
+namespace Unity.MLAgents.Extensions.Sensors
+{
+ public class ArticulationBodyJointExtractor : IJointExtractor
+ {
+ ArticulationBody m_Body;
+
+ public ArticulationBodyJointExtractor(ArticulationBody body)
+ {
+ m_Body = body;
+ }
+
+ public int NumObservations(PhysicsSensorSettings settings)
+ {
+ return NumObservations(m_Body, settings);
+ }
+
+ public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings)
+ {
+ if (body == null || body.isRoot)
+ {
+ return 0;
+ }
+
+ var totalCount = 0;
+ if (settings.UseJointPositionsAndAngles)
+ {
+ switch (body.jointType)
+ {
+ case ArticulationJointType.RevoluteJoint:
+ case ArticulationJointType.SphericalJoint:
+ // Both RevoluteJoint and SphericalJoint have all angular components.
+ // We use sine and cosine of the angles for the observations.
+ totalCount += 2 * body.dofCount;
+ break;
+ case ArticulationJointType.FixedJoint:
+ // Since FixedJoint can't moved, there aren't any interesting observations for it.
+ break;
+ case ArticulationJointType.PrismaticJoint:
+ // One linear component
+ totalCount += body.dofCount;
+ break;
+ }
+ }
+
+ if (settings.UseJointForces)
+ {
+ totalCount += body.dofCount;
+ }
+
+ return totalCount;
+ }
+
+ public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
+ {
+ if (m_Body == null || m_Body.isRoot)
+ {
+ return 0;
+ }
+
+ var currentOffset = offset;
+
+ // Write joint positions
+ if (settings.UseJointPositionsAndAngles)
+ {
+ switch (m_Body.jointType)
+ {
+ case ArticulationJointType.RevoluteJoint:
+ case ArticulationJointType.SphericalJoint:
+ // All joint positions are angular
+ for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
+ {
+ var jointRotationRads = m_Body.jointPosition[dofIndex];
+ writer[currentOffset++] = Mathf.Sin(jointRotationRads);
+ writer[currentOffset++] = Mathf.Cos(jointRotationRads);
+ }
+ break;
+ case ArticulationJointType.FixedJoint:
+ // No observations
+ break;
+ case ArticulationJointType.PrismaticJoint:
+ writer[currentOffset++] = GetPrismaticValue();
+ break;
+ }
+ }
+
+ if (settings.UseJointForces)
+ {
+ for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
+ {
+ // take tanh to keep in [-1, 1]
+ writer[currentOffset++] = (float) System.Math.Tanh(m_Body.jointForce[dofIndex]);
+ }
+ }
+
+ return currentOffset - offset;
+ }
+
+ float GetPrismaticValue()
+ {
+ // Prismatic joints should have at most one free axis.
+ bool limited = false;
+ var drive = m_Body.xDrive;
+ if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion)
+ {
+ drive = m_Body.xDrive;
+ limited = true;
+ }
+ else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion)
+ {
+ drive = m_Body.yDrive;
+ limited = true;
+ }
+ else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion)
+ {
+ drive = m_Body.zDrive;
+ limited = true;
+ }
+
+ var jointPos = m_Body.jointPosition[0];
+ if (limited)
+ {
+ // If locked, interpolate between the limits.
+ var upperLimit = drive.upperLimit;
+ var lowerLimit = drive.lowerLimit;
+ if (upperLimit <= lowerLimit)
+ {
+ // Invalid limits (probably equal), so don't try to lerp
+ return 0;
+ }
+ var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos);
+
+ // Convert [0, 1] -> [-1, 1]
+ var normalized = 2.0f * invLerped - 1.0f;
+ return normalized;
+ }
+ // take tanh() to keep in [-1, 1]
+ return (float) System.Math.Tanh(jointPos);
+ }
+ }
+}
+#endif
\ No newline at end of file
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta
new file mode 100644
index 0000000000..8b5c4d6729
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 238d15f867b9c4ced9cef331b7420b27
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
index 512b857345..f354a614b7 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
@@ -71,6 +71,8 @@ protected override Pose GetPoseAt(int index)
var t = go.transform;
return new Pose { rotation = t.rotation, position = t.position };
}
+
+ internal ArticulationBody[] Bodies => m_Bodies;
}
}
#endif // UNITY_2020_1_OR_NEWER
\ No newline at end of file
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
index 23682bf2f6..82418a99cf 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
@@ -32,8 +32,16 @@ public override int[] GetObservationShape()
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
- var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
- return new[] { numTransformObservations };
+ var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
+ var numJointObservations = 0;
+ // Start from i=1 to ignore the root
+ for (var i = 1; i < poseExtractor.Bodies.Length; i++)
+ {
+ numJointObservations += ArticulationBodyJointExtractor.NumObservations(
+ poseExtractor.Bodies[i], Settings
+ );
+ }
+ return new[] { numPoseObservations + numJointObservations };
}
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs
new file mode 100644
index 0000000000..401e3abf50
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs
@@ -0,0 +1,27 @@
+using Unity.MLAgents.Sensors;
+
+namespace Unity.MLAgents.Extensions.Sensors
+{
+ ///
+ /// Interface for generating observations from a physical joint or constraint.
+ ///
+ public interface IJointExtractor
+ {
+ ///
+ /// Determine the number of observations that would be generated for the particular joint
+ /// using the provided PhysicsSensorSettings.
+ ///
+ ///
+ /// Number of floats that will be written.
+ int NumObservations(PhysicsSensorSettings settings);
+
+ ///
+ /// Write the observations to the ObservationWriter, starting at the specified offset.
+ ///
+ ///
+ ///
+ ///
+ /// Number of floats that were written.
+ int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset);
+ }
+}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta
new file mode 100644
index 0000000000..a1ef9c2f7b
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 2d2a01ea194334a4682d5c8cad4a956b
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
index 6b0bb2ca0f..de9d3866f6 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
@@ -12,6 +12,7 @@ public class PhysicsBodySensor : ISensor
string m_SensorName;
PoseExtractor m_PoseExtractor;
+ IJointExtractor[] m_JointExtractors;
PhysicsSensorSettings m_Settings;
///
@@ -22,23 +23,59 @@ public class PhysicsBodySensor : ISensor
///
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
{
- m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
+ var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
+ m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
- var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
- m_Shape = new[] { numTransformObservations };
+ var numJointExtractorObservations = 0;
+ var rigidBodies = poseExtractor.Bodies;
+ if (rigidBodies != null)
+ {
+ m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root
+ for (var i = 1; i < rigidBodies.Length; i++)
+ {
+ var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]);
+ numJointExtractorObservations += jointExtractor.NumObservations(settings);
+ m_JointExtractors[i - 1] = jointExtractor;
+ }
+ }
+ else
+ {
+ m_JointExtractors = new IJointExtractor[0];
+ }
+
+ var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
+ m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
}
#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
{
- m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody);
+ var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
+ m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
- var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
- m_Shape = new[] { numTransformObservations };
+ var numJointExtractorObservations = 0;
+ var articBodies = poseExtractor.Bodies;
+ if (articBodies != null)
+ {
+ m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root
+ for (var i = 1; i < articBodies.Length; i++)
+ {
+ var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]);
+ numJointExtractorObservations += jointExtractor.NumObservations(settings);
+ m_JointExtractors[i - 1] = jointExtractor;
+ }
+ }
+ else
+ {
+ m_JointExtractors = new IJointExtractor[0];
+ }
+
+ var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
+ m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
}
#endif
@@ -52,6 +89,10 @@ public int[] GetObservationShape()
public int Write(ObservationWriter writer)
{
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
+ foreach (var jointExtractor in m_JointExtractors)
+ {
+ numWritten += jointExtractor.Write(m_Settings, writer, numWritten);
+ }
return numWritten;
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
index 31a48e31c9..9109d9592e 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
@@ -40,6 +40,16 @@ public struct PhysicsSensorSettings
///
public bool UseLocalSpaceLinearVelocity;
+ ///
+ /// Whether to use joint-specific positions and angles as observations.
+ ///
+ public bool UseJointPositionsAndAngles;
+
+ ///
+ /// Whether to use the joint forces and torques that are applied by the solver as observations.
+ ///
+ public bool UseJointForces;
+
///
/// Creates a PhysicsSensorSettings with reasonable default values.
///
@@ -68,26 +78,6 @@ public bool UseLocalSpace
{
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
}
-
-
- ///
- /// The number of floats needed to represent a given number of transforms.
- ///
- ///
- ///
- public int TransformSize(int numTransforms)
- {
- int obsPerTransform = 0;
- obsPerTransform += UseModelSpaceTranslations ? 3 : 0;
- obsPerTransform += UseModelSpaceRotations ? 4 : 0;
- obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
- obsPerTransform += UseLocalSpaceRotations ? 4 : 0;
-
- obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0;
- obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0;
-
- return numTransforms * obsPerTransform;
- }
}
internal static class ObservationWriterPhysicsExtensions
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
index 03902442ec..6a5c31a7c0 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
@@ -167,6 +167,24 @@ public void UpdateLocalSpacePoses()
}
}
+ ///
+ /// Compute the number of floats needed to represent the poses for the given PhysicsSensorSettings.
+ ///
+ ///
+ ///
+ public int GetNumPoseObservations(PhysicsSensorSettings settings)
+ {
+ int obsPerPose = 0;
+ obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0;
+ obsPerPose += settings.UseModelSpaceRotations ? 4 : 0;
+ obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0;
+ obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0;
+
+ obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0;
+ obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0;
+
+ return NumPoses * obsPerPose;
+ }
internal void DrawModelSpace(Vector3 offset)
{
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs
new file mode 100644
index 0000000000..dda1aed27f
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs
@@ -0,0 +1,62 @@
+using System.Collections.Generic;
+using UnityEngine;
+using Unity.MLAgents.Sensors;
+
+namespace Unity.MLAgents.Extensions.Sensors
+{
+ public class RigidBodyJointExtractor : IJointExtractor
+ {
+ Rigidbody m_Body;
+ Joint m_Joint;
+
+ public RigidBodyJointExtractor(Rigidbody body)
+ {
+ m_Body = body;
+ m_Joint = m_Body?.GetComponent();
+ }
+
+ public int NumObservations(PhysicsSensorSettings settings)
+ {
+ return NumObservations(m_Body, m_Joint, settings);
+ }
+
+ public static int NumObservations(Rigidbody body, Joint joint, PhysicsSensorSettings settings)
+ {
+ if(body == null || joint == null)
+ {
+ return 0;
+ }
+
+ var numObservations = 0;
+ if (settings.UseJointForces)
+ {
+ // 3 force and 3 torque values
+ numObservations += 6;
+ }
+
+ return numObservations;
+ }
+
+ public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
+ {
+ if (m_Body == null || m_Joint == null)
+ {
+ return 0;
+ }
+
+ var currentOffset = offset;
+ if (settings.UseJointForces)
+ {
+ // Take tanh of the forces and torques to ensure they're in [-1, 1]
+ writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.x);
+ writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.y);
+ writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.z);
+
+ writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.x);
+ writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.y);
+ writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.z);
+ }
+ return currentOffset - offset;
+ }
+ }
+}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta
new file mode 100644
index 0000000000..9d3dc91df9
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 5014d7ab95c6a44469f447b8a7019746
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
index 9036cbc4e5..05b55ef737 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
@@ -75,6 +75,8 @@ protected override Pose GetPoseAt(int index)
var body = m_Bodies[index];
return new Pose { rotation = body.rotation, position = body.position };
}
+
+ internal Rigidbody[] Bodies => m_Bodies;
}
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
index 88202bbac1..ce6cf05379 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
@@ -44,8 +44,17 @@ public override int[] GetObservationShape()
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject);
- var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
- return new[] { numTransformObservations };
+ var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
+
+ var numJointObservations = 0;
+ // Start from i=1 to ignore the root
+ for (var i = 1; i < poseExtractor.Bodies.Length; i++)
+ {
+ var body = poseExtractor.Bodies[i];
+ var joint = body?.GetComponent();
+ numJointObservations += RigidBodyJointExtractor.NumObservations(body, joint, Settings);
+ }
+ return new[] { numPoseObservations + numJointObservations };
}
}
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
index 33d23d4697..94e708ed68 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
@@ -42,6 +42,7 @@ public void TestSingleBody()
0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
+ Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
}
[Test]
@@ -61,7 +62,14 @@ public void TestBodiesWithJoint()
var leafArticBody = leafGameObj.AddComponent();
leafGameObj.transform.SetParent(middleGamObj.transform);
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f);
- leafArticBody.jointType = ArticulationJointType.RevoluteJoint;
+ leafArticBody.jointType = ArticulationJointType.PrismaticJoint;
+ leafArticBody.linearLockZ = ArticulationDofLock.LimitedMotion;
+ leafArticBody.zDrive = new ArticulationDrive
+ {
+ lowerLimit = -3,
+ upperLimit = 1
+ };
+
#if UNITY_2020_2_OR_NEWER
// ArticulationBody.velocity is read-only in 2020.1
@@ -107,6 +115,30 @@ public void TestBodiesWithJoint()
#endif
};
SensorTestHelper.CompareObservation(sensor, expected);
+ Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
+
+ // Update the settings to only process joint observations
+ sensorComponent.Settings = new PhysicsSensorSettings
+ {
+ UseJointForces = true,
+ UseJointPositionsAndAngles = true,
+ };
+
+ sensor = sensorComponent.CreateSensor();
+ sensor.Update();
+
+ expected = new[]
+ {
+ // revolute
+ 0f, 1f, // joint1.position (sin and cos)
+ 0f, // joint1.force
+
+ // prismatic
+ 0.5f, // joint2.position (interpolate between limits)
+ 0f, // joint2.force
+ };
+ SensorTestHelper.CompareObservation(sensor, expected);
+ Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
}
}
}
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
index 5fbb74c9cf..279fc7007d 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
@@ -52,6 +52,7 @@ public void TestSingleRigidbody()
0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
+ Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
}
[Test]
@@ -107,6 +108,27 @@ public void TestBodiesWithJoint()
0f, -1f, 1f // Leaf vel
};
SensorTestHelper.CompareObservation(sensor, expected);
+ Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
+
+ // Update the settings to only process joint observations
+ sensorComponent.Settings = new PhysicsSensorSettings
+ {
+ UseJointPositionsAndAngles = true,
+ UseJointForces = true,
+ };
+
+ sensor = sensorComponent.CreateSensor();
+ sensor.Update();
+
+ expected = new[]
+ {
+ 0f, 0f, 0f, // joint1.force
+ 0f, 0f, 0f, // joint1.torque
+ 0f, 0f, 0f, // joint2.force
+ 0f, 0f, 0f, // joint2.torque
+ };
+ SensorTestHelper.CompareObservation(sensor, expected);
+ Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
}
}