Skip to content
This repository has been archived by the owner on Dec 19, 2018. It is now read-only.

Commit

Permalink
[Fixes #852] TestHost: OnStarting and OnCompleted callbacks of respon…
Browse files Browse the repository at this point in the history
…se are not being awaited
  • Loading branch information
kichalla committed Sep 15, 2016
1 parent 98e35cc commit b6da89f
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 36 deletions.
25 changes: 14 additions & 11 deletions src/Microsoft.AspNetCore.TestHost/ClientHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ protected override async Task<HttpResponseMessage> SendAsync(
try
{
await _application.ProcessRequestAsync(state.Context);
state.CompleteResponse();
await state.CompleteResponseAsync();
state.ServerCleanup(exception: null);
}
catch (Exception ex)
Expand Down Expand Up @@ -165,7 +165,7 @@ internal RequestState(HttpRequestMessage request, PathString pathBase, IHttpAppl
}
}

_responseStream = new ResponseStream(ReturnResponseMessage, AbortRequest);
_responseStream = new ResponseStream(ReturnResponseMessageAsync, AbortRequest);
httpContext.Response.Body = _responseStream;
httpContext.Response.StatusCode = 200;
httpContext.RequestAborted = _requestAbortedSource.Token;
Expand All @@ -187,27 +187,30 @@ internal void AbortRequest()
_responseStream.Complete();
}

internal void CompleteResponse()
internal async Task CompleteResponseAsync()
{
_pipelineFinished = true;
ReturnResponseMessage();
await ReturnResponseMessageAsync();
_responseStream.Complete();
_responseFeature.FireOnResponseCompleted();
await _responseFeature.FireOnResponseCompletedAsync();
}

internal void ReturnResponseMessage()
internal async Task ReturnResponseMessageAsync()
{
if (!_responseTcs.Task.IsCompleted)
// Check if the response has already started because the TrySetResult below could happen a bit late
// (as it happens on a different thread) by which point the CompleteResponseAsync could run and calls this
// method again.
if (!Context.HttpContext.Response.HasStarted)
{
var response = GenerateResponse();
var response = await GenerateResponseAsync();
// Dispatch, as TrySetResult will synchronously execute the waiters callback and block our Write.
Task.Factory.StartNew(() => _responseTcs.TrySetResult(response));
var setResult = Task.Factory.StartNew(() => _responseTcs.TrySetResult(response));
}
}

private HttpResponseMessage GenerateResponse()
private async Task<HttpResponseMessage> GenerateResponseAsync()
{
_responseFeature.FireOnSendingHeaders();
await _responseFeature.FireOnSendingHeadersAsync();
var httpContext = Context.HttpContext;

var response = new HttpResponseMessage();
Expand Down
34 changes: 20 additions & 14 deletions src/Microsoft.AspNetCore.TestHost/ResponseFeature.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace Microsoft.AspNetCore.TestHost
{
internal class ResponseFeature : IHttpResponseFeature
{
private Action _responseStarting = () => { };
private Action _responseCompleted = () => { };
private Func<Task> _responseStartingAsync = () => Task.FromResult(true);
private Func<Task> _responseCompletedAsync = () => Task.FromResult(true);

public ResponseFeature()
{
Expand All @@ -36,33 +36,39 @@ public ResponseFeature()

public void OnStarting(Func<object, Task> callback, object state)
{
var prior = _responseStarting;
_responseStarting = () =>
var prior = _responseStartingAsync;
_responseStartingAsync = async () =>
{
callback(state);
prior();
await callback(state);
await prior();
};
}

public void OnCompleted(Func<object, Task> callback, object state)
{
var prior = _responseCompleted;
_responseCompleted = () =>
var prior = _responseCompletedAsync;
_responseCompletedAsync = async () =>
{
callback(state);
prior();
try
{
await callback(state);
}
finally
{
await prior();
}
};
}

public void FireOnSendingHeaders()
public async Task FireOnSendingHeadersAsync()
{
_responseStarting();
await _responseStartingAsync();
HasStarted = true;
}

public void FireOnResponseCompleted()
public Task FireOnResponseCompletedAsync()
{
_responseCompleted();
return _responseCompletedAsync();
}
}
}
19 changes: 10 additions & 9 deletions src/Microsoft.AspNetCore.TestHost/ResponseStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,23 @@ internal class ResponseStream : Stream
private TaskCompletionSource<object> _readWaitingForData;
private object _signalReadLock;

private Action _onFirstWrite;
private Func<Task> _onFirstWriteAsync;
private bool _firstWrite;
private Action _abortRequest;

internal ResponseStream(Action onFirstWrite, Action abortRequest)
internal ResponseStream(Func<Task> onFirstWriteAsync, Action abortRequest)
{
if (onFirstWrite == null)
if (onFirstWriteAsync == null)
{
throw new ArgumentNullException(nameof(onFirstWrite));
throw new ArgumentNullException(nameof(onFirstWriteAsync));
}

if (abortRequest == null)
{
throw new ArgumentNullException(nameof(abortRequest));
}

_onFirstWrite = onFirstWrite;
_onFirstWriteAsync = onFirstWriteAsync;
_firstWrite = true;
_abortRequest = abortRequest;

Expand Down Expand Up @@ -98,7 +98,7 @@ public override void Flush()
_writeLock.Wait();
try
{
FirstWrite();
FirstWriteAsync().GetAwaiter().GetResult();
}
finally
{
Expand Down Expand Up @@ -230,13 +230,14 @@ public async override Task<int> ReadAsync(byte[] buffer, int offset, int count,
}

// Called under write-lock.
private void FirstWrite()
private Task FirstWriteAsync()
{
if (_firstWrite)
{
_firstWrite = false;
_onFirstWrite();
return _onFirstWriteAsync();
}
return Task.FromResult(true);
}

// Write with count 0 will still trigger OnFirstWrite
Expand All @@ -248,7 +249,7 @@ public override void Write(byte[] buffer, int offset, int count)
_writeLock.Wait();
try
{
FirstWrite();
FirstWriteAsync().GetAwaiter().GetResult();
if (count == 0)
{
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Threading.Tasks;
using Xunit;

namespace Microsoft.AspNetCore.TestHost
{
public class ResponseFeatureTests
{
[Fact]
public void StatusCode_DefaultsTo200()
public async Task StatusCode_DefaultsTo200()
{
// Arrange & Act
var responseInformation = new ResponseFeature();
Expand All @@ -17,7 +18,7 @@ public void StatusCode_DefaultsTo200()
Assert.Equal(200, responseInformation.StatusCode);
Assert.False(responseInformation.HasStarted);

responseInformation.FireOnSendingHeaders();
await responseInformation.FireOnSendingHeadersAsync();

Assert.True(responseInformation.HasStarted);
}
Expand Down

0 comments on commit b6da89f

Please sign in to comment.