diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab
index ce6583ecd3..22474a8612 100644
--- a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab
+++ b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab
@@ -2742,6 +2742,7 @@ GameObject:
- component: {fileID: 4845971001715176662}
- component: {fileID: 4845971001715176663}
- component: {fileID: 4845971001715176660}
+ - component: {fileID: 4622120667686875944}
m_Layer: 0
m_Name: Crawler
m_TagString: Untagged
@@ -2779,7 +2780,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- VectorObservationSize: 138
+ VectorObservationSize: 21
NumStackedVectorObservations: 1
VectorActionSize: 14000000
VectorActionDescriptions: []
@@ -2872,6 +2873,30 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
debugCommandLineOverride:
+--- !u!114 &4622120667686875944
+MonoBehaviour:
+ m_ObjectHideFlags: 0
+ m_CorrespondingSourceObject: {fileID: 0}
+ m_PrefabInstance: {fileID: 0}
+ m_PrefabAsset: {fileID: 0}
+ m_GameObject: {fileID: 4845971001715176661}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: df0f8be9a37d6486498061e2cbc4cd94, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ RootBody: {fileID: 4845971001588102145}
+ VirtualRoot: {fileID: 2270141184585723037}
+ Settings:
+ UseModelSpaceTranslations: 1
+ UseModelSpaceRotations: 1
+ UseLocalSpaceTranslations: 0
+ UseLocalSpaceRotations: 1
+ UseModelSpaceLinearVelocity: 1
+ UseLocalSpaceLinearVelocity: 0
+ UseJointPositionsAndAngles: 0
+ UseJointForces: 0
+ sensorName:
--- !u!1 &4845971001730692034
GameObject:
m_ObjectHideFlags: 0
@@ -3018,6 +3043,12 @@ PrefabInstance:
objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 72f745913c5a34df5aaadd5c1f0024cb, type: 3}
+--- !u!1 &2270141184585723037 stripped
+GameObject:
+ m_CorrespondingSourceObject: {fileID: 2591864627249999519, guid: 72f745913c5a34df5aaadd5c1f0024cb,
+ type: 3}
+ m_PrefabInstance: {fileID: 4357529801223143938}
+ m_PrefabAsset: {fileID: 0}
--- !u!4 &2270141184585723026 stripped
Transform:
m_CorrespondingSourceObject: {fileID: 2591864627249999504, guid: 72f745913c5a34df5aaadd5c1f0024cb,
@@ -3030,7 +3061,7 @@ MonoBehaviour:
type: 3}
m_PrefabInstance: {fileID: 4357529801223143938}
m_PrefabAsset: {fileID: 0}
- m_GameObject: {fileID: 0}
+ m_GameObject: {fileID: 2270141184585723037}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 771e78c5e980e440e8cd19716b55075f, type: 3}
diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
index fb182341b2..aa1378135e 100644
--- a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
+++ b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
@@ -392,6 +392,11 @@ PrefabInstance:
propertyPath: targetToLookAt
value:
objectReference: {fileID: 2673081981996998229}
+ - target: {fileID: 4622120667686875944, guid: 0456c89e8c9c243d595b039fe7aa0bf9,
+ type: 3}
+ propertyPath: Settings.UseLocalSpaceLinearVelocity
+ value: 1
+ objectReference: {fileID: 0}
- target: {fileID: 4845971000000621469, guid: 0456c89e8c9c243d595b039fe7aa0bf9,
type: 3}
propertyPath: m_ConnectedAnchor.x
diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs b/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
index e0b7951833..fda546a13c 100644
--- a/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
+++ b/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
@@ -91,17 +91,8 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
//GROUND CHECK
sensor.AddObservation(bp.groundContact.touchingGround); // Is this bp touching the ground
- //Get velocities in the context of our orientation cube's space
- //Note: You can get these velocities in world space as well but it may not train as well.
- sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.velocity));
- sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity));
-
- //Get position relative to hips in the context of our orientation cube's space
- sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.position - body.position));
-
if (bp.rb.transform != body)
{
- sensor.AddObservation(bp.rb.transform.localRotation);
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}
@@ -111,9 +102,6 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
///
public override void CollectObservations(VectorSensor sensor)
{
- //Add body rotation delta relative to orientation cube
- sensor.AddObservation(Quaternion.FromToRotation(body.forward, orientationCube.transform.forward));
-
//Add pos of target relative to orientation cube
sensor.AddObservation(orientationCube.transform.InverseTransformPoint(target.transform.position));
diff --git a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn
index 9413f25653..9902063433 100644
Binary files a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn and b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn differ
diff --git a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn
index 70fac39c0c..85e15c089d 100644
Binary files a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn and b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn differ
diff --git a/com.unity.ml-agents.extensions/README.md b/com.unity.ml-agents.extensions/README.md
index 651f450e09..5cba2759c9 100644
--- a/com.unity.ml-agents.extensions/README.md
+++ b/com.unity.ml-agents.extensions/README.md
@@ -1,3 +1,5 @@
# ML-Agents Extensions
This is a source-only package for new features based on ML-Agents.
+
+More details coming soon.
diff --git a/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs
new file mode 100644
index 0000000000..0cd831e21f
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs
@@ -0,0 +1,3 @@
+using System.Runtime.CompilerServices;
+
+[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]
diff --git a/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta
new file mode 100644
index 0000000000..21cec76829
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 48c8790647c3345e19c57d6c21065112
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
index f354a614b7..49aef67b74 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
@@ -54,17 +54,17 @@ public ArticulationBodyPoseExtractor(ArticulationBody rootBody)
parentIndices[i] = bodyToIndex[parentArticBody];
}
- SetParentIndices(parentIndices);
+ Setup(parentIndices);
}
///
- protected override Vector3 GetLinearVelocityAt(int index)
+ protected internal override Vector3 GetLinearVelocityAt(int index)
{
return m_Bodies[index].velocity;
}
///
- protected override Pose GetPoseAt(int index)
+ protected internal override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];
var go = body.gameObject;
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
index de9d3866f6..ec9eddfae1 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
@@ -18,12 +18,20 @@ public class PhysicsBodySensor : ISensor
///
/// Construct a new PhysicsBodySensor
///
- ///
+ /// The root Rigidbody. This has no Joints on it (but other Joints may connect to it).
+ /// Optional GameObject used to find Rigidbodies in the hierarchy.
+ /// Optional GameObject used to determine the root of the poses,
///
///
- public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
+ public PhysicsBodySensor(
+ Rigidbody rootBody,
+ GameObject rootGameObject,
+ GameObject virtualRoot,
+ PhysicsSensorSettings settings,
+ string sensorName=null
+ )
{
- var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
+ var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject, virtualRoot);
m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
index 9109d9592e..5488be8666 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
@@ -1,5 +1,4 @@
using System;
-
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
@@ -95,25 +94,26 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
var offset = baseOffset;
if (settings.UseModelSpace)
{
- var poses = poseExtractor.ModelSpacePoses;
- var vels = poseExtractor.ModelSpaceVelocities;
-
- for(var i=0; i
- /// Read access to the model space transforms.
+ /// Read iterator for the enabled model space transforms.
///
- public IList ModelSpacePoses
+ public IEnumerable GetEnabledModelSpacePoses()
{
- get { return m_ModelSpacePoses; }
+ if (m_ModelSpacePoses == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_ModelSpacePoses.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_ModelSpacePoses[i];
+ }
+ }
}
///
- /// Read access to the local space transforms.
+ /// Read iterator for the enabled local space transforms.
///
- public IList LocalSpacePoses
+ public IEnumerable GetEnabledLocalSpacePoses()
{
- get { return m_LocalSpacePoses; }
+ if (m_LocalSpacePoses == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_LocalSpacePoses.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_LocalSpacePoses[i];
+ }
+ }
}
///
- /// Read access to the model space linear velocities.
+ /// Read iterator for the enabled model space linear velocities.
///
- public IList ModelSpaceVelocities
+ public IEnumerable GetEnabledModelSpaceVelocities()
{
- get { return m_ModelSpaceLinearVelocities; }
+ if (m_ModelSpaceLinearVelocities == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_ModelSpaceLinearVelocities.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_ModelSpaceLinearVelocities[i];
+ }
+ }
}
///
- /// Read access to the local space linear velocities.
+ /// Read iterator for the enabled local space linear velocities.
///
- public IList LocalSpaceVelocities
+ public IEnumerable GetEnabledLocalSpaceVelocities()
{
- get { return m_LocalSpaceLinearVelocities; }
+ if (m_LocalSpaceLinearVelocities == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_LocalSpaceLinearVelocities.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_LocalSpaceLinearVelocities[i];
+ }
+ }
}
///
- /// Number of poses in the hierarchy (read-only).
+ /// Number of enabled poses in the hierarchy (read-only).
+ ///
+ public int NumEnabledPoses
+ {
+ get
+ {
+ if (m_PoseEnabled == null)
+ {
+ return 0;
+ }
+
+ var numEnabled = 0;
+ for (var i = 0; i < m_PoseEnabled.Length; i++)
+ {
+ numEnabled += m_PoseEnabled[i] ? 1 : 0;
+ }
+
+ return numEnabled;
+ }
+ }
+
+ ///
+ /// Number of total poses in the hierarchy (read-only).
///
public int NumPoses
{
- get { return m_ModelSpacePoses?.Length ?? 0; }
+ get { return m_ModelSpacePoses?.Length ?? 0; }
}
///
@@ -77,20 +145,43 @@ public int GetParentIndex(int index)
return m_ParentIndices[index];
}
+ ///
+ /// Set whether the pose at the given index is enabled or disabled for observations.
+ ///
+ ///
+ ///
+ public void SetPoseEnabled(int index, bool val)
+ {
+ m_PoseEnabled[index] = val;
+ }
+
///
/// Initialize with the mapping of parent indices.
/// The 0th element is assumed to be -1, indicating that it's the root.
///
///
- protected void SetParentIndices(int[] parentIndices)
+ protected void Setup(int[] parentIndices)
{
+#if DEBUG
+ if (parentIndices[0] != -1)
+ {
+ throw new UnityAgentsException($"Expected parentIndices[0] to be -1, got {parentIndices[0]}");
+ }
+#endif
m_ParentIndices = parentIndices;
- var numTransforms = parentIndices.Length;
- m_ModelSpacePoses = new Pose[numTransforms];
- m_LocalSpacePoses = new Pose[numTransforms];
+ var numPoses = parentIndices.Length;
+ m_ModelSpacePoses = new Pose[numPoses];
+ m_LocalSpacePoses = new Pose[numPoses];
- m_ModelSpaceLinearVelocities = new Vector3[numTransforms];
- m_LocalSpaceLinearVelocities = new Vector3[numTransforms];
+ m_ModelSpaceLinearVelocities = new Vector3[numPoses];
+ m_LocalSpaceLinearVelocities = new Vector3[numPoses];
+
+ m_PoseEnabled = new bool[numPoses];
+ // All poses are enabled by default. Generally we'll want to disable the root though.
+ for (var i = 0; i < numPoses; i++)
+ {
+ m_PoseEnabled[i] = true;
+ }
}
///
@@ -98,14 +189,14 @@ protected void SetParentIndices(int[] parentIndices)
///
///
///
- protected abstract Pose GetPoseAt(int index);
+ protected internal abstract Pose GetPoseAt(int index);
///
/// Return the world space linear velocity of the i'th object.
///
///
///
- protected abstract Vector3 GetLinearVelocityAt(int index);
+ protected internal abstract Vector3 GetLinearVelocityAt(int index);
///
@@ -113,24 +204,27 @@ protected void SetParentIndices(int[] parentIndices)
///
public void UpdateModelSpacePoses()
{
- if (m_ModelSpacePoses == null)
+ using (TimerStack.Instance.Scoped("UpdateModelSpacePoses"))
{
- return;
- }
+ if (m_ModelSpacePoses == null)
+ {
+ return;
+ }
- var rootWorldTransform = GetPoseAt(0);
- var worldToModel = rootWorldTransform.Inverse();
- var rootLinearVel = GetLinearVelocityAt(0);
+ var rootWorldTransform = GetPoseAt(0);
+ var worldToModel = rootWorldTransform.Inverse();
+ var rootLinearVel = GetLinearVelocityAt(0);
- for (var i = 0; i < m_ModelSpacePoses.Length; i++)
- {
- var currentWorldSpacePose = GetPoseAt(i);
- var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose);
- m_ModelSpacePoses[i] = currentModelSpacePose;
+ for (var i = 0; i < m_ModelSpacePoses.Length; i++)
+ {
+ var currentWorldSpacePose = GetPoseAt(i);
+ var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose);
+ m_ModelSpacePoses[i] = currentModelSpacePose;
- var currentBodyLinearVel = GetLinearVelocityAt(i);
- var relativeVelocity = currentBodyLinearVel - rootLinearVel;
- m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
+ var currentBodyLinearVel = GetLinearVelocityAt(i);
+ var relativeVelocity = currentBodyLinearVel - rootLinearVel;
+ m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
+ }
}
}
@@ -139,30 +233,33 @@ public void UpdateModelSpacePoses()
///
public void UpdateLocalSpacePoses()
{
- if (m_LocalSpacePoses == null)
- {
- return;
- }
-
- for (var i = 0; i < m_LocalSpacePoses.Length; i++)
+ using (TimerStack.Instance.Scoped("UpdateLocalSpacePoses"))
{
- if (m_ParentIndices[i] != -1)
+ if (m_LocalSpacePoses == null)
{
- var parentTransform = GetPoseAt(m_ParentIndices[i]);
- // This is slightly inefficient, since for a body with multiple children, we'll end up inverting
- // the transform multiple times. Might be able to trade space for perf here.
- var invParent = parentTransform.Inverse();
- var currentTransform = GetPoseAt(i);
- m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
-
- var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
- var currentLinearVel = GetLinearVelocityAt(i);
- m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
+ return;
}
- else
+
+ for (var i = 0; i < m_LocalSpacePoses.Length; i++)
{
- m_LocalSpacePoses[i] = Pose.identity;
- m_LocalSpaceLinearVelocities[i] = Vector3.zero;
+ if (m_ParentIndices[i] != -1)
+ {
+ var parentTransform = GetPoseAt(m_ParentIndices[i]);
+ // This is slightly inefficient, since for a body with multiple children, we'll end up inverting
+ // the transform multiple times. Might be able to trade space for perf here.
+ var invParent = parentTransform.Inverse();
+ var currentTransform = GetPoseAt(i);
+ m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
+
+ var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
+ var currentLinearVel = GetLinearVelocityAt(i);
+ m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
+ }
+ else
+ {
+ m_LocalSpacePoses[i] = Pose.identity;
+ m_LocalSpaceLinearVelocities[i] = Vector3.zero;
+ }
}
}
}
@@ -183,7 +280,7 @@ public int GetNumPoseObservations(PhysicsSensorSettings settings)
obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0;
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0;
- return NumPoses * obsPerPose;
+ return NumEnabledPoses * obsPerPose;
}
internal void DrawModelSpace(Vector3 offset)
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
index 05b55ef737..44ff9a7641 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
@@ -12,11 +12,21 @@ public class RigidBodyPoseExtractor : PoseExtractor
{
Rigidbody[] m_Bodies;
+ ///
+ /// Optional game object used to determine the root of the poses, separate from the actual Rigidbodies
+ /// in the hierarchy. For locomotion
+ ///
+ GameObject m_VirtualRoot;
+
///
/// Initialize given a root RigidBody.
///
- ///
- public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null)
+ /// The root Rigidbody. This has no Joints on it (but other Joints may connect to it).
+ /// Optional GameObject used to find Rigidbodies in the hierarchy.
+ /// Optional GameObject used to determine the root of the poses,
+ /// separate from the actual Rigidbodies in the hierarchy. For locomotion tasks, with ragdolls, this provides
+ /// a stabilized refernece frame, which can improve learning.
+ public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, GameObject virtualRoot = null)
{
if (rootBody == null)
{
@@ -32,18 +42,42 @@ public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = nu
{
rbs = rootGameObject.GetComponentsInChildren();
}
- var bodyToIndex = new Dictionary(rbs.Length);
- var parentIndices = new int[rbs.Length];
- if (rbs[0] != rootBody)
+ if (rbs == null || rbs.Length == 0)
+ {
+ Debug.Log("No rigid bodies found!");
+ return;
+ }
+
+ if (rbs[0] != rootBody)
{
Debug.Log("Expected root body at index 0");
return;
}
+ // Adjust the array if we have a virtual root.
+ // This will be at index 0, and the "real" root will be parented to it.
+ if (virtualRoot != null)
+ {
+ var extendedRbs = new Rigidbody[rbs.Length + 1];
+ for (var i = 0; i < rbs.Length; i++)
+ {
+ extendedRbs[i + 1] = rbs[i];
+ }
+
+ rbs = extendedRbs;
+ }
+
+ var bodyToIndex = new Dictionary(rbs.Length);
+ var parentIndices = new int[rbs.Length];
+ parentIndices[0] = -1;
+
for (var i = 0; i < rbs.Length; i++)
{
- bodyToIndex[rbs[i]] = i;
+ if(rbs[i] != null)
+ {
+ bodyToIndex[rbs[i]] = i;
+ }
}
var joints = rootBody.GetComponentsInChildren ();
@@ -59,19 +93,44 @@ public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = nu
parentIndices[childIndex] = parentIndex;
}
+ if (virtualRoot != null)
+ {
+ // Make sure the original root treats the virtual root as its parent.
+ parentIndices[1] = 0;
+ m_VirtualRoot = virtualRoot;
+ }
+
m_Bodies = rbs;
- SetParentIndices(parentIndices);
+ Setup(parentIndices);
+
+ // By default, ignore the root
+ SetPoseEnabled(0, false);
}
///
- protected override Vector3 GetLinearVelocityAt(int index)
+ protected internal override Vector3 GetLinearVelocityAt(int index)
{
+ if (index == 0 && m_VirtualRoot != null)
+ {
+ // No velocity on the virtual root
+ return Vector3.zero;
+ }
return m_Bodies[index].velocity;
}
///
- protected override Pose GetPoseAt(int index)
+ protected internal override Pose GetPoseAt(int index)
{
+ if (index == 0 && m_VirtualRoot != null)
+ {
+ // Use the GameObject's world transform
+ return new Pose
+ {
+ rotation = m_VirtualRoot.transform.rotation,
+ position = m_VirtualRoot.transform.position
+ };
+ }
+
var body = m_Bodies[index];
return new Pose { rotation = body.rotation, position = body.position };
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
index ce6cf05379..9a077a6594 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
@@ -13,6 +13,11 @@ public class RigidBodySensorComponent : SensorComponent
///
public Rigidbody RootBody;
+ ///
+ /// Optional GameObject used to determine the root of the poses.
+ ///
+ public GameObject VirtualRoot;
+
///
/// Settings defining what types of observations will be generated.
///
@@ -30,7 +35,7 @@ public class RigidBodySensorComponent : SensorComponent
///
public override ISensor CreateSensor()
{
- return new PhysicsBodySensor(RootBody, gameObject, Settings, sensorName);
+ return new PhysicsBodySensor(RootBody, gameObject, VirtualRoot, Settings, sensorName);
}
///
@@ -43,7 +48,7 @@ public override int[] GetObservationShape()
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
- var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject);
+ var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
index 94e708ed68..f642c32c5d 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
@@ -100,17 +100,12 @@ public void TestBodiesWithJoint()
// Local space
0f, 0f, 0f, // Root pos
-#if UNITY_2020_2_OR_NEWER
- 0f, 0f, 0f, // Root vel
-#endif
-
13.37f, 0f, 0f, // Attached pos
-#if UNITY_2020_2_OR_NEWER
- -1f, 1f, 0f, // Attached vel
-#endif
-
4.2f, 0f, 0f, // Leaf pos
+
#if UNITY_2020_2_OR_NEWER
+ 0f, 0f, 0f, // Root vel
+ -1f, 1f, 0f, // Attached vel
0f, -1f, 1f // Leaf vel
#endif
};
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
index 627b4d7b8b..5f862d613d 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
@@ -8,19 +8,19 @@ public class PoseExtractorTests
{
class UselessPoseExtractor : PoseExtractor
{
- protected override Pose GetPoseAt(int index)
+ protected internal override Pose GetPoseAt(int index)
{
return Pose.identity;
}
- protected override Vector3 GetLinearVelocityAt(int index)
+ protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
public void Init(int[] parentIndices)
{
- SetParentIndices(parentIndices);
+ Setup(parentIndices);
}
}
@@ -60,10 +60,10 @@ public ChainPoseExtractor(int size)
{
parents[i] = i - 1;
}
- SetParentIndices(parents);
+ Setup(parents);
}
- protected override Pose GetPoseAt(int index)
+ protected internal override Pose GetPoseAt(int index)
{
var rotation = Quaternion.identity;
var translation = offset + new Vector3(index, index, index);
@@ -74,7 +74,7 @@ protected override Pose GetPoseAt(int index)
};
}
- protected override Vector3 GetLinearVelocityAt(int index)
+ protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
@@ -91,23 +91,77 @@ public void TestChain()
chain.UpdateModelSpacePoses();
chain.UpdateLocalSpacePoses();
- // Root transforms are currently always the identity.
- Assert.IsTrue(chain.ModelSpacePoses[0] == Pose.identity);
- Assert.IsTrue(chain.LocalSpacePoses[0] == Pose.identity);
- // Check the non-root transforms
- for (var i = 1; i < size; i++)
+ var modelPoseIndex = 0;
+ foreach (var modelSpace in chain.GetEnabledModelSpacePoses())
{
- var modelSpace = chain.ModelSpacePoses[i];
- var expectedModelTranslation = new Vector3(i, i, i);
- Assert.IsTrue(expectedModelTranslation == modelSpace.position);
+ if (modelPoseIndex == 0)
+ {
+ // Root transforms are currently always the identity.
+ Assert.IsTrue(modelSpace == Pose.identity);
+ }
+ else
+ {
+ var expectedModelTranslation = new Vector3(modelPoseIndex, modelPoseIndex, modelPoseIndex);
+ Assert.IsTrue(expectedModelTranslation == modelSpace.position);
- var localSpace = chain.LocalSpacePoses[i];
- var expectedLocalTranslation = new Vector3(1, 1, 1);
- Assert.IsTrue(expectedLocalTranslation == localSpace.position);
+ }
+ modelPoseIndex++;
}
+ Assert.AreEqual(size, modelPoseIndex);
+
+ var localPoseIndex = 0;
+ foreach (var localSpace in chain.GetEnabledLocalSpacePoses())
+ {
+ if (localPoseIndex == 0)
+ {
+ // Root transforms are currently always the identity.
+ Assert.IsTrue(localSpace == Pose.identity);
+ }
+ else
+ {
+ var expectedLocalTranslation = new Vector3(1, 1, 1);
+ Assert.IsTrue(expectedLocalTranslation == localSpace.position, $"{expectedLocalTranslation} != {localSpace.position}");
+ }
+
+ localPoseIndex++;
+ }
+ Assert.AreEqual(size, localPoseIndex);
}
+ class BadPoseExtractor : PoseExtractor
+ {
+ public BadPoseExtractor()
+ {
+ var size = 2;
+ var parents = new int[size];
+ // Parents are intentionally invalid - expect -1 at root
+ for (var i = 0; i < size; i++)
+ {
+ parents[i] = i;
+ }
+ Setup(parents);
+ }
+
+ protected internal override Pose GetPoseAt(int index)
+ {
+ return Pose.identity;
+ }
+
+ protected internal override Vector3 GetLinearVelocityAt(int index)
+ {
+ return Vector3.zero;
+ }
+ }
+
+ [Test]
+ public void TestExpectedRoot()
+ {
+ Assert.Throws(() =>
+ {
+ var bad = new BadPoseExtractor();
+ });
+ }
}
public class PoseExtensionTests
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
index a5d8b5bcb5..2d157b88e0 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
@@ -1,6 +1,7 @@
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
+using UnityEditor;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
@@ -56,6 +57,63 @@ public void TestTwoBodies()
var poseExtractor = new RigidBodyPoseExtractor(rb1);
Assert.AreEqual(2, poseExtractor.NumPoses);
+
+ rb1.position = new Vector3(1, 0, 0);
+ rb1.rotation = Quaternion.Euler(0, 13.37f, 0);
+ rb1.velocity = new Vector3(2, 0, 0);
+
+ Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position);
+ Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation);
+ Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0));
+ }
+
+ [Test]
+ public void TestTwoBodiesVirtualRoot()
+ {
+ // * virtualRoot
+ // * rootObj
+ // - rb1
+ // * go2
+ // - rb2
+ // - joint
+ var virtualRoot = new GameObject("I am vroot");
+
+ var rootObj = new GameObject();
+ var rb1 = rootObj.AddComponent();
+
+ var go2 = new GameObject();
+ var rb2 = go2.AddComponent();
+ go2.transform.SetParent(rootObj.transform);
+
+ var joint = go2.AddComponent();
+ joint.connectedBody = rb1;
+
+ var poseExtractor = new RigidBodyPoseExtractor(rb1, null, virtualRoot);
+ Assert.AreEqual(3, poseExtractor.NumPoses);
+
+ // "body" 0 has no parent
+ Assert.AreEqual(-1, poseExtractor.GetParentIndex(0));
+
+ // body 1 has parent 0
+ Assert.AreEqual(0, poseExtractor.GetParentIndex(1));
+
+ var virtualRootPos = new Vector3(0,2,0);
+ var virtualRootRot = Quaternion.Euler(0, 42, 0);
+ virtualRoot.transform.position = virtualRootPos;
+ virtualRoot.transform.rotation = virtualRootRot;
+
+ Assert.AreEqual(virtualRootPos, poseExtractor.GetPoseAt(0).position);
+ Assert.IsTrue(virtualRootRot == poseExtractor.GetPoseAt(0).rotation);
+ Assert.AreEqual(Vector3.zero, poseExtractor.GetLinearVelocityAt(0));
+
+ // Same as above test, but using index 1
+ rb1.position = new Vector3(1, 0, 0);
+ rb1.rotation = Quaternion.Euler(0, 13.37f, 0);
+ rb1.velocity = new Vector3(2, 0, 0);
+
+ Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position);
+ Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation);
+ Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1));
}
}
}
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
index 279fc7007d..a6c8b9f366 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
@@ -45,14 +45,12 @@ public void TestSingleRigidbody()
var sensor = sensorComponent.CreateSensor();
sensor.Update();
- var expected = new[]
- {
- 0f, 0f, 0f, // ModelSpaceLinearVelocity
- 0f, 0f, 0f, // LocalSpaceTranslations
- 0f, 0f, 0f, 1f // LocalSpaceRotations
- };
- SensorTestHelper.CompareObservation(sensor, expected);
+
+ // The root body is ignored since it always generates identity values
+ // and there are no other bodies to generate observations.
+ var expected = new float[0];
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
+ SensorTestHelper.CompareObservation(sensor, expected);
}
[Test]
@@ -78,6 +76,7 @@ public void TestBodiesWithJoint()
var joint2 = leafGameObj.AddComponent();
joint2.connectedBody = middleRb;
+ var virtualRoot = new GameObject();
var sensorComponent = rootObj.AddComponent();
sensorComponent.RootBody = rootRb;
@@ -87,9 +86,12 @@ public void TestBodiesWithJoint()
UseLocalSpaceTranslations = true,
UseLocalSpaceLinearVelocity = true
};
+ sensorComponent.VirtualRoot = virtualRoot;
var sensor = sensorComponent.CreateSensor();
sensor.Update();
+
+ // Note that the VirtualRoot is ignored from the observations
var expected = new[]
{
// Model space
@@ -99,16 +101,15 @@ public void TestBodiesWithJoint()
// Local space
0f, 0f, 0f, // Root pos
- 0f, 0f, 0f, // Root vel
-
13.37f, 0f, 0f, // Attached pos
- -1f, 1f, 0f, // Attached vel
-
4.2f, 0f, 0f, // Leaf pos
+
+ 1f, 0f, 0f, // Root vel (relative to virtual root)
+ -1f, 1f, 0f, // Attached vel
0f, -1f, 1f // Leaf vel
};
- SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
+ SensorTestHelper.CompareObservation(sensor, expected);
// Update the settings to only process joint observations
sensorComponent.Settings = new PhysicsSensorSettings
diff --git a/com.unity.ml-agents/Runtime/AssemblyInfo.cs b/com.unity.ml-agents/Runtime/AssemblyInfo.cs
index 5a6e5ced39..4bc7a8bbb0 100644
--- a/com.unity.ml-agents/Runtime/AssemblyInfo.cs
+++ b/com.unity.ml-agents/Runtime/AssemblyInfo.cs
@@ -2,3 +2,4 @@
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor")]
+[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions")]