From 71b3a29c3f1ab83f978d7b959f2ec9ffc503f9dc Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Tue, 9 Oct 2018 14:23:56 -0400 Subject: [PATCH] Expose ActionProbabilities list from RankingResponse --- .../bindings/cs/rl.net.cli/RLSimulator.cs | 20 ++++++- .../rl.net.native/rl.net.ranking_response.cc | 17 +++++- .../rl.net.native/rl.net.ranking_response.h | 1 + .../bindings/cs/rl.net/RankingResponse.cs | 52 ++++++++++++++----- 4 files changed, 74 insertions(+), 16 deletions(-) diff --git a/reinforcement_learning/bindings/cs/rl.net.cli/RLSimulator.cs b/reinforcement_learning/bindings/cs/rl.net.cli/RLSimulator.cs index 4d569f5a2ce..ab0f3651bbd 100644 --- a/reinforcement_learning/bindings/cs/rl.net.cli/RLSimulator.cs +++ b/reinforcement_learning/bindings/cs/rl.net.cli/RLSimulator.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Text; using Rl.Net; namespace Rl.Net.Cli { @@ -11,6 +12,23 @@ public enum Topic : long MachineLearning } + internal static class RankingResponseExtensions + { + public static string ToDistributionString(this RankingResponse response) + { + StringBuilder stringBuilder = new StringBuilder("("); + + foreach (ActionProbability actionProbability in response) + { + stringBuilder.Append($"[{actionProbability.ActionIndex}, {actionProbability.Probability}]"); + } + + stringBuilder.Append(')'); + + return stringBuilder.ToString(); + } + } + internal class RLSimulator { public static readonly Random RandomSource = new Random(); @@ -75,7 +93,7 @@ private void Step() // TODO: Record stats this.stats.Record(person, actionTopic, outcome); - Console.WriteLine($" {this.stats.TotalActions}, ctxt, {person.Id}, action, {actionTopic}, outcome, {outcome}, dist, {"" /*todo*/}, {this.stats.GetStats(person, actionTopic)}"); + Console.WriteLine($" {this.stats.TotalActions}, ctxt, {person.Id}, action, {actionTopic}, outcome, {outcome}, dist, {responseContainer.ToDistributionString()}, {this.stats.GetStats(person, actionTopic)}"); } private void SafeRaiseError(ApiStatus errorStatus) diff --git a/reinforcement_learning/bindings/cs/rl.net.native/rl.net.ranking_response.cc b/reinforcement_learning/bindings/cs/rl.net.native/rl.net.ranking_response.cc index 5c5ffb374fa..23823ddf272 100644 --- a/reinforcement_learning/bindings/cs/rl.net.native/rl.net.ranking_response.cc +++ b/reinforcement_learning/bindings/cs/rl.net.native/rl.net.ranking_response.cc @@ -11,6 +11,16 @@ class ranking_enumerator_adapter {} public: + inline int check_current() + { + if (this->current != this->end) + { + return 1; + } + + return 0; + } + inline int move_next() { if ((this->current != this->end) && @@ -68,9 +78,14 @@ API void DeleteRankingEnumeratorAdapter(ranking_enumerator_adapter* adapter) delete adapter; } +API int RankingEnumeratorInit(ranking_enumerator_adapter* adapter) +{ + return adapter->check_current(); +} + API int RankingEnumeratorMoveNext(ranking_enumerator_adapter* adapter) { - return adapter->move_next(); + return adapter->move_next(); } API reinforcement_learning::action_prob GetRankingEnumeratorCurrent(ranking_enumerator_adapter* adapter) diff --git a/reinforcement_learning/bindings/cs/rl.net.native/rl.net.ranking_response.h b/reinforcement_learning/bindings/cs/rl.net.native/rl.net.ranking_response.h index c0983861ec6..64cd68e3801 100644 --- a/reinforcement_learning/bindings/cs/rl.net.native/rl.net.ranking_response.h +++ b/reinforcement_learning/bindings/cs/rl.net.native/rl.net.ranking_response.h @@ -24,6 +24,7 @@ extern "C" { API ranking_enumerator_adapter* CreateRankingEnumeratorAdapter(reinforcement_learning::ranking_response* ranking); API void DeleteRankingEnumeratorAdapter(ranking_enumerator_adapter* adapter); + API int RankingEnumeratorInit(ranking_enumerator_adapter* adapter); API int RankingEnumeratorMoveNext(ranking_enumerator_adapter* adapter); API reinforcement_learning::action_prob GetRankingEnumeratorCurrent(ranking_enumerator_adapter* adapter); } \ No newline at end of file diff --git a/reinforcement_learning/bindings/cs/rl.net/RankingResponse.cs b/reinforcement_learning/bindings/cs/rl.net/RankingResponse.cs index 759f916fbf5..9c81e0cde14 100644 --- a/reinforcement_learning/bindings/cs/rl.net/RankingResponse.cs +++ b/reinforcement_learning/bindings/cs/rl.net/RankingResponse.cs @@ -5,21 +5,21 @@ using System.Runtime.InteropServices; using Rl.Net.Native; +using System.Collections; namespace Rl.Net { [StructLayout(LayoutKind.Sequential)] - internal struct ActionProbability + public struct ActionProbability { - public UIntPtr ActionId; // If we expose this publicly, this will not be CLS-compliant - // No idea if that could cause issues going to .NET Core (probably - // not, but this is something we should check.), but having to do - // a conversion on every iteration feels very heavyweight. Can - // we change this contract to be defined as the signed version of - // the pointer type? - public float Probability; + private UIntPtr actionIndex; + private float probability; + + public long ActionIndex => (long)this.actionIndex.ToUInt64(); + + public float Probability => this.probability; } - public sealed class RankingResponse: NativeObject + public sealed class RankingResponse: NativeObject, IEnumerable { [DllImport("rl.net.native.dll")] private static extern IntPtr CreateRankingResponse(); @@ -63,9 +63,9 @@ public long Count } // TODO: Why does this method call, which seems like a "get" of a value, have an API status? - public bool TryGetChosenAction(out long action, ApiStatus status = null) + public bool TryGetChosenAction(out long actionIndex, ApiStatus status = null) { - action = -1; + actionIndex = -1; UIntPtr chosenAction; int result = GetRankingChosenAction(this.NativeHandle, out chosenAction, status.ToNativeHandleOrNullptr()); @@ -74,10 +74,20 @@ public bool TryGetChosenAction(out long action, ApiStatus status = null) return false; } - action = (long)(chosenAction.ToUInt64()); + actionIndex = (long)(chosenAction.ToUInt64()); return true; } + public IEnumerator GetEnumerator() + { + return new RankingResponseEnumerator(this); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + private class RankingResponseEnumerator : NativeObject, IEnumerator { [DllImport("rl.net.native.dll")] @@ -91,12 +101,17 @@ private static New BindConstructorArguments(RankingRe [DllImport("rl.net.native.dll")] private static extern void DeleteRankingEnumeratorAdapter(IntPtr rankingEnumeratorAdapter); + [DllImport("rl.net.native.dll")] + private static extern int RankingEnumeratorInit(IntPtr rankingEnumeratorAdapter); + [DllImport("rl.net.native.dll")] private static extern int RankingEnumeratorMoveNext(IntPtr rankingEnumeratorAdapter); [DllImport("rl.net.native.dll")] private static extern ActionProbability GetRankingEnumeratorCurrent(IntPtr rankingEnumeratorAdapter); + private bool initialState = true; + public RankingResponseEnumerator(RankingResponse rankingResponse) : base(BindConstructorArguments(rankingResponse), new Delete(DeleteRankingEnumeratorAdapter)) { } @@ -113,9 +128,18 @@ public ActionProbability Current public bool MoveNext() { - // The contract of result is to return 1 if true, 0 if false. - int result = RankingEnumeratorMoveNext(this.NativeHandle); + int result; + if (this.initialState) + { + this.initialState = false; + result = RankingEnumeratorInit(this.NativeHandle); + } + else + { + result = RankingEnumeratorMoveNext(this.NativeHandle); + } + // The contract of result is to return 1 if true, 0 if false. return result == 1; }