Skip to content

Commit

Permalink
Uses matrix maths for observation vectors.
Browse files Browse the repository at this point in the history
  • Loading branch information
RunSwimFlyRich committed Aug 12, 2019
1 parent 74d82a5 commit 18f6616
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public class CrawlerAgent : Agent
bool isNewDecisionStep;
int currentDecisionStep;

private Transform workingTransform;
Quaternion lookRotation;
Matrix4x4 targetDirMatrix;

public override void InitializeAgent()
{
Expand All @@ -64,9 +65,7 @@ public override void InitializeAgent()
jdController.SetupBodyPart(leg2Lower);
jdController.SetupBodyPart(leg3Upper);
jdController.SetupBodyPart(leg3Lower);

workingTransform = new GameObject().transform;
}
}

/// <summary>
/// We only need to change the joint settings based on decision freq.
Expand Down Expand Up @@ -94,8 +93,11 @@ public void CollectObservationBodyPart(BodyPart bp)
var rb = bp.rb;
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground

AddVectorObs(workingTransform.InverseTransformVector(rb.velocity));
AddVectorObs(workingTransform.InverseTransformDirection(rb.angularVelocity));
Vector3 velocityRelativeToLookRotationToTarget = targetDirMatrix.inverse.MultiplyVector(rb.velocity);
AddVectorObs(velocityRelativeToLookRotationToTarget);

Vector3 angularVelocityRelativeToLookRotationToTarget = targetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
AddVectorObs(angularVelocityRelativeToLookRotationToTarget);

if (bp.rb.transform != body)
{
Expand All @@ -113,17 +115,24 @@ public override void CollectObservations()
jdController.GetCurrentJointForces();
// Normalize dir vector to help generalize

workingTransform.rotation = Quaternion.LookRotation(dirToTarget);
lookRotation = Quaternion.LookRotation(dirToTarget);
targetDirMatrix = Matrix4x4.TRS(Vector3.zero, lookRotation, Vector3.one);

// Forward & up to help with orientation
RaycastHit hit;
if (Physics.Raycast(body.position, Vector3.down, out hit, 10.0f))
{
AddVectorObs(hit.distance);
}
else
AddVectorObs(10.0f);
AddVectorObs(workingTransform.InverseTransformVector(body.forward));
AddVectorObs(workingTransform.InverseTransformVector(body.up));
AddVectorObs(10.0f);

Vector3 bodyForwardRelativeToLookRotationToTarget = targetDirMatrix.inverse.MultiplyVector(body.forward);
AddVectorObs(bodyForwardRelativeToLookRotationToTarget);

Vector3 bodyUpRelativeToLookRotationToTarget = targetDirMatrix.inverse.MultiplyVector(body.up);
AddVectorObs(bodyUpRelativeToLookRotationToTarget);

foreach (var bodyPart in jdController.bodyPartsDict.Values)
{
CollectObservationBodyPart(bodyPart);
Expand Down

0 comments on commit 18f6616

Please sign in to comment.