Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#if UNITY_2020_1_OR_NEWER

using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{
public class ArticulationBodyJointExtractor : IJointExtractor
{
ArticulationBody m_Body;

public ArticulationBodyJointExtractor(ArticulationBody body)
{
m_Body = body;
}

public int NumObservations(PhysicsSensorSettings settings)
{
return NumObservations(m_Body, settings);
}

public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings)
{
if (body == null || body.isRoot)
{
return 0;
}

var totalCount = 0;
if (settings.UseJointPositionsAndAngles)
{
switch (body.jointType)
{
case ArticulationJointType.RevoluteJoint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in this case? Is this TODO later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should "fall through" to totalCount += 2 * body.dofCount right? Or is my understanding of switch statements obsolete?

I believe (should verify with the physics folks) that:
RevoluteJoint: 1 angular dof
SphericalJoint: 3 angular dof
FixedJoint: 0 dof
PrismaticJoint: 1 linear dof

so both RevoluteJoint and SphericalJoint are basically the same (just different values for body.dofCount)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fall through is what will happen, maybe a comment to display intent could help

case ArticulationJointType.SphericalJoint:
// Both RevoluteJoint and SphericalJoint have all angular components.
// We use sine and cosine of the angles for the observations.
totalCount += 2 * body.dofCount;
break;
case ArticulationJointType.FixedJoint:
// Since FixedJoint can't moved, there aren't any interesting observations for it.
break;
case ArticulationJointType.PrismaticJoint:
// One linear component
totalCount += body.dofCount;
break;
}
}

if (settings.UseJointForces)
{
totalCount += body.dofCount;
}

return totalCount;
}

public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
{
if (m_Body == null || m_Body.isRoot)
{
return 0;
}

var currentOffset = offset;

// Write joint positions
if (settings.UseJointPositionsAndAngles)
{
switch (m_Body.jointType)
{
case ArticulationJointType.RevoluteJoint:
case ArticulationJointType.SphericalJoint:
// All joint positions are angular
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
{
var jointRotationRads = m_Body.jointPosition[dofIndex];
writer[currentOffset++] = Mathf.Sin(jointRotationRads);
writer[currentOffset++] = Mathf.Cos(jointRotationRads);
}
break;
case ArticulationJointType.FixedJoint:
// No observations
break;
case ArticulationJointType.PrismaticJoint:
writer[currentOffset++] = GetPrismaticValue();
break;
}
}

if (settings.UseJointForces)
{
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
{
// take tanh to keep in [-1, 1]
writer[currentOffset++] = (float) System.Math.Tanh(m_Body.jointForce[dofIndex]);
}
}

return currentOffset - offset;
}

float GetPrismaticValue()
{
// Prismatic joints should have at most one free axis.
bool limited = false;
var drive = m_Body.xDrive;
if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion)
{
drive = m_Body.xDrive;
limited = true;
}
else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion)
{
drive = m_Body.yDrive;
limited = true;
}
else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion)
{
drive = m_Body.zDrive;
limited = true;
}

var jointPos = m_Body.jointPosition[0];
if (limited)
{
// If locked, interpolate between the limits.
var upperLimit = drive.upperLimit;
var lowerLimit = drive.lowerLimit;
if (upperLimit <= lowerLimit)
{
// Invalid limits (probably equal), so don't try to lerp
return 0;
}
var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos);

// Convert [0, 1] -> [-1, 1]
var normalized = 2.0f * invLerped - 1.0f;
return normalized;
}
// take tanh() to keep in [-1, 1]
return (float) System.Math.Tanh(jointPos);
}
}
}
#endif

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ protected override Pose GetPoseAt(int index)
var t = go.transform;
return new Pose { rotation = t.rotation, position = t.position };
}

internal ArticulationBody[] Bodies => m_Bodies;
}
}
#endif // UNITY_2020_1_OR_NEWER
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,16 @@ public override int[] GetObservationShape()
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
return new[] { numTransformObservations };
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
// Start from i=1 to ignore the root
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
{
numJointObservations += ArticulationBodyJointExtractor.NumObservations(
poseExtractor.Bodies[i], Settings
);
}
return new[] { numPoseObservations + numJointObservations };
}
}

Expand Down
27 changes: 27 additions & 0 deletions com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Interface for generating observations from a physical joint or constraint.
/// </summary>
public interface IJointExtractor
{
/// <summary>
/// Determine the number of observations that would be generated for the particular joint
/// using the provided PhysicsSensorSettings.
/// </summary>
/// <param name="settings"></param>
/// <returns>Number of floats that will be written.</returns>
int NumObservations(PhysicsSensorSettings settings);

/// <summary>
/// Write the observations to the ObservationWriter, starting at the specified offset.
/// </summary>
/// <param name="settings"></param>
/// <param name="writer"></param>
/// <param name="offset"></param>
/// <returns>Number of floats that were written.</returns>
int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset);
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public class PhysicsBodySensor : ISensor
string m_SensorName;

PoseExtractor m_PoseExtractor;
IJointExtractor[] m_JointExtractors;
PhysicsSensorSettings m_Settings;

/// <summary>
Expand All @@ -22,23 +23,59 @@ public class PhysicsBodySensor : ISensor
/// <param name="sensorName"></param>
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
{
m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;

var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
var numJointExtractorObservations = 0;
var rigidBodies = poseExtractor.Bodies;
if (rigidBodies != null)
{
m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root
for (var i = 1; i < rigidBodies.Length; i++)
{
var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
{
m_JointExtractors = new IJointExtractor[0];
}

var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
}

#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
{
m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody);
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;

var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
var numJointExtractorObservations = 0;
var articBodies = poseExtractor.Bodies;
if (articBodies != null)
{
m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root
for (var i = 1; i < articBodies.Length; i++)
{
var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
{
m_JointExtractors = new IJointExtractor[0];
}

var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
}
#endif

Expand All @@ -52,6 +89,10 @@ public int[] GetObservationShape()
public int Write(ObservationWriter writer)
{
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
foreach (var jointExtractor in m_JointExtractors)
{
numWritten += jointExtractor.Write(m_Settings, writer, numWritten);
}
return numWritten;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ public struct PhysicsSensorSettings
/// </summary>
public bool UseLocalSpaceLinearVelocity;

/// <summary>
/// Whether to use joint-specific positions and angles as observations.
/// </summary>
public bool UseJointPositionsAndAngles;

/// <summary>
/// Whether to use the joint forces and torques that are applied by the solver as observations.
/// </summary>
public bool UseJointForces;

/// <summary>
/// Creates a PhysicsSensorSettings with reasonable default values.
/// </summary>
Expand Down Expand Up @@ -68,26 +78,6 @@ public bool UseLocalSpace
{
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
}


/// <summary>
/// The number of floats needed to represent a given number of transforms.
/// </summary>
/// <param name="numTransforms"></param>
/// <returns></returns>
public int TransformSize(int numTransforms)
{
int obsPerTransform = 0;
obsPerTransform += UseModelSpaceTranslations ? 3 : 0;
obsPerTransform += UseModelSpaceRotations ? 4 : 0;
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;

obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0;
obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0;

return numTransforms * obsPerTransform;
}
}

internal static class ObservationWriterPhysicsExtensions
Expand Down
18 changes: 18 additions & 0 deletions com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,24 @@ public void UpdateLocalSpacePoses()
}
}

/// <summary>
/// Compute the number of floats needed to represent the poses for the given PhysicsSensorSettings.
/// </summary>
/// <param name="settings"></param>
/// <returns></returns>
public int GetNumPoseObservations(PhysicsSensorSettings settings)
{
int obsPerPose = 0;
obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0;
obsPerPose += settings.UseModelSpaceRotations ? 4 : 0;
obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0;
obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0;

obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0;
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0;

return NumPoses * obsPerPose;
}

internal void DrawModelSpace(Vector3 offset)
{
Expand Down
Loading