diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab index ce6583ecd3..22474a8612 100644 --- a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab +++ b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab @@ -2742,6 +2742,7 @@ GameObject: - component: {fileID: 4845971001715176662} - component: {fileID: 4845971001715176663} - component: {fileID: 4845971001715176660} + - component: {fileID: 4622120667686875944} m_Layer: 0 m_Name: Crawler m_TagString: Untagged @@ -2779,7 +2780,7 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - VectorObservationSize: 138 + VectorObservationSize: 21 NumStackedVectorObservations: 1 VectorActionSize: 14000000 VectorActionDescriptions: [] @@ -2872,6 +2873,30 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: debugCommandLineOverride: +--- !u!114 &4622120667686875944 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 4845971001715176661} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: df0f8be9a37d6486498061e2cbc4cd94, type: 3} + m_Name: + m_EditorClassIdentifier: + RootBody: {fileID: 4845971001588102145} + VirtualRoot: {fileID: 2270141184585723037} + Settings: + UseModelSpaceTranslations: 1 + UseModelSpaceRotations: 1 + UseLocalSpaceTranslations: 0 + UseLocalSpaceRotations: 1 + UseModelSpaceLinearVelocity: 1 + UseLocalSpaceLinearVelocity: 0 + UseJointPositionsAndAngles: 0 + UseJointForces: 0 + sensorName: --- !u!1 &4845971001730692034 GameObject: m_ObjectHideFlags: 0 @@ -3018,6 +3043,12 @@ PrefabInstance: objectReference: {fileID: 0} m_RemovedComponents: [] m_SourcePrefab: {fileID: 100100000, guid: 72f745913c5a34df5aaadd5c1f0024cb, type: 3} +--- !u!1 &2270141184585723037 stripped +GameObject: + m_CorrespondingSourceObject: {fileID: 2591864627249999519, guid: 72f745913c5a34df5aaadd5c1f0024cb, + type: 3} + m_PrefabInstance: {fileID: 4357529801223143938} + m_PrefabAsset: {fileID: 0} --- !u!4 &2270141184585723026 stripped Transform: m_CorrespondingSourceObject: {fileID: 2591864627249999504, guid: 72f745913c5a34df5aaadd5c1f0024cb, @@ -3030,7 +3061,7 @@ MonoBehaviour: type: 3} m_PrefabInstance: {fileID: 4357529801223143938} m_PrefabAsset: {fileID: 0} - m_GameObject: {fileID: 0} + m_GameObject: {fileID: 2270141184585723037} m_Enabled: 1 m_EditorHideFlags: 0 m_Script: {fileID: 11500000, guid: 771e78c5e980e440e8cd19716b55075f, type: 3} diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab index fb182341b2..aa1378135e 100644 --- a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab +++ b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab @@ -392,6 +392,11 @@ PrefabInstance: propertyPath: targetToLookAt value: objectReference: {fileID: 2673081981996998229} + - target: {fileID: 4622120667686875944, guid: 0456c89e8c9c243d595b039fe7aa0bf9, + type: 3} + propertyPath: Settings.UseLocalSpaceLinearVelocity + value: 1 + objectReference: {fileID: 0} - target: {fileID: 4845971000000621469, guid: 0456c89e8c9c243d595b039fe7aa0bf9, type: 3} propertyPath: m_ConnectedAnchor.x diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs b/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs index e0b7951833..fda546a13c 100644 --- a/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs @@ -91,17 +91,8 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor) //GROUND CHECK sensor.AddObservation(bp.groundContact.touchingGround); // Is this bp touching the ground - //Get velocities in the context of our orientation cube's space - //Note: You can get these velocities in world space as well but it may not train as well. - sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.velocity)); - sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity)); - - //Get position relative to hips in the context of our orientation cube's space - sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.position - body.position)); - if (bp.rb.transform != body) { - sensor.AddObservation(bp.rb.transform.localRotation); sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit); } } @@ -111,9 +102,6 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor) /// public override void CollectObservations(VectorSensor sensor) { - //Add body rotation delta relative to orientation cube - sensor.AddObservation(Quaternion.FromToRotation(body.forward, orientationCube.transform.forward)); - //Add pos of target relative to orientation cube sensor.AddObservation(orientationCube.transform.InverseTransformPoint(target.transform.position)); diff --git a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn index 9413f25653..9902063433 100644 Binary files a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn and b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn differ diff --git a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn index 70fac39c0c..85e15c089d 100644 Binary files a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn and b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn differ diff --git a/com.unity.ml-agents.extensions/README.md b/com.unity.ml-agents.extensions/README.md index 651f450e09..5cba2759c9 100644 --- a/com.unity.ml-agents.extensions/README.md +++ b/com.unity.ml-agents.extensions/README.md @@ -1,3 +1,5 @@ # ML-Agents Extensions This is a source-only package for new features based on ML-Agents. + +More details coming soon. diff --git a/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs new file mode 100644 index 0000000000..0cd831e21f --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")] diff --git a/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta new file mode 100644 index 0000000000..21cec76829 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 48c8790647c3345e19c57d6c21065112 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs index f354a614b7..49aef67b74 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs @@ -54,17 +54,17 @@ public ArticulationBodyPoseExtractor(ArticulationBody rootBody) parentIndices[i] = bodyToIndex[parentArticBody]; } - SetParentIndices(parentIndices); + Setup(parentIndices); } /// - protected override Vector3 GetLinearVelocityAt(int index) + protected internal override Vector3 GetLinearVelocityAt(int index) { return m_Bodies[index].velocity; } /// - protected override Pose GetPoseAt(int index) + protected internal override Pose GetPoseAt(int index) { var body = m_Bodies[index]; var go = body.gameObject; diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs index de9d3866f6..ec9eddfae1 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -18,12 +18,20 @@ public class PhysicsBodySensor : ISensor /// /// Construct a new PhysicsBodySensor /// - /// + /// The root Rigidbody. This has no Joints on it (but other Joints may connect to it). + /// Optional GameObject used to find Rigidbodies in the hierarchy. + /// Optional GameObject used to determine the root of the poses, /// /// - public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null) + public PhysicsBodySensor( + Rigidbody rootBody, + GameObject rootGameObject, + GameObject virtualRoot, + PhysicsSensorSettings settings, + string sensorName=null + ) { - var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject); + var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject, virtualRoot); m_PoseExtractor = poseExtractor; m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName; m_Settings = settings; diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs index 9109d9592e..5488be8666 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs @@ -1,5 +1,4 @@ using System; - using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Extensions.Sensors @@ -95,25 +94,26 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting var offset = baseOffset; if (settings.UseModelSpace) { - var poses = poseExtractor.ModelSpacePoses; - var vels = poseExtractor.ModelSpaceVelocities; - - for(var i=0; i - /// Read access to the model space transforms. + /// Read iterator for the enabled model space transforms. /// - public IList ModelSpacePoses + public IEnumerable GetEnabledModelSpacePoses() { - get { return m_ModelSpacePoses; } + if (m_ModelSpacePoses == null) + { + yield break; + } + + for (var i = 0; i < m_ModelSpacePoses.Length; i++) + { + if (m_PoseEnabled[i]) + { + yield return m_ModelSpacePoses[i]; + } + } } /// - /// Read access to the local space transforms. + /// Read iterator for the enabled local space transforms. /// - public IList LocalSpacePoses + public IEnumerable GetEnabledLocalSpacePoses() { - get { return m_LocalSpacePoses; } + if (m_LocalSpacePoses == null) + { + yield break; + } + + for (var i = 0; i < m_LocalSpacePoses.Length; i++) + { + if (m_PoseEnabled[i]) + { + yield return m_LocalSpacePoses[i]; + } + } } /// - /// Read access to the model space linear velocities. + /// Read iterator for the enabled model space linear velocities. /// - public IList ModelSpaceVelocities + public IEnumerable GetEnabledModelSpaceVelocities() { - get { return m_ModelSpaceLinearVelocities; } + if (m_ModelSpaceLinearVelocities == null) + { + yield break; + } + + for (var i = 0; i < m_ModelSpaceLinearVelocities.Length; i++) + { + if (m_PoseEnabled[i]) + { + yield return m_ModelSpaceLinearVelocities[i]; + } + } } /// - /// Read access to the local space linear velocities. + /// Read iterator for the enabled local space linear velocities. /// - public IList LocalSpaceVelocities + public IEnumerable GetEnabledLocalSpaceVelocities() { - get { return m_LocalSpaceLinearVelocities; } + if (m_LocalSpaceLinearVelocities == null) + { + yield break; + } + + for (var i = 0; i < m_LocalSpaceLinearVelocities.Length; i++) + { + if (m_PoseEnabled[i]) + { + yield return m_LocalSpaceLinearVelocities[i]; + } + } } /// - /// Number of poses in the hierarchy (read-only). + /// Number of enabled poses in the hierarchy (read-only). + /// + public int NumEnabledPoses + { + get + { + if (m_PoseEnabled == null) + { + return 0; + } + + var numEnabled = 0; + for (var i = 0; i < m_PoseEnabled.Length; i++) + { + numEnabled += m_PoseEnabled[i] ? 1 : 0; + } + + return numEnabled; + } + } + + /// + /// Number of total poses in the hierarchy (read-only). /// public int NumPoses { - get { return m_ModelSpacePoses?.Length ?? 0; } + get { return m_ModelSpacePoses?.Length ?? 0; } } /// @@ -77,20 +145,43 @@ public int GetParentIndex(int index) return m_ParentIndices[index]; } + /// + /// Set whether the pose at the given index is enabled or disabled for observations. + /// + /// + /// + public void SetPoseEnabled(int index, bool val) + { + m_PoseEnabled[index] = val; + } + /// /// Initialize with the mapping of parent indices. /// The 0th element is assumed to be -1, indicating that it's the root. /// /// - protected void SetParentIndices(int[] parentIndices) + protected void Setup(int[] parentIndices) { +#if DEBUG + if (parentIndices[0] != -1) + { + throw new UnityAgentsException($"Expected parentIndices[0] to be -1, got {parentIndices[0]}"); + } +#endif m_ParentIndices = parentIndices; - var numTransforms = parentIndices.Length; - m_ModelSpacePoses = new Pose[numTransforms]; - m_LocalSpacePoses = new Pose[numTransforms]; + var numPoses = parentIndices.Length; + m_ModelSpacePoses = new Pose[numPoses]; + m_LocalSpacePoses = new Pose[numPoses]; - m_ModelSpaceLinearVelocities = new Vector3[numTransforms]; - m_LocalSpaceLinearVelocities = new Vector3[numTransforms]; + m_ModelSpaceLinearVelocities = new Vector3[numPoses]; + m_LocalSpaceLinearVelocities = new Vector3[numPoses]; + + m_PoseEnabled = new bool[numPoses]; + // All poses are enabled by default. Generally we'll want to disable the root though. + for (var i = 0; i < numPoses; i++) + { + m_PoseEnabled[i] = true; + } } /// @@ -98,14 +189,14 @@ protected void SetParentIndices(int[] parentIndices) /// /// /// - protected abstract Pose GetPoseAt(int index); + protected internal abstract Pose GetPoseAt(int index); /// /// Return the world space linear velocity of the i'th object. /// /// /// - protected abstract Vector3 GetLinearVelocityAt(int index); + protected internal abstract Vector3 GetLinearVelocityAt(int index); /// @@ -113,24 +204,27 @@ protected void SetParentIndices(int[] parentIndices) /// public void UpdateModelSpacePoses() { - if (m_ModelSpacePoses == null) + using (TimerStack.Instance.Scoped("UpdateModelSpacePoses")) { - return; - } + if (m_ModelSpacePoses == null) + { + return; + } - var rootWorldTransform = GetPoseAt(0); - var worldToModel = rootWorldTransform.Inverse(); - var rootLinearVel = GetLinearVelocityAt(0); + var rootWorldTransform = GetPoseAt(0); + var worldToModel = rootWorldTransform.Inverse(); + var rootLinearVel = GetLinearVelocityAt(0); - for (var i = 0; i < m_ModelSpacePoses.Length; i++) - { - var currentWorldSpacePose = GetPoseAt(i); - var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose); - m_ModelSpacePoses[i] = currentModelSpacePose; + for (var i = 0; i < m_ModelSpacePoses.Length; i++) + { + var currentWorldSpacePose = GetPoseAt(i); + var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose); + m_ModelSpacePoses[i] = currentModelSpacePose; - var currentBodyLinearVel = GetLinearVelocityAt(i); - var relativeVelocity = currentBodyLinearVel - rootLinearVel; - m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity; + var currentBodyLinearVel = GetLinearVelocityAt(i); + var relativeVelocity = currentBodyLinearVel - rootLinearVel; + m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity; + } } } @@ -139,30 +233,33 @@ public void UpdateModelSpacePoses() /// public void UpdateLocalSpacePoses() { - if (m_LocalSpacePoses == null) - { - return; - } - - for (var i = 0; i < m_LocalSpacePoses.Length; i++) + using (TimerStack.Instance.Scoped("UpdateLocalSpacePoses")) { - if (m_ParentIndices[i] != -1) + if (m_LocalSpacePoses == null) { - var parentTransform = GetPoseAt(m_ParentIndices[i]); - // This is slightly inefficient, since for a body with multiple children, we'll end up inverting - // the transform multiple times. Might be able to trade space for perf here. - var invParent = parentTransform.Inverse(); - var currentTransform = GetPoseAt(i); - m_LocalSpacePoses[i] = invParent.Multiply(currentTransform); - - var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]); - var currentLinearVel = GetLinearVelocityAt(i); - m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel); + return; } - else + + for (var i = 0; i < m_LocalSpacePoses.Length; i++) { - m_LocalSpacePoses[i] = Pose.identity; - m_LocalSpaceLinearVelocities[i] = Vector3.zero; + if (m_ParentIndices[i] != -1) + { + var parentTransform = GetPoseAt(m_ParentIndices[i]); + // This is slightly inefficient, since for a body with multiple children, we'll end up inverting + // the transform multiple times. Might be able to trade space for perf here. + var invParent = parentTransform.Inverse(); + var currentTransform = GetPoseAt(i); + m_LocalSpacePoses[i] = invParent.Multiply(currentTransform); + + var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]); + var currentLinearVel = GetLinearVelocityAt(i); + m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel); + } + else + { + m_LocalSpacePoses[i] = Pose.identity; + m_LocalSpaceLinearVelocities[i] = Vector3.zero; + } } } } @@ -183,7 +280,7 @@ public int GetNumPoseObservations(PhysicsSensorSettings settings) obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0; obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0; - return NumPoses * obsPerPose; + return NumEnabledPoses * obsPerPose; } internal void DrawModelSpace(Vector3 offset) diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs index 05b55ef737..44ff9a7641 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs @@ -12,11 +12,21 @@ public class RigidBodyPoseExtractor : PoseExtractor { Rigidbody[] m_Bodies; + /// + /// Optional game object used to determine the root of the poses, separate from the actual Rigidbodies + /// in the hierarchy. For locomotion + /// + GameObject m_VirtualRoot; + /// /// Initialize given a root RigidBody. /// - /// - public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null) + /// The root Rigidbody. This has no Joints on it (but other Joints may connect to it). + /// Optional GameObject used to find Rigidbodies in the hierarchy. + /// Optional GameObject used to determine the root of the poses, + /// separate from the actual Rigidbodies in the hierarchy. For locomotion tasks, with ragdolls, this provides + /// a stabilized refernece frame, which can improve learning. + public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, GameObject virtualRoot = null) { if (rootBody == null) { @@ -32,18 +42,42 @@ public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = nu { rbs = rootGameObject.GetComponentsInChildren(); } - var bodyToIndex = new Dictionary(rbs.Length); - var parentIndices = new int[rbs.Length]; - if (rbs[0] != rootBody) + if (rbs == null || rbs.Length == 0) + { + Debug.Log("No rigid bodies found!"); + return; + } + + if (rbs[0] != rootBody) { Debug.Log("Expected root body at index 0"); return; } + // Adjust the array if we have a virtual root. + // This will be at index 0, and the "real" root will be parented to it. + if (virtualRoot != null) + { + var extendedRbs = new Rigidbody[rbs.Length + 1]; + for (var i = 0; i < rbs.Length; i++) + { + extendedRbs[i + 1] = rbs[i]; + } + + rbs = extendedRbs; + } + + var bodyToIndex = new Dictionary(rbs.Length); + var parentIndices = new int[rbs.Length]; + parentIndices[0] = -1; + for (var i = 0; i < rbs.Length; i++) { - bodyToIndex[rbs[i]] = i; + if(rbs[i] != null) + { + bodyToIndex[rbs[i]] = i; + } } var joints = rootBody.GetComponentsInChildren (); @@ -59,19 +93,44 @@ public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = nu parentIndices[childIndex] = parentIndex; } + if (virtualRoot != null) + { + // Make sure the original root treats the virtual root as its parent. + parentIndices[1] = 0; + m_VirtualRoot = virtualRoot; + } + m_Bodies = rbs; - SetParentIndices(parentIndices); + Setup(parentIndices); + + // By default, ignore the root + SetPoseEnabled(0, false); } /// - protected override Vector3 GetLinearVelocityAt(int index) + protected internal override Vector3 GetLinearVelocityAt(int index) { + if (index == 0 && m_VirtualRoot != null) + { + // No velocity on the virtual root + return Vector3.zero; + } return m_Bodies[index].velocity; } /// - protected override Pose GetPoseAt(int index) + protected internal override Pose GetPoseAt(int index) { + if (index == 0 && m_VirtualRoot != null) + { + // Use the GameObject's world transform + return new Pose + { + rotation = m_VirtualRoot.transform.rotation, + position = m_VirtualRoot.transform.position + }; + } + var body = m_Bodies[index]; return new Pose { rotation = body.rotation, position = body.position }; } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs index ce6cf05379..9a077a6594 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs @@ -13,6 +13,11 @@ public class RigidBodySensorComponent : SensorComponent /// public Rigidbody RootBody; + /// + /// Optional GameObject used to determine the root of the poses. + /// + public GameObject VirtualRoot; + /// /// Settings defining what types of observations will be generated. /// @@ -30,7 +35,7 @@ public class RigidBodySensorComponent : SensorComponent /// public override ISensor CreateSensor() { - return new PhysicsBodySensor(RootBody, gameObject, Settings, sensorName); + return new PhysicsBodySensor(RootBody, gameObject, VirtualRoot, Settings, sensorName); } /// @@ -43,7 +48,7 @@ public override int[] GetObservationShape() // TODO static method in PhysicsBodySensor? // TODO only update PoseExtractor when body changes? - var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject); + var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot); var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); var numJointObservations = 0; diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs index 94e708ed68..f642c32c5d 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs @@ -100,17 +100,12 @@ public void TestBodiesWithJoint() // Local space 0f, 0f, 0f, // Root pos -#if UNITY_2020_2_OR_NEWER - 0f, 0f, 0f, // Root vel -#endif - 13.37f, 0f, 0f, // Attached pos -#if UNITY_2020_2_OR_NEWER - -1f, 1f, 0f, // Attached vel -#endif - 4.2f, 0f, 0f, // Leaf pos + #if UNITY_2020_2_OR_NEWER + 0f, 0f, 0f, // Root vel + -1f, 1f, 0f, // Attached vel 0f, -1f, 1f // Leaf vel #endif }; diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs index 627b4d7b8b..5f862d613d 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs @@ -8,19 +8,19 @@ public class PoseExtractorTests { class UselessPoseExtractor : PoseExtractor { - protected override Pose GetPoseAt(int index) + protected internal override Pose GetPoseAt(int index) { return Pose.identity; } - protected override Vector3 GetLinearVelocityAt(int index) + protected internal override Vector3 GetLinearVelocityAt(int index) { return Vector3.zero; } public void Init(int[] parentIndices) { - SetParentIndices(parentIndices); + Setup(parentIndices); } } @@ -60,10 +60,10 @@ public ChainPoseExtractor(int size) { parents[i] = i - 1; } - SetParentIndices(parents); + Setup(parents); } - protected override Pose GetPoseAt(int index) + protected internal override Pose GetPoseAt(int index) { var rotation = Quaternion.identity; var translation = offset + new Vector3(index, index, index); @@ -74,7 +74,7 @@ protected override Pose GetPoseAt(int index) }; } - protected override Vector3 GetLinearVelocityAt(int index) + protected internal override Vector3 GetLinearVelocityAt(int index) { return Vector3.zero; } @@ -91,23 +91,77 @@ public void TestChain() chain.UpdateModelSpacePoses(); chain.UpdateLocalSpacePoses(); - // Root transforms are currently always the identity. - Assert.IsTrue(chain.ModelSpacePoses[0] == Pose.identity); - Assert.IsTrue(chain.LocalSpacePoses[0] == Pose.identity); - // Check the non-root transforms - for (var i = 1; i < size; i++) + var modelPoseIndex = 0; + foreach (var modelSpace in chain.GetEnabledModelSpacePoses()) { - var modelSpace = chain.ModelSpacePoses[i]; - var expectedModelTranslation = new Vector3(i, i, i); - Assert.IsTrue(expectedModelTranslation == modelSpace.position); + if (modelPoseIndex == 0) + { + // Root transforms are currently always the identity. + Assert.IsTrue(modelSpace == Pose.identity); + } + else + { + var expectedModelTranslation = new Vector3(modelPoseIndex, modelPoseIndex, modelPoseIndex); + Assert.IsTrue(expectedModelTranslation == modelSpace.position); - var localSpace = chain.LocalSpacePoses[i]; - var expectedLocalTranslation = new Vector3(1, 1, 1); - Assert.IsTrue(expectedLocalTranslation == localSpace.position); + } + modelPoseIndex++; } + Assert.AreEqual(size, modelPoseIndex); + + var localPoseIndex = 0; + foreach (var localSpace in chain.GetEnabledLocalSpacePoses()) + { + if (localPoseIndex == 0) + { + // Root transforms are currently always the identity. + Assert.IsTrue(localSpace == Pose.identity); + } + else + { + var expectedLocalTranslation = new Vector3(1, 1, 1); + Assert.IsTrue(expectedLocalTranslation == localSpace.position, $"{expectedLocalTranslation} != {localSpace.position}"); + } + + localPoseIndex++; + } + Assert.AreEqual(size, localPoseIndex); } + class BadPoseExtractor : PoseExtractor + { + public BadPoseExtractor() + { + var size = 2; + var parents = new int[size]; + // Parents are intentionally invalid - expect -1 at root + for (var i = 0; i < size; i++) + { + parents[i] = i; + } + Setup(parents); + } + + protected internal override Pose GetPoseAt(int index) + { + return Pose.identity; + } + + protected internal override Vector3 GetLinearVelocityAt(int index) + { + return Vector3.zero; + } + } + + [Test] + public void TestExpectedRoot() + { + Assert.Throws(() => + { + var bad = new BadPoseExtractor(); + }); + } } public class PoseExtensionTests diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs index a5d8b5bcb5..2d157b88e0 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs @@ -1,6 +1,7 @@ using UnityEngine; using NUnit.Framework; using Unity.MLAgents.Extensions.Sensors; +using UnityEditor; namespace Unity.MLAgents.Extensions.Tests.Sensors { @@ -56,6 +57,63 @@ public void TestTwoBodies() var poseExtractor = new RigidBodyPoseExtractor(rb1); Assert.AreEqual(2, poseExtractor.NumPoses); + + rb1.position = new Vector3(1, 0, 0); + rb1.rotation = Quaternion.Euler(0, 13.37f, 0); + rb1.velocity = new Vector3(2, 0, 0); + + Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position); + Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation); + Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0)); + } + + [Test] + public void TestTwoBodiesVirtualRoot() + { + // * virtualRoot + // * rootObj + // - rb1 + // * go2 + // - rb2 + // - joint + var virtualRoot = new GameObject("I am vroot"); + + var rootObj = new GameObject(); + var rb1 = rootObj.AddComponent(); + + var go2 = new GameObject(); + var rb2 = go2.AddComponent(); + go2.transform.SetParent(rootObj.transform); + + var joint = go2.AddComponent(); + joint.connectedBody = rb1; + + var poseExtractor = new RigidBodyPoseExtractor(rb1, null, virtualRoot); + Assert.AreEqual(3, poseExtractor.NumPoses); + + // "body" 0 has no parent + Assert.AreEqual(-1, poseExtractor.GetParentIndex(0)); + + // body 1 has parent 0 + Assert.AreEqual(0, poseExtractor.GetParentIndex(1)); + + var virtualRootPos = new Vector3(0,2,0); + var virtualRootRot = Quaternion.Euler(0, 42, 0); + virtualRoot.transform.position = virtualRootPos; + virtualRoot.transform.rotation = virtualRootRot; + + Assert.AreEqual(virtualRootPos, poseExtractor.GetPoseAt(0).position); + Assert.IsTrue(virtualRootRot == poseExtractor.GetPoseAt(0).rotation); + Assert.AreEqual(Vector3.zero, poseExtractor.GetLinearVelocityAt(0)); + + // Same as above test, but using index 1 + rb1.position = new Vector3(1, 0, 0); + rb1.rotation = Quaternion.Euler(0, 13.37f, 0); + rb1.velocity = new Vector3(2, 0, 0); + + Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position); + Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation); + Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1)); } } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs index 279fc7007d..a6c8b9f366 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs @@ -45,14 +45,12 @@ public void TestSingleRigidbody() var sensor = sensorComponent.CreateSensor(); sensor.Update(); - var expected = new[] - { - 0f, 0f, 0f, // ModelSpaceLinearVelocity - 0f, 0f, 0f, // LocalSpaceTranslations - 0f, 0f, 0f, 1f // LocalSpaceRotations - }; - SensorTestHelper.CompareObservation(sensor, expected); + + // The root body is ignored since it always generates identity values + // and there are no other bodies to generate observations. + var expected = new float[0]; Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + SensorTestHelper.CompareObservation(sensor, expected); } [Test] @@ -78,6 +76,7 @@ public void TestBodiesWithJoint() var joint2 = leafGameObj.AddComponent(); joint2.connectedBody = middleRb; + var virtualRoot = new GameObject(); var sensorComponent = rootObj.AddComponent(); sensorComponent.RootBody = rootRb; @@ -87,9 +86,12 @@ public void TestBodiesWithJoint() UseLocalSpaceTranslations = true, UseLocalSpaceLinearVelocity = true }; + sensorComponent.VirtualRoot = virtualRoot; var sensor = sensorComponent.CreateSensor(); sensor.Update(); + + // Note that the VirtualRoot is ignored from the observations var expected = new[] { // Model space @@ -99,16 +101,15 @@ public void TestBodiesWithJoint() // Local space 0f, 0f, 0f, // Root pos - 0f, 0f, 0f, // Root vel - 13.37f, 0f, 0f, // Attached pos - -1f, 1f, 0f, // Attached vel - 4.2f, 0f, 0f, // Leaf pos + + 1f, 0f, 0f, // Root vel (relative to virtual root) + -1f, 1f, 0f, // Attached vel 0f, -1f, 1f // Leaf vel }; - SensorTestHelper.CompareObservation(sensor, expected); Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + SensorTestHelper.CompareObservation(sensor, expected); // Update the settings to only process joint observations sensorComponent.Settings = new PhysicsSensorSettings diff --git a/com.unity.ml-agents/Runtime/AssemblyInfo.cs b/com.unity.ml-agents/Runtime/AssemblyInfo.cs index 5a6e5ced39..4bc7a8bbb0 100644 --- a/com.unity.ml-agents/Runtime/AssemblyInfo.cs +++ b/com.unity.ml-agents/Runtime/AssemblyInfo.cs @@ -2,3 +2,4 @@ [assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")] [assembly: InternalsVisibleTo("Unity.ML-Agents.Editor")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions")]