Skip to content

Commit

Permalink
Moving domain randomization to C# (#4065)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewcoh committed Jun 12, 2020
1 parent 27b85fe commit 20527d1
Show file tree
Hide file tree
Showing 22 changed files with 672 additions and 454 deletions.
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to
### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The Parameter Randomization feature has been refactored to enable sampling of new parameters per episode to improve robustness. The
`resampling-interval` parameter has been removed and the config structure updated. More information [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-ML-Agents.md). (#4065)

### Minor Changes
#### com.unity.ml-agents (C#)
Expand Down
70 changes: 70 additions & 0 deletions com.unity.ml-agents/Runtime/Sampler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using System;
using System.Collections.Generic;
using Unity.MLAgents.Inference.Utils;
using UnityEngine;
using Random=System.Random;

namespace Unity.MLAgents
{

/// <summary>
/// Takes a list of floats that encode a sampling distribution and returns the sampling function.
/// </summary>
internal static class SamplerFactory
{

public static Func<float> CreateUniformSampler(float min, float max, int seed)
{
Random distr = new Random(seed);
return () => min + (float)distr.NextDouble() * (max - min);
}

public static Func<float> CreateGaussianSampler(float mean, float stddev, int seed)
{
RandomNormal distr = new RandomNormal(seed, mean, stddev);
return () => (float)distr.NextDouble();
}

public static Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed)
{
//RNG
Random distr = new Random(seed);
// Will be used to normalize intervalFuncs
float sumIntervalSizes = 0;
//The number of intervals
int numIntervals = (int)(intervals.Count/2);
// List that will store interval lengths
float[] intervalSizes = new float[numIntervals];
// List that will store uniform distributions
IList<Func<float>> intervalFuncs = new Func<float>[numIntervals];
// Collect all intervals and store as uniform distrus
// Collect all interval sizes
for(int i = 0; i < numIntervals; i++)
{
var min = intervals[2 * i];
var max = intervals[2 * i + 1];
var intervalSize = max - min;
sumIntervalSizes += intervalSize;
intervalSizes[i] = intervalSize;
intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize;
}
// Normalize interval lengths
for(int i = 0; i < numIntervals; i++)
{
intervalSizes[i] = intervalSizes[i] / sumIntervalSizes;
}
// Build cmf for intervals
for(int i = 1; i < numIntervals; i++)
{
intervalSizes[i] += intervalSizes[i - 1];
}
Multinomial intervalDistr = new Multinomial(seed + 1);
float MultiRange()
{
int sampledInterval = intervalDistr.Sample(intervalSizes);
return intervalFuncs[sampledInterval].Invoke();
}
return MultiRange;
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/Sampler.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,30 @@ namespace Unity.MLAgents.SideChannels
/// </summary>
internal enum EnvironmentDataTypes
{
Float = 0
Float = 0,
Sampler = 1
}

/// <summary>
/// The types of distributions from which to sample reset parameters.
/// </summary>
internal enum SamplerType
{
/// <summary>
/// Samples a reset parameter from a uniform distribution.
/// </summary>
Uniform = 0,

/// <summary>
/// Samples a reset parameter from a Gaussian distribution.
/// </summary>
Gaussian = 1,

/// <summary>
/// Samples a reset parameter from a MultiRangeUniform distribution.
/// </summary>
MultiRangeUniform = 2

}

/// <summary>
Expand All @@ -18,7 +41,7 @@ internal enum EnvironmentDataTypes
/// </summary>
internal class EnvironmentParametersChannel : SideChannel
{
Dictionary<string, float> m_Parameters = new Dictionary<string, float>();
Dictionary<string, Func<float>> m_Parameters = new Dictionary<string, Func<float>>();
Dictionary<string, Action<float>> m_RegisteredActions =
new Dictionary<string, Action<float>>();

Expand All @@ -42,12 +65,40 @@ protected override void OnMessageReceived(IncomingMessage msg)
{
var value = msg.ReadFloat32();

m_Parameters[key] = value;
m_Parameters[key] = () => value;

Action<float> action;
m_RegisteredActions.TryGetValue(key, out action);
action?.Invoke(value);
}
else if ((int)EnvironmentDataTypes.Sampler == type)
{
int seed = msg.ReadInt32();
int samplerType = msg.ReadInt32();
Func<float> sampler = () => 0.0f;
if ((int)SamplerType.Uniform == samplerType)
{
float min = msg.ReadFloat32();
float max = msg.ReadFloat32();
sampler = SamplerFactory.CreateUniformSampler(min, max, seed);
}
else if ((int)SamplerType.Gaussian == samplerType)
{
float mean = msg.ReadFloat32();
float stddev = msg.ReadFloat32();

sampler = SamplerFactory.CreateGaussianSampler(mean, stddev, seed);
}
else if ((int)SamplerType.MultiRangeUniform == samplerType)
{
IList<float> intervals = msg.ReadFloatList();
sampler = SamplerFactory.CreateMultiRangeUniformSampler(intervals, seed);
}
else{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
}
m_Parameters[key] = sampler;
}
else
{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
Expand All @@ -63,9 +114,9 @@ protected override void OnMessageReceived(IncomingMessage msg)
/// <returns></returns>
public float GetWithDefault(string key, float defaultValue)
{
float valueOut;
Func<float> valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut : defaultValue;
return hasKey ? valueOut.Invoke() : defaultValue;
}

/// <summary>
Expand Down
109 changes: 109 additions & 0 deletions com.unity.ml-agents/Tests/Editor/SamplerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using System;
using NUnit.Framework;
using System.IO;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.SideChannels;

namespace Unity.MLAgents.Tests
{
public class SamplerTests
{
const int k_Seed = 1337;
const double k_Epsilon = 0.0001;
EnvironmentParametersChannel m_Channel;

public SamplerTests()
{
m_Channel = SideChannelsManager.GetSideChannel<EnvironmentParametersChannel>();
// if running test on its own
if (m_Channel == null)
{
m_Channel = new EnvironmentParametersChannel();
SideChannelsManager.RegisterSideChannel(m_Channel);
}
}
[Test]
public void UniformSamplerTest()
{
float min_value = 1.0f;
float max_value = 2.0f;
string parameter = "parameter1";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.Uniform);
outgoingMsg.WriteFloat32(min_value);
outgoingMsg.WriteFloat32(max_value);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelsManager.ProcessSideChannelData(message);
}
Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}

[Test]
public void GaussianSamplerTest()
{
float mean = 3.0f;
float stddev = 0.2f;
string parameter = "parameter2";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.Gaussian);
outgoingMsg.WriteFloat32(mean);
outgoingMsg.WriteFloat32(stddev);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelsManager.ProcessSideChannelData(message);
}
Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}

[Test]
public void MultiRangeUniformSamplerTest()
{
float[] intervals = new float[4];
intervals[0] = 1.2f;
intervals[1] = 2f;
intervals[2] = 3.2f;
intervals[3] = 4.1f;
string parameter = "parameter3";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform);
outgoingMsg.WriteFloatList(intervals);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelsManager.ProcessSideChannelData(message);
}
Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}

internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg)
{
byte[] message = msg.ToByteArray();
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
binaryWriter.Write(sideChannel.ChannelId.ToByteArray());
binaryWriter.Write(message.Length);
binaryWriter.Write(message);
}
return memStream.ToArray();
}
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Tests/Editor/SamplerTests.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 8 additions & 11 deletions config/ppo/3DBall_randomize.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,13 @@ behaviors:
threaded: true

parameter_randomization:
resampling-interval: 5000
mass:
sampler-type: uniform
min_value: 0.5
max_value: 10
gravity:
sampler-type: uniform
min_value: 7
max_value: 12
sampler_type: uniform
sampler_parameters:
min_value: 0.5
max_value: 10
scale:
sampler-type: uniform
min_value: 0.75
max_value: 3
sampler_type: uniform
sampler_parameters:
min_value: 0.75
max_value: 3

0 comments on commit 20527d1

Please sign in to comment.