/
server.go
309 lines (275 loc) · 10.2 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
/* Copyright (c) 2019 Snowflake Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
*/
// Package server provides the server-side implementation of the
// sansshell proxy server.
package server
import (
"context"
"fmt"
"io"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
pb "github.com/Snowflake-Labs/sansshell/proxy"
"github.com/Snowflake-Labs/sansshell/telemetry/metrics"
)
// Metrics
var (
proxyReplyErrorCounter = metrics.MetricDefinition{Name: "proxy_reply_error",
Description: "number of failure when sending reply to the client"}
proxyDispatchUnknownReqtypeCounter = metrics.MetricDefinition{Name: "proxy_dispatch_unknown_reqtype",
Description: "number of dispatch failure due to unknown proxy request type"}
)
// A TargetDialer is used by the proxy server to make connections
// to requested targets
// It encapsulates the various low-level details of making target
// connections (such as client credentials, deadlines, etc) which
// the proxy can use without needing to understand them.
type TargetDialer interface {
DialContext(ctx context.Context, target string, dialOpts ...grpc.DialOption) (ClientConnCloser, error)
}
// ClientConnCloser is a closeable grpc.ClientConnInterface
type ClientConnCloser interface {
grpc.ClientConnInterface
Close() error
}
// an optionsDialer implements TargetDialer using native grpc.Dial
type optionsDialer struct {
opts []grpc.DialOption
}
// See TargetDialer.DialContext
func (o *optionsDialer) DialContext(ctx context.Context, target string, dialOpts ...grpc.DialOption) (ClientConnCloser, error) {
opts := o.opts
opts = append(opts, dialOpts...)
return grpc.DialContext(ctx, target, opts...)
}
// NewDialer creates a new TargetDialer that uses grpc.Dial with the
// supplied DialOptions
func NewDialer(opts ...grpc.DialOption) TargetDialer {
return &optionsDialer{opts: opts}
}
// Server implements proxy.ProxyServer
type Server struct {
// A map of /Package.Service/Method => ServiceMethod
serviceMap map[string]*ServiceMethod
// A dialer for making proxy -> target connections
dialer TargetDialer
// A policy authorizer, for authorizing proxy -> target requests
authorizer *rpcauth.Authorizer
}
// Register registers this server with the given ServiceRegistrar
// (typically a grpc.Server)
func (s *Server) Register(sr grpc.ServiceRegistrar) {
pb.RegisterProxyServer(sr, s)
}
// New creates a new Server which will use the supplied TargetDialer
// for opening new target connections, and the global protobuf
// registry to resolve service methods
// The supplied authorizer is used to authorize requests made
// to targets.
func New(dialer TargetDialer, authorizer *rpcauth.Authorizer) *Server {
return NewWithServiceMap(dialer, authorizer, LoadGlobalServiceMap())
}
// NewWithServiceMap create a new Server using the supplied TargetDialer
// and service map.
// The supplied authorizer is used to authorize requests made
// to targets.
func NewWithServiceMap(dialer TargetDialer, authorizer *rpcauth.Authorizer, serviceMap map[string]*ServiceMethod) *Server {
return &Server{
serviceMap: serviceMap,
dialer: dialer,
authorizer: authorizer,
}
}
// Proxy implements ProxyServer.Proxy to provide a single bidirectional
// stream which manages requests to a set of one or more backend
// target servers
func (s *Server) Proxy(stream pb.Proxy_ProxyServer) error {
requestChan := make(chan *pb.ProxyRequest)
replyChan := make(chan *pb.ProxyReply)
group, ctx := errgroup.WithContext(stream.Context())
// create a new TargetStreamSet to manage the target streams
// associated with this proxy connection
streamSet := NewTargetStreamSet(s.serviceMap, s.dialer, s.authorizer)
// A single go-routine for handling all sends to the reply
// channel
// While a stream can be safely used for both send and receive
// simultaneously, it is not safe for multiple goroutines
// to call "Send" on the same stream
group.Go(func() error {
return send(ctx, replyChan, stream)
})
// A single go-routine for receiving all incoming requests from
// the client
// While a stream can be safely used for both send and receive
// simultaneously, it is not safe for multiple goroutines
// to call "Recv" on the same stream
group.Go(func() error {
// This double-dispatch is necessary because Recv() will block
// until the proxy stream itself is cancelled.
// If dispatch has failed, we need to exit, even though Recv
// is still active
// In this case, any error returned from Recv is safe to discard
// since the errgroup will already contain the correct status
// to return to the client.
errChan := make(chan error)
go func() {
err := receive(ctx, stream, requestChan)
select {
case errChan <- err:
default:
// our parent has exited
}
close(errChan)
}()
select {
case err := <-errChan:
return err
case <-ctx.Done():
return ctx.Err()
}
})
// This dispatching goroutine manages request dispatch to a set of
// active target streams
group.Go(func() error {
// when we finish dispatching, we're done, and will send no further
// messages to the reply channel
// This will signal the Send goroutine to exit
defer close(replyChan)
// Create a derived, cancellable context that we can use to tear
// down all streams in case of error.
ctx, cancel := context.WithCancel(ctx)
// Invoke dispatch to handle incoming requests
err := dispatch(ctx, stream, requestChan, replyChan, streamSet)
// If dispatch returned with an error, we can cancel all
// running streams by cancelling their context.
if err != nil {
cancel()
}
// Wait for running streams to exit.
streamSet.Wait()
cancel()
return err
})
// Final RPC status is the status of the waitgroup
return group.Wait()
}
// send relays messages from `replyChan` to the provided stream
func send(ctx context.Context, replyChan chan *pb.ProxyReply, stream pb.Proxy_ProxyServer) error {
recorder := metrics.RecorderFromContextOrNoop(ctx)
for msg := range replyChan {
if err := stream.Send(msg); err != nil {
recorder.CounterOrLog(ctx, proxyReplyErrorCounter, 1)
return err
}
}
return nil
}
// receive relays incoming messages received from the provided stream to `requestChan`
// until EOF (or other error) is received from the stream, or the supplied context is
// done
func receive(ctx context.Context, stream pb.Proxy_ProxyServer, requestChan chan *pb.ProxyRequest) error {
// Close 'requestChan' when receive returns, since we will
// never receive any additional messages from the client
// This can be used by the dispatching goroutine as a single
// to CloseSend on the target streams
defer close(requestChan)
for {
// Receive from the client stream
// This will block, but can return early
// if the stream context is cancelled
req, err := stream.Recv()
if err == io.EOF {
// On the server, io.EOF indicates that the
// client has issued as CloseSend(), and will
// issue no further requests
// Returning here will close requestChan, which
// we can use as a signal to propogate the CloseSend
// to all running target streams
return nil
}
if err != nil {
return err
}
select {
case requestChan <- req:
case <-ctx.Done():
return ctx.Err()
}
}
}
// dispatch manages incoming requests from `requestChan` by routing them to the supplied stream set
func dispatch(ctx context.Context, stream pb.Proxy_ProxyServer, requestChan chan *pb.ProxyRequest, replyChan chan *pb.ProxyReply, streamSet *TargetStreamSet) error {
// Channel to track streams that have completed and should
// be removed from the stream set
doneChan := make(chan uint64)
recorder := metrics.RecorderFromContextOrNoop(ctx)
var addedPeerToContext bool
for {
select {
case <-ctx.Done():
// Our context has ended. This should propogate automtically
// to all target streams
return ctx.Err()
case closedStream := <-doneChan:
// A stream has closed, and sent its final ServerClose status
// Remove it from the active streams list. Further messages
// received with this stream ID will return an error to the
// client.
streamSet.Remove(closedStream)
case req, ok := <-requestChan:
if !ok {
// The request channel has been closed
// This could occur if the proxy client executes
// a CloseSend(), or Send/Recv() from the client
// stream has failed with an error
// In the latter case, the context cancellation
// should eventually propagate to the target
// streams, and cause them to finish
// In either case, we should let the target streams
// know that no further requests will be arriving
streamSet.ClientCloseAll()
return nil
}
if !addedPeerToContext {
// Peer information might not be properly populated until rpcauth
// evaluates the initial received message, so let's grab fresh
// peer information when we know we've gotten at least one message.
ctx = rpcauth.AddPeerToContext(ctx, rpcauth.PeerInputFromContext(stream.Context()))
addedPeerToContext = true
}
// We have a new request
switch req.Request.(type) {
case *pb.ProxyRequest_StartStream:
if err := streamSet.Add(ctx, req.GetStartStream(), replyChan, doneChan); err != nil {
return err
}
case *pb.ProxyRequest_StreamData:
if err := streamSet.Send(ctx, req.GetStreamData()); err != nil {
return err
}
case *pb.ProxyRequest_ClientCancel:
if err := streamSet.ClientCancel(req.GetClientCancel()); err != nil {
return err
}
case *pb.ProxyRequest_ClientClose:
if err := streamSet.ClientClose(req.GetClientClose()); err != nil {
return err
}
default:
recorder.CounterOrLog(ctx, proxyDispatchUnknownReqtypeCounter, 1)
return fmt.Errorf("unhandled request type %T", req.Request)
}
}
}
}