Skip to content

Commit

Permalink
Fix #2600 by preserving ExecutionContext and reuse it for subsequent …
Browse files Browse the repository at this point in the history
…binding method invocations (#2658)

* first idea to fix #2600

* Improve idea to work on .NET 4

* Make BindingInvoker work with Task<T>

* fix execution for failing binding methods

* Keep ExecutionContext in ScenarioContext

* Cleanup unit tests
  • Loading branch information
gasparnagy committed Nov 9, 2022
1 parent 697fe33 commit a2567a6
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 39 deletions.
36 changes: 31 additions & 5 deletions TechTalk.SpecFlow/Bindings/AsyncMethodHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@ private static bool IsTask(Type type)

private static bool IsTaskOfT(Type type, out Type typeArg)
{
typeArg = null;
var isTaskOfT = type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>);
if (isTaskOfT)
while (type != null)
{
typeArg = type.GetGenericArguments()[0];
var isTaskOfT = type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>);
if (isTaskOfT)
{
typeArg = type.GetGenericArguments()[0];
if (typeArg != typeof(void) && typeArg.FullName != "System.Threading.Tasks.VoidTaskResult")
return true;
}
type = type.BaseType;
}
return isTaskOfT;

typeArg = null;
return false;
}

private static bool IsValueTask(Type type)
Expand Down Expand Up @@ -100,5 +107,24 @@ public static bool IsAsyncVoid(IBindingMethod bindingMethod)
return IsAsyncVoid(runtimeBindingMethod.MethodInfo);
return false;
}

public static Task<object> ConvertToTaskOfObject(Task task)
{
if (task.GetType() == typeof(Task))
return ConvertTaskOfT(task, false);
if (task is Task<object> taskOfObj)
return taskOfObj;
if (IsTaskOfT(task.GetType(), out _))
return ConvertTaskOfT(task, true);
return ConvertTaskOfT(task, false);
}

// We are allowed to have async method here, because the synchronous part of the task (that might have changed the ExecutionContext)
// has been executed already.
private static async Task<object> ConvertTaskOfT(Task task, bool getValue)
{
await task;
return getValue ? task.GetType().GetProperty(nameof(Task<object>.Result))!.GetValue(task) : null;
}
}
}
53 changes: 47 additions & 6 deletions TechTalk.SpecFlow/Bindings/BindingDelegateInvoker.cs
Original file line number Diff line number Diff line change
@@ -1,28 +1,69 @@
using System;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

namespace TechTalk.SpecFlow.Bindings
{
public class BindingDelegateInvoker : IBindingDelegateInvoker
{
public virtual async Task<object> InvokeDelegateAsync(Delegate bindingDelegate, object[] invokeArgs)
public virtual async Task<object> InvokeDelegateAsync(Delegate bindingDelegate, object[] invokeArgs, ExecutionContextHolder executionContext)
{
// To be able to simulate the behavior of sequential async or sync steps in a test, we need to ensure that
// the next step continues with the ExecutionContext that the previous step finished with.
//
// Without preserving the ExecutionContext this would not happen because the async methods all the way up to the
// generated test method (for example this method) are discarding the ExecutionContext changes at the end, so the
// next step would start with an "empty" ExecutionContext again.
//
// It is important that no methods from here (until the user's binding method) is marked with 'async' otherwise that
// would again discard the ExecutionContext.
//
// The ExecutionContext only flows down, so async binding methods cannot directly change it, but even if all binding method
// is async the constructor of the binding classes are run in sync, so they should be able to change the ExecutionContext.
// (The binding classes are created as part of the 'bindingDelegate' this method receives.

try
{
return await InvokeInExecutionContext(executionContext?.Value, () => CreateDelegateInvocationTask(bindingDelegate, invokeArgs));
}
finally
{
if (executionContext != null)
executionContext.Value = ExecutionContext.Capture();
}
}

private Task<object> InvokeInExecutionContext(ExecutionContext executionContext, Func<Task<object>> callback)
{
if (executionContext == null)
return callback();

Task<object> result = Task.FromResult((object)null);
ExecutionContext.Run(executionContext, _ => { result = callback(); }, null);
return result;
}

// Important: this method MUST NOT be async because that would discard the ExecutionContext changes during execution!
private Task<object> CreateDelegateInvocationTask(Delegate bindingDelegate, object[] invokeArgs)
{
if (AsyncMethodHelper.IsAwaitable(bindingDelegate.Method.ReturnType))
return await InvokeBindingDelegateAsync(bindingDelegate, invokeArgs);
return InvokeBindingDelegateSync(bindingDelegate, invokeArgs);
return InvokeBindingDelegateAsync(bindingDelegate, invokeArgs);
return Task.FromResult(InvokeBindingDelegateSync(bindingDelegate, invokeArgs));
}

protected virtual object InvokeBindingDelegateSync(Delegate bindingDelegate, object[] invokeArgs)
{
return bindingDelegate.DynamicInvoke(invokeArgs);
}

protected virtual async Task<object> InvokeBindingDelegateAsync(Delegate bindingDelegate, object[] invokeArgs)
// Important: this method MUST NOT be async because that would discard the ExecutionContext changes during execution!
private Task<object> InvokeBindingDelegateAsync(Delegate bindingDelegate, object[] invokeArgs)
{
var result = bindingDelegate.DynamicInvoke(invokeArgs);
if (AsyncMethodHelper.IsAwaitableAsTask(result, out var task))
await task;
return result;
return AsyncMethodHelper.ConvertToTaskOfObject(task);
return Task.FromResult(result);
}
}
}
13 changes: 12 additions & 1 deletion TechTalk.SpecFlow/Bindings/BindingInvoker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Tasks;
using TechTalk.SpecFlow.Bindings.Reflection;
using TechTalk.SpecFlow.Compatibility;
Expand Down Expand Up @@ -54,7 +56,8 @@ public virtual async Task<object> InvokeBindingAsync(IBinding binding, IContextM
Array.Copy(arguments, 0, invokeArgs, 1, arguments.Length);
invokeArgs[0] = contextManager;

result = await bindingDelegateInvoker.InvokeDelegateAsync(bindingAction, invokeArgs);
var executionContextHolder = GetExecutionContextHolder(contextManager);
result = await bindingDelegateInvoker.InvokeDelegateAsync(bindingAction, invokeArgs, executionContextHolder);

stopwatch.Stop();
durationHolder.Duration = stopwatch.Elapsed;
Expand Down Expand Up @@ -90,6 +93,14 @@ public virtual async Task<object> InvokeBindingAsync(IBinding binding, IContextM
}
}

private ExecutionContextHolder GetExecutionContextHolder(IContextManager contextManager)
{
var scenarioContext = contextManager.ScenarioContext;
if (scenarioContext == null)
return null;
return scenarioContext.ScenarioContainer.Resolve<ExecutionContextHolder>();
}

protected virtual CultureInfoScope CreateCultureInfoScope(IContextManager contextManager)
{
return new CultureInfoScope(contextManager.FeatureContext);
Expand Down
8 changes: 8 additions & 0 deletions TechTalk.SpecFlow/Bindings/ExecutionContextHolder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using System.Runtime.CompilerServices;
using System.Threading;

namespace TechTalk.SpecFlow.Bindings;

public class ExecutionContextHolder : StrongBox<ExecutionContext>
{
}
2 changes: 1 addition & 1 deletion TechTalk.SpecFlow/Bindings/IBindingDelegateInvoker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ namespace TechTalk.SpecFlow.Bindings
{
public interface IBindingDelegateInvoker
{
Task<object> InvokeDelegateAsync(Delegate bindingDelegate, object[] invokeArgs);
Task<object> InvokeDelegateAsync(Delegate bindingDelegate, object[] invokeArgs, ExecutionContextHolder executionContext);
}
}
2 changes: 1 addition & 1 deletion TechTalk.SpecFlow/Infrastructure/TestExecutionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ private async Task FireEventsAsync(HookType hookType)
{
await InvokeHookAsync(_bindingInvoker, hookBinding, hookType);
}
}
}
catch (Exception hookExceptionCaught)
{
hookException = hookExceptionCaught;
Expand Down
Loading

0 comments on commit a2567a6

Please sign in to comment.