Permalink
Browse files

Merge branch 'master' of github.com:OmerMor/AsyncBridge

  • Loading branch information...
OmerMor committed Apr 13, 2012
2 parents e9fa1aa + 627d325 commit efaf22381fcff4794d91c940ef6a6e46c8399c30
@@ -36,6 +36,7 @@
<ItemGroup>
<Compile Include="AsyncTaskMethodBuilder.cs" />
<Compile Include="AsyncVoidMethodBuilder.cs" />
+ <Compile Include="ConfigurableTaskAwaitable.cs" />
<Compile Include="DelayTask.cs" />
<Compile Include="TaskUtils.cs" />
<Compile Include="VoidTaskResult.cs" />
@@ -0,0 +1,42 @@
+using System.Threading.Tasks;
+
+namespace AsyncBridge
+{
+ /// <summary>
+ /// An awaitable which wraps a class, maybe preventing it from capturing the SynchronizationContext
+ /// </summary>
+ public class ConfigurableTaskAwaitable<T>
+ {
+ private readonly Task<T> m_task;
+ private readonly bool m_useCapturedContext;
+
+ public ConfigurableTaskAwaitable(Task<T> task, bool useCapturedContext)
+ {
+ m_task = task;
+ m_useCapturedContext = useCapturedContext;
+ }
+
+ public TaskAwaiter<T> GetAwaiter()
+ {
+ return new TaskAwaiter<T>(m_task, m_useCapturedContext);
+ }
+ }
+
+ // ZOMG why isn't void an actual type
+ public class ConfigurableTaskAwaitable
+ {
+ private readonly Task m_task;
+ private readonly bool m_useCapturedContext;
+
+ public ConfigurableTaskAwaitable(Task task, bool useCapturedContext)
+ {
+ m_task = task;
+ m_useCapturedContext = useCapturedContext;
+ }
+
+ public TaskAwaiter GetAwaiter()
+ {
+ return new TaskAwaiter(m_task, m_useCapturedContext);
+ }
+ }
+}
@@ -5,18 +5,20 @@ namespace System.Threading.Tasks
public struct TaskAwaiter : INotifyCompletion
{
private readonly Task m_task;
+ private readonly bool m_useCapturedContext;
- internal TaskAwaiter(Task task)
+ internal TaskAwaiter(Task task, bool useCapturedContext = true)
{
m_task = task;
+ m_useCapturedContext = useCapturedContext;
}
internal static TaskScheduler TaskScheduler
{
get
{
var taskScheduler = SynchronizationContext.Current == null
- ? TaskScheduler.Default
+ ? TaskScheduler.Current
: TaskScheduler.FromCurrentSynchronizationContext();
return taskScheduler;
}
@@ -30,7 +32,9 @@ public bool IsCompleted
public void OnCompleted(Action continuation)
{
m_task.ContinueWith(
- delegate { continuation(); }, TaskScheduler);
+ delegate { continuation(); },
+ // I don't think continuing on the thread pool is what people really wanted when they called ConfigureAwait, but it's what the CTP did
+ m_useCapturedContext ? TaskScheduler : TaskScheduler.Default);
}
public void GetResult()
@@ -49,10 +53,12 @@ public void GetResult()
public struct TaskAwaiter<T> : INotifyCompletion
{
private readonly Task<T> m_task;
+ private readonly bool m_useCapturedContext;
- internal TaskAwaiter(Task<T> task)
+ public TaskAwaiter(Task<T> task, bool useCapturedContext = true)
{
m_task = task;
+ m_useCapturedContext = useCapturedContext;
}
public bool IsCompleted
@@ -63,7 +69,9 @@ public bool IsCompleted
public void OnCompleted(Action continuation)
{
m_task.ContinueWith(
- delegate { continuation(); }, TaskAwaiter.TaskScheduler);
+ delegate { continuation(); },
+ // I don't think continuing on the thread pool is what people really wanted when they called ConfigureAwait, but it's what the CTP did
+ m_useCapturedContext ? TaskAwaiter.TaskScheduler : TaskScheduler.Default);
}
public T GetResult()
@@ -56,6 +56,16 @@ public static YieldAwaitable Yield()
return new YieldAwaitable((object)SynchronizationContext.Current ?? TaskScheduler.Current);
}
+ public static ConfigurableTaskAwaitable<T> ConfigureAwait<T>(this Task<T> original, bool continueOnCapturedContext)
+ {
+ return new ConfigurableTaskAwaitable<T>(original, continueOnCapturedContext);
+ }
+
+ public static ConfigurableTaskAwaitable ConfigureAwait(this Task original, bool continueOnCapturedContext)
+ {
+ return new ConfigurableTaskAwaitable(original, continueOnCapturedContext);
+ }
+
// Methods which are implemented in terms of TaskFactory
public static Task<T[]> WhenAll<T>(params Task<T>[] tasks)
{
@@ -21,7 +21,7 @@ public YieldAwaiter GetAwaiter()
}
[StructLayout(LayoutKind.Sequential, Size = 1)]
- public struct YieldAwaiter : ICriticalNotifyCompletion
+ public struct YieldAwaiter : ICriticalNotifyCompletion, INotifyCompletion
{
private static readonly WaitCallback s_waitCallbackRunAction = runAction;
private static readonly SendOrPostCallback s_sendOrPostCallbackRunAction =
@@ -43,7 +43,13 @@
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
- <Compile Include="..\AsyncBridge.Tests\*.cs" />
+ <Compile Include="..\AsyncBridge.Tests\DelayTest.cs" />
+ <Compile Include="..\AsyncBridge.Tests\SyncContextTests.cs" />
+ <Compile Include="..\AsyncBridge.Tests\Test.cs" />
+ <Compile Include="..\AsyncBridge.Tests\WhenAllTests.cs" />
+ <Compile Include="..\AsyncBridge.Tests\WhenAnyTests.cs" />
+ <Link>SyncContextTests.cs</Link>
+ </Compile>
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
<ItemGroup>
@@ -43,6 +43,7 @@
</ItemGroup>
<ItemGroup>
<Compile Include="DelayTest.cs" />
+ <Compile Include="SyncContextTests.cs" />
<Compile Include="Test.cs" />
<Compile Include="WhenAllTests.cs" />
<Compile Include="WhenAnyTests.cs" />
@@ -0,0 +1,100 @@
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace AsyncBridge.Tests
+{
+ [TestClass]
+ public class SyncContextTests
+ {
+ class MagicSynchronizationContext : SynchronizationContext
+ {
+ public static readonly MagicSynchronizationContext Instance = new MagicSynchronizationContext();
+
+ public override void Post(SendOrPostCallback d, object state)
+ {
+ base.Post(o =>
+ {
+ SetSynchronizationContext(this);
+ d(o);
+ }, state);
+ }
+ }
+
+ [TestMethod]
+ public async Task YieldSyncContext()
+ {
+ SynchronizationContext.SetSynchronizationContext(MagicSynchronizationContext.Instance);
+ await TaskUtils.Yield();
+ Assert.IsTrue(SynchronizationContext.Current is MagicSynchronizationContext);
+ }
+
+ [TestMethod]
+ public async Task FromResultSyncContext()
+ {
+ SynchronizationContext.SetSynchronizationContext(MagicSynchronizationContext.Instance);
+ int r = await TaskUtils.FromResult(4);
+ Assert.IsTrue(SynchronizationContext.Current is MagicSynchronizationContext);
+ Assert.AreEqual(4, r);
+ }
+
+ [TestMethod]
+ public async Task DelaySyncContext()
+ {
+ SynchronizationContext.SetSynchronizationContext(MagicSynchronizationContext.Instance);
+ await TaskUtils.Delay(1);
+ Assert.IsTrue(SynchronizationContext.Current is MagicSynchronizationContext);
+ }
+
+ [TestMethod]
+ public async Task SimpleTaskSyncContext()
+ {
+ SynchronizationContext.SetSynchronizationContext(MagicSynchronizationContext.Instance);
+ await WaitABit();
+ Assert.IsTrue(SynchronizationContext.Current is MagicSynchronizationContext);
+ }
+
+ [TestMethod]
+ public async Task ReturningTaskSyncContext()
+ {
+ SynchronizationContext.SetSynchronizationContext(MagicSynchronizationContext.Instance);
+ int r = await WaitAThing();
+ Assert.IsTrue(SynchronizationContext.Current is MagicSynchronizationContext);
+ Assert.AreEqual(6, r);
+ }
+
+ [TestMethod]
+ public async Task ConfiguredSimpleTaskSyncContext()
+ {
+ SynchronizationContext.SetSynchronizationContext(MagicSynchronizationContext.Instance);
+ await WaitABit().ConfigureAwait(false);
+ Assert.IsFalse(SynchronizationContext.Current is MagicSynchronizationContext);
+ }
+
+ [TestMethod]
+ public async Task ConfiguredReturningTaskSyncContext()
+ {
+ SynchronizationContext.SetSynchronizationContext(MagicSynchronizationContext.Instance);
+ int r = await WaitAThing().ConfigureAwait(false);
+ Assert.IsFalse(SynchronizationContext.Current is MagicSynchronizationContext);
+ Assert.AreEqual(6, r);
+ }
+
+ /// <summary>
+ /// Exercise our AsyncTaskMethodBuilder
+ /// </summary>
+ private async Task WaitABit()
+ {
+ await TaskUtils.Delay(1);
+ }
+
+ /// <summary>
+ /// Exercise our AsyncTaskMethodBuilder'1
+ /// </summary>
+ private async Task<int> WaitAThing()
+ {
+ await TaskUtils.Delay(1);
+ return await TaskUtils.FromResult(6);
+ }
+ }
+}

0 comments on commit efaf223

Please sign in to comment.