This repository has been archived by the owner on Dec 19, 2018. It is now read-only.
/
ClientHandler.cs
132 lines (115 loc) · 5.14 KB
/
ClientHandler.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Context = Microsoft.AspNetCore.Hosting.Internal.HostingApplication.Context;
namespace Microsoft.AspNetCore.TestHost
{
/// <summary>
/// This adapts HttpRequestMessages to ASP.NET Core requests, dispatches them through the pipeline, and returns the
/// associated HttpResponseMessage.
/// </summary>
public class ClientHandler : HttpMessageHandler
{
private readonly IHttpApplication<Context> _application;
private readonly PathString _pathBase;
/// <summary>
/// Create a new handler.
/// </summary>
/// <param name="pathBase">The base path.</param>
/// <param name="application">The <see cref="IHttpApplication{TContext}"/>.</param>
public ClientHandler(PathString pathBase, IHttpApplication<Context> application)
{
_application = application ?? throw new ArgumentNullException(nameof(application));
// PathString.StartsWithSegments that we use below requires the base path to not end in a slash.
if (pathBase.HasValue && pathBase.Value.EndsWith("/"))
{
pathBase = new PathString(pathBase.Value.Substring(0, pathBase.Value.Length - 1));
}
_pathBase = pathBase;
}
/// <summary>
/// This adapts HttpRequestMessages to ASP.NET Core requests, dispatches them through the pipeline, and returns the
/// associated HttpResponseMessage.
/// </summary>
/// <param name="request"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected override async Task<HttpResponseMessage> SendAsync(
HttpRequestMessage request,
CancellationToken cancellationToken)
{
if (request == null)
{
throw new ArgumentNullException(nameof(request));
}
var contextBuilder = new HttpContextBuilder(_application);
Stream responseBody = null;
var requestContent = request.Content ?? new StreamContent(Stream.Null);
var body = await requestContent.ReadAsStreamAsync();
contextBuilder.Configure(context =>
{
var req = context.Request;
req.Protocol = "HTTP/" + request.Version.ToString(fieldCount: 2);
req.Method = request.Method.ToString();
req.Scheme = request.RequestUri.Scheme;
req.Host = HostString.FromUriComponent(request.RequestUri);
if (request.RequestUri.IsDefaultPort)
{
req.Host = new HostString(req.Host.Host);
}
req.Path = PathString.FromUriComponent(request.RequestUri);
req.PathBase = PathString.Empty;
if (req.Path.StartsWithSegments(_pathBase, out var remainder))
{
req.Path = remainder;
req.PathBase = _pathBase;
}
req.QueryString = QueryString.FromUriComponent(request.RequestUri);
foreach (var header in request.Headers)
{
req.Headers.Append(header.Key, header.Value.ToArray());
}
if (requestContent != null)
{
foreach (var header in requestContent.Headers)
{
req.Headers.Append(header.Key, header.Value.ToArray());
}
}
if (body.CanSeek)
{
// This body may have been consumed before, rewind it.
body.Seek(0, SeekOrigin.Begin);
}
req.Body = body;
responseBody = context.Response.Body;
});
var httpContext = await contextBuilder.SendAsync(cancellationToken);
var response = new HttpResponseMessage();
response.StatusCode = (HttpStatusCode)httpContext.Response.StatusCode;
response.ReasonPhrase = httpContext.Features.Get<IHttpResponseFeature>().ReasonPhrase;
response.RequestMessage = request;
response.Content = new StreamContent(responseBody);
foreach (var header in httpContext.Response.Headers)
{
if (!response.Headers.TryAddWithoutValidation(header.Key, (IEnumerable<string>)header.Value))
{
bool success = response.Content.Headers.TryAddWithoutValidation(header.Key, (IEnumerable<string>)header.Value);
Contract.Assert(success, "Bad header");
}
}
return response;
}
}
}