Skip to content

Commit

Permalink
Expose ActionProbabilities list from RankingResponse
Browse files Browse the repository at this point in the history
  • Loading branch information
lokitoth committed Oct 9, 2018
1 parent bafdf6e commit 71b3a29
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 16 deletions.
20 changes: 19 additions & 1 deletion reinforcement_learning/bindings/cs/rl.net.cli/RLSimulator.cs
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Rl.Net;

namespace Rl.Net.Cli {
Expand All @@ -11,6 +12,23 @@ public enum Topic : long
MachineLearning
}

internal static class RankingResponseExtensions
{
public static string ToDistributionString(this RankingResponse response)
{
StringBuilder stringBuilder = new StringBuilder("(");

foreach (ActionProbability actionProbability in response)
{
stringBuilder.Append($"[{actionProbability.ActionIndex}, {actionProbability.Probability}]");
}

stringBuilder.Append(')');

return stringBuilder.ToString();
}
}

internal class RLSimulator
{
public static readonly Random RandomSource = new Random();
Expand Down Expand Up @@ -75,7 +93,7 @@ private void Step()
// TODO: Record stats
this.stats.Record(person, actionTopic, outcome);

Console.WriteLine($" {this.stats.TotalActions}, ctxt, {person.Id}, action, {actionTopic}, outcome, {outcome}, dist, {"" /*todo*/}, {this.stats.GetStats(person, actionTopic)}");
Console.WriteLine($" {this.stats.TotalActions}, ctxt, {person.Id}, action, {actionTopic}, outcome, {outcome}, dist, {responseContainer.ToDistributionString()}, {this.stats.GetStats(person, actionTopic)}");
}

private void SafeRaiseError(ApiStatus errorStatus)
Expand Down
Expand Up @@ -11,6 +11,16 @@ class ranking_enumerator_adapter
{}

public:
inline int check_current()
{
if (this->current != this->end)
{
return 1;
}

return 0;
}

inline int move_next()
{
if ((this->current != this->end) &&
Expand Down Expand Up @@ -68,9 +78,14 @@ API void DeleteRankingEnumeratorAdapter(ranking_enumerator_adapter* adapter)
delete adapter;
}

API int RankingEnumeratorInit(ranking_enumerator_adapter* adapter)
{
return adapter->check_current();
}

API int RankingEnumeratorMoveNext(ranking_enumerator_adapter* adapter)
{
return adapter->move_next();
return adapter->move_next();
}

API reinforcement_learning::action_prob GetRankingEnumeratorCurrent(ranking_enumerator_adapter* adapter)
Expand Down
Expand Up @@ -24,6 +24,7 @@ extern "C" {
API ranking_enumerator_adapter* CreateRankingEnumeratorAdapter(reinforcement_learning::ranking_response* ranking);
API void DeleteRankingEnumeratorAdapter(ranking_enumerator_adapter* adapter);

API int RankingEnumeratorInit(ranking_enumerator_adapter* adapter);
API int RankingEnumeratorMoveNext(ranking_enumerator_adapter* adapter);
API reinforcement_learning::action_prob GetRankingEnumeratorCurrent(ranking_enumerator_adapter* adapter);
}
52 changes: 38 additions & 14 deletions reinforcement_learning/bindings/cs/rl.net/RankingResponse.cs
Expand Up @@ -5,21 +5,21 @@
using System.Runtime.InteropServices;

using Rl.Net.Native;
using System.Collections;

namespace Rl.Net {
[StructLayout(LayoutKind.Sequential)]
internal struct ActionProbability
public struct ActionProbability
{
public UIntPtr ActionId; // If we expose this publicly, this will not be CLS-compliant
// No idea if that could cause issues going to .NET Core (probably
// not, but this is something we should check.), but having to do
// a conversion on every iteration feels very heavyweight. Can
// we change this contract to be defined as the signed version of
// the pointer type?
public float Probability;
private UIntPtr actionIndex;
private float probability;

public long ActionIndex => (long)this.actionIndex.ToUInt64();

public float Probability => this.probability;
}

public sealed class RankingResponse: NativeObject<RankingResponse>
public sealed class RankingResponse: NativeObject<RankingResponse>, IEnumerable<ActionProbability>
{
[DllImport("rl.net.native.dll")]
private static extern IntPtr CreateRankingResponse();
Expand Down Expand Up @@ -63,9 +63,9 @@ public long Count
}

// TODO: Why does this method call, which seems like a "get" of a value, have an API status?
public bool TryGetChosenAction(out long action, ApiStatus status = null)
public bool TryGetChosenAction(out long actionIndex, ApiStatus status = null)
{
action = -1;
actionIndex = -1;
UIntPtr chosenAction;
int result = GetRankingChosenAction(this.NativeHandle, out chosenAction, status.ToNativeHandleOrNullptr());

Expand All @@ -74,10 +74,20 @@ public bool TryGetChosenAction(out long action, ApiStatus status = null)
return false;
}

action = (long)(chosenAction.ToUInt64());
actionIndex = (long)(chosenAction.ToUInt64());
return true;
}

public IEnumerator<ActionProbability> GetEnumerator()
{
return new RankingResponseEnumerator(this);
}

IEnumerator IEnumerable.GetEnumerator()
{
return this.GetEnumerator();
}

private class RankingResponseEnumerator : NativeObject<RankingResponseEnumerator>, IEnumerator<ActionProbability>
{
[DllImport("rl.net.native.dll")]
Expand All @@ -91,12 +101,17 @@ private static New<RankingResponseEnumerator> BindConstructorArguments(RankingRe
[DllImport("rl.net.native.dll")]
private static extern void DeleteRankingEnumeratorAdapter(IntPtr rankingEnumeratorAdapter);

[DllImport("rl.net.native.dll")]
private static extern int RankingEnumeratorInit(IntPtr rankingEnumeratorAdapter);

[DllImport("rl.net.native.dll")]
private static extern int RankingEnumeratorMoveNext(IntPtr rankingEnumeratorAdapter);

[DllImport("rl.net.native.dll")]
private static extern ActionProbability GetRankingEnumeratorCurrent(IntPtr rankingEnumeratorAdapter);

private bool initialState = true;

public RankingResponseEnumerator(RankingResponse rankingResponse) : base(BindConstructorArguments(rankingResponse), new Delete<RankingResponseEnumerator>(DeleteRankingEnumeratorAdapter))
{
}
Expand All @@ -113,9 +128,18 @@ public ActionProbability Current

public bool MoveNext()
{
// The contract of result is to return 1 if true, 0 if false.
int result = RankingEnumeratorMoveNext(this.NativeHandle);
int result;
if (this.initialState)
{
this.initialState = false;
result = RankingEnumeratorInit(this.NativeHandle);
}
else
{
result = RankingEnumeratorMoveNext(this.NativeHandle);
}

// The contract of result is to return 1 if true, 0 if false.
return result == 1;
}

Expand Down

0 comments on commit 71b3a29

Please sign in to comment.