/
GrpcServiceBase.cs
238 lines (207 loc) · 8.59 KB
/
GrpcServiceBase.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Grpc.Core;
using ProtoBuf.Grpc;
using ProtoBuf.Grpc.Configuration;
using ServiceStack.Grpc;
using ServiceStack.Web;
using ServiceStack.Text;
namespace ServiceStack;
public abstract class GrpcServiceBase : IGrpcService
{
private ServiceStackHost appHost;
private ServiceStackHost AppHost => appHost ??= HostContext.AppHost;
private RpcGateway rpcGateway;
protected RpcGateway RpcGateway => rpcGateway ??= HostContext.AppHost.RpcGateway;
private GrpcFeature? feature;
protected GrpcFeature Feature => feature ??= HostContext.AssertPlugin<GrpcFeature>();
protected async Task WriteResponseHeadersAsync(IResponse httpRes, CallContext context)
{
var res = (GrpcResponse) httpRes;
var nonSuccessStatus = res.StatusCode >= 300;
if (!Feature.DisableResponseHeaders || nonSuccessStatus)
{
foreach (var header in Feature.IgnoreResponseHeaders.Safe())
{
res.Headers.Remove(header);
}
if (res.Headers.Count > 0 || nonSuccessStatus)
{
var headers = new global::Grpc.Core.Metadata();
if (nonSuccessStatus)
headers.Add(Keywords.HttpStatus, res.StatusCode.ToString());
foreach (var entry in res.Headers)
{
headers.Add(entry.Key, entry.Value);
}
if (nonSuccessStatus)
{
var status = res.Dto.GetResponseStatus();
if (status != null)
headers.Add(Keywords.GrpcResponseStatus,
GrpcMarshaller<ResponseStatus>.Instance.Serializer(status));
var desc = status?.ErrorCode ?? res.StatusDescription ??
status?.Message ?? HttpStatus.GetStatusDescription(res.StatusCode);
context.ServerCallContext!.Status = Feature.ToGrpcStatus?.Invoke(httpRes) ?? ToGrpcStatus(res.StatusCode, desc);
}
await context.ServerCallContext!.WriteResponseHeadersAsync(headers).ConfigAwait();
}
}
}
protected virtual Task<TResponse> ExecuteDynamic<TRequest,TResponse>(string method, DynamicRequest request, CallContext context)
{
var requestType = typeof(TRequest);
AppHost.AssertFeatures(ServiceStack.Feature.Grpc);
var to = request.Params.ToObjectDictionary();
var typedRequest = to?.FromObjectDictionary(requestType) ?? requestType.CreateInstance();
if (request.Params != null)
{
foreach (var entry in request.Params)
{
context.RequestHeaders?.Add("query." + entry.Key, entry.Value);
}
}
return Execute<TResponse>(method, typedRequest, context);
}
protected virtual async Task<TResponse> Execute<TResponse>(string method, object request, CallContext context)
{
AppHost.AssertFeatures(ServiceStack.Feature.Grpc);
if (!Feature.DisableRequestParamsInHeaders)
PopulateRequestFromHeaders(request, context.CallOptions.Headers);
var req = new GrpcRequest(context, request, method);
using var scope = req.StartScope();
var ret = await RpcGateway.ExecuteAsync<TResponse>(request, req).ConfigAwait();
req.Response.Dto ??= ret;
await WriteResponseHeadersAsync(req.Response, context).ConfigAwait();
return ret;
}
protected virtual void PopulateRequestFromHeaders(object request, global::Grpc.Core.Metadata headers)
{
if (headers.Count == 0)
return;
var props = TypeProperties.Get(request.GetType());
var to = new Dictionary<string, object>();
foreach (var entry in headers)
{
var key = entry.Key.IndexOf('.') >= 0 && (
entry.Key.StartsWith("query.") ||
entry.Key.StartsWith("form.") ||
entry.Key.StartsWith("cookie.") ||
entry.Key.StartsWith("header."))
? entry.Key.RightPart('.')
: entry.Key;
if (!props.PropertyMap.TryGetValue(key, out var accessor))
continue;
var propName = accessor.PropertyInfo.Name;
to[propName] = !entry.Key.EndsWith("-bin")
? entry.Value
: entry.ValueBytes;
}
if (to.Count > 0)
to.PopulateInstance(request);
}
protected virtual async IAsyncEnumerable<TResponse> Stream<TRequest,TResponse>(TRequest request, CallContext context)
{
AppHost.AssertFeatures(ServiceStack.Feature.Grpc);
if (!Feature.DisableRequestParamsInHeaders)
PopulateRequestFromHeaders(request, context.CallOptions.Headers);
if (!Feature.RequestServiceTypeMap.TryGetValue(typeof(TRequest), out var serviceType))
throw new NotSupportedException($"'{typeof(TRequest).Name}' was not registered in GrpcFeature.RegisterServices");
var service = (IStreamService<TRequest,TResponse>) AppHost.Container.Resolve(serviceType);
using var disposableService = service as IDisposable;
var req = new GrpcRequest(context, request, HttpMethods.Post);
using var scope = req.StartScope();
var res = req.Response;
if (service is IRequiresRequest requiresRequest)
requiresRequest.Request = req;
IAsyncEnumerable<TResponse>? response = default;
try
{
if (AppHost.ApplyPreRequestFilters(req, req.Response))
yield break;
await AppHost.ApplyRequestFiltersAsync(req, res, request).ConfigAwait();
if (res.IsClosed)
yield break;
response = service.Stream(request, context.CancellationToken);
}
catch (Exception e)
{
res.Dto = RpcGateway.CreateErrorResponse<TResponse>(res, e);
await WriteResponseHeadersAsync(res, context).ConfigAwait();
yield break; //written in headers
}
if (response != null)
{
var enumerator = response.GetAsyncEnumerator();
bool more;
try
{
more = await enumerator.MoveNextAsync();
}
catch (Exception e)
{
// catch + handle first Exception
res.Dto = RpcGateway.CreateErrorResponse<TResponse>(res, e);
await WriteResponseHeadersAsync(res, context).ConfigAwait();
await enumerator.DisposeAsync();
yield break; //written in headers
}
yield return enumerator.Current;
while (more)
{
try
{
more = await enumerator.MoveNextAsync();
}
catch (OperationCanceledException)
{
await enumerator.DisposeAsync();
yield break;
}
catch (Exception)
{
await enumerator.DisposeAsync();
yield break;
}
if (more)
yield return enumerator.Current;
}
await enumerator.DisposeAsync();
}
}
protected virtual async IAsyncEnumerable<TResponse> StreamService<TRequest,TResponse>(IStreamService<TRequest,TResponse> service,
TRequest request, [EnumeratorCancellation] CancellationToken cancel)
{
var response = service.Stream(request, cancel);
await foreach (var item in response.WithCancellation(cancel))
{
yield return item;
}
}
protected Status ToGrpcStatus(int httpStatus, string detail)
{
switch (httpStatus)
{
case 400:
return new Status(StatusCode.Internal, detail);
case 401:
return new Status(StatusCode.Unauthenticated, detail);
case 403:
return new Status(StatusCode.PermissionDenied, detail);
case 404:
return new Status(StatusCode.NotFound, detail);
case 409:
return new Status(StatusCode.AlreadyExists, detail);
case 429:
case 502:
case 503:
case 504:
return new Status(StatusCode.Unavailable, detail);
default:
return new Status(StatusCode.Unknown, detail);
}
}
}