Skip to content

Commit

Permalink
Make AsyncMessagePump and AsyncDelegatePump ExecutionContext aware (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shane32 authored Oct 19, 2023
1 parent d70229a commit 6c72c2e
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 16 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ was posted and it throws an exception, or if the callback throws an exception, t
is handled by the `HandleErrorAsync` protected method, which can be overridden by a user in
a derived class. `DrainAsync` is provided to wait for pending messages to be processed.

Note that the ExecutionContext is captured during the `Post` call and restored while the
callback (and/or the error handler) is executed.

Constructors:

- `AsyncMessagePump(Func<T, Task> callback)`
Expand Down Expand Up @@ -83,6 +86,9 @@ If the timeout expires before the delegate is executed, a `TimeoutException` is
delegate is not executed. Similarly, if the cancellation token is triggered before the delegate
is executed, a `OperationCanceledException` is thrown and the delegate is not executed.

Note that the ExecutionContext is captured during the call to `SendAsync` and restored while
the delegate is executed.

Public methods:

- `void Post(Func<Task> message)`
Expand Down
2 changes: 1 addition & 1 deletion src/AsyncResetEvents/AsyncAutoResetEvent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public void Set(bool backgroundThread = false)
}

// otherwise, we need to run the waiting code on a background thread
// the follwing is equivalent to Task.Run with state
// the following is equivalent to Task.Run with state
_ = Task.Factory.StartNew(
static mre => ((AsyncAutoResetEvent)mre!).Set(false),
this,
Expand Down
79 changes: 64 additions & 15 deletions src/AsyncResetEvents/AsyncMessagePump.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,56 @@ namespace Shane32.AsyncResetEvents;
/// </summary>
public class AsyncMessagePump<T>
{
private struct MessageTuple
private class MessageTuple
{
public T? Value;
public Task<T>? Delegate;
#if NETSTANDARD2_0_OR_GREATER || NET5_0_OR_GREATER
private readonly ExecutionContext? _context = ExecutionContext.Capture(); // only returns null when ExecutionContext.IsFlowSuppressed() == true, not for default/empty contexts
private object? _state;
#endif

public Task ExecuteAsync(Func<T, Task> callback)
{
#if NETSTANDARD2_0_OR_GREATER || NET5_0_OR_GREATER
if (_context != null) {
_state = callback;
ExecutionContext.Run(
_context,
static state => {
var messageTuple = (MessageTuple)state!;
var callback = (Func<T, Task>)messageTuple._state!;
var returnTask = messageTuple.ExecuteInternalAsync(callback);
messageTuple._state = returnTask;
},
this);
var returnTask = (Task)_state!;
_state = null;
return returnTask;
}
#endif
return ExecuteInternalAsync(callback);
}

private Task ExecuteInternalAsync(Func<T, Task> callback)
=> Delegate == null ? callback(Value!) ??
#if NETSTANDARD1_0
DelegateTuple.CompletedTask
#else
Task.CompletedTask
#endif
: ExecuteDelegateAsync(Delegate, callback);
private static async Task ExecuteDelegateAsync(Task<T> executeDelegate, Func<T, Task> callback)
{
var message = await executeDelegate.ConfigureAwait(false);
var callbackTask = callback(message);
if (callbackTask != null)
await callbackTask.ConfigureAwait(false);
}
}
private readonly Func<T, Task> _callback;
private readonly Func<T, Task> _wrappedCallback;
private readonly Queue<MessageTuple> _queue = new();
#if NET5_0_OR_GREATER
private TaskCompletionSource? _drainTask;
Expand All @@ -38,17 +81,31 @@ private struct MessageTuple
/// </summary>
public AsyncMessagePump(Func<T, Task> callback)
{
_callback = callback ?? throw new ArgumentNullException(nameof(callback));
if (callback == null)
throw new ArgumentNullException(nameof(callback));
_wrappedCallback = async obj => {
try {
await callback(obj).ConfigureAwait(false);
} catch (Exception ex) {
// if an error occurs within HandleErrorAsync it will be caught within CompleteAsync
await HandleErrorAsync(ex).ConfigureAwait(false);
}
};
}

/// <summary>
/// Initializes a new instances with the specified synchronous callback delegate.
/// </summary>
public AsyncMessagePump(Action<T> callback)
: this(ConvertCallback(callback))
{
}

private static Func<T, Task> ConvertCallback(Action<T> callback)
{
if (callback == null)
throw new ArgumentNullException(nameof(callback));
_callback = message => {
return message => {
callback(message);
#if NETSTANDARD1_0
return DelegateTuple.CompletedTask;
Expand Down Expand Up @@ -114,19 +171,11 @@ private async void CompleteAsync()
messageTuple = _queue.Peek();
}
while (true) {
// process the message
// process the message (_wrappedCallback contains error handling)
try {
var message = messageTuple.Delegate != null
? await messageTuple.Delegate.ConfigureAwait(false)
: messageTuple.Value!;
var callbackTask = _callback(message);
if (callbackTask != null)
await callbackTask.ConfigureAwait(false);
} catch (Exception ex) {
try {
await HandleErrorAsync(ex).ConfigureAwait(false);
} catch { }
await messageTuple.ExecuteAsync(_wrappedCallback).ConfigureAwait(false);
}
catch { }

// once the message has been passed along, dequeue it
lock (_queue) {
Expand Down
6 changes: 6 additions & 0 deletions src/Tests/AsyncDelegatePumpTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,22 @@ private void Verify(string expected)
[Fact]
public async Task Basic()
{
var al = new AsyncLocal<string?>();
al.Value = "a";
_pump.Post(async () => {
await Task.Delay(100);
Assert.Equal("a", al.Value);
WriteLine("100");
});
al.Value = "b";
_pump.Post(async () => {
Verify("100 ");
await Task.Delay(1);
Assert.Equal("b", al.Value);
WriteLine("1");
_reset.Set();
});
al.Value = null;
Verify("");
await _reset.WaitAsync();
Verify("100 1 ");
Expand Down
82 changes: 82 additions & 0 deletions src/Tests/AsyncMessagePumpTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,88 @@ public void DrainWorks_Synchronous()
Assert.True(task.IsCompleted);
}

[Fact]
public async Task ExecutionContextCopied()
{
var al = new AsyncLocal<string?>();
int success = 0;
int expectedSuccess = 3;
var tcs = new TaskCompletionSource<int>();
var pump = new AsyncMessagePump<string?>(async str => {
await Task.Delay(100);
if (str == al.Value) {
var incremented = Interlocked.Increment(ref success);
if (incremented == expectedSuccess)
tcs.SetResult(incremented);
}
});

al.Value = "1";
pump.Post("1");
al.Value = "2";
pump.Post(Task.FromResult<string?>("2"));
al.Value = "3";
pump.Post("3");
al.Value = null;

await tcs.Task;
}

#if NETCOREAPP2_1_OR_GREATER
[Fact]
public async Task ExecutionContextCopiedFromDefault()
{
Task t;
// erase all execution context
using (ExecutionContext.SuppressFlow()) {
t = Task.Run(async () => {
// now there is no execution context
Assert.True(IsDefaultExecutionContext());
var al = new AsyncLocal<string?>();
int success = 0;
int expectedSuccess = 4;
var tcs = new TaskCompletionSource<int>();
var pump = new AsyncMessagePump<string?>(async str => {
await Task.Delay(100);
if (str == null) {
if (IsDefaultExecutionContext()) {
var incremented = Interlocked.Increment(ref success);
if (incremented == expectedSuccess)
tcs.SetResult(incremented);
}
} else if (str == al.Value) {
var incremented = Interlocked.Increment(ref success);
if (incremented == expectedSuccess)
tcs.SetResult(incremented);
}
});
pump.Post((string?)null);
Assert.True(IsDefaultExecutionContext());
al.Value = "1";
Assert.False(IsDefaultExecutionContext());
pump.Post("1");
al.Value = "2";
pump.Post(Task.FromResult<string?>("2"));
al.Value = "3";
pump.Post("3");
al.Value = null;
await tcs.Task;
});
}
await t;

bool IsDefaultExecutionContext()
{
using var context = ExecutionContext.Capture();
var isDefaultProperty = typeof(ExecutionContext).GetProperty("IsDefault", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic)!;
return (bool)isDefaultProperty.GetValue(context)!;
}
}
#endif

public class DerivedAsyncMessagePump : AsyncMessagePump<string>
{
private readonly StringBuilder _sb;
Expand Down

0 comments on commit 6c72c2e

Please sign in to comment.