/
server_router.go
129 lines (117 loc) · 4.75 KB
/
server_router.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package grpcmw
import (
"errors"
"fmt"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)
// ServerRouter represents route resolver that allows to use the appropriate
// chain of interceptors for a given gRPC request with an interceptor register.
type ServerRouter interface {
// GetRegister returns the interceptor register of the router.
GetRegister() ServerInterceptorRegister
// SetRegister sets the interceptor register of the router.
SetRegister(reg ServerInterceptorRegister)
// UnaryResolver returns a `grpc.UnaryServerInterceptor` that uses the
// appropriate chain of interceptors with the given unary gRPC request.
UnaryResolver() grpc.UnaryServerInterceptor
// StreamResolver returns a `grpc.StreamServerInterceptor` that uses the
// appropriate chain of interceptors with the given stream gRPC request.
StreamResolver() grpc.StreamServerInterceptor
}
type serverRouter struct {
interceptors ServerInterceptorRegister
}
// NewServerRouter initializes a `ServerRouter`.
// This implementation is based on the official route format used by gRPC as
// defined here :
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
//
// Based on this format, this implementation splits the interceptors into four
// levels:
// - the global level: these are the interceptors called at each request.
// - the package level: these are the interceptors called at each request to
// a service from the corresponding package.
// - the service level: these are the interceptors called at each request to
// a method from the corresponding service.
// - the method level: these are the interceptors called at each request to
// the specific method.
func NewServerRouter() ServerRouter {
return &serverRouter{
interceptors: NewServerInterceptorRegister("global"),
}
}
func resolveServerInterceptorRec(pathTokens []string, lvl ServerInterceptor, cb func(lvl ServerInterceptor), force bool) (ServerInterceptor, error) {
if cb != nil {
cb(lvl)
}
if len(pathTokens) == 0 || len(pathTokens[0]) == 0 {
return lvl, nil
}
reg, ok := lvl.(ServerInterceptorRegister)
if !ok {
return nil, fmt.Errorf("Level %s does not implement grpcmw.ServerInterceptorRegister", lvl.Index())
}
sub, exists := reg.Get(pathTokens[0])
if !exists {
if force {
if len(pathTokens) == 1 {
sub = NewServerInterceptor(pathTokens[0])
} else {
sub = NewServerInterceptorRegister(pathTokens[0])
}
reg.Register(sub)
} else {
return nil, nil
}
}
return resolveServerInterceptorRec(pathTokens[1:], sub, cb, force)
}
func resolveServerInterceptor(route string, lvl ServerInterceptor, cb func(lvl ServerInterceptor), force bool) (ServerInterceptor, error) {
// TODO: Find a more efficient way to resolve the route
matchs := routeRegexp.FindStringSubmatch(route)
if len(matchs) == 0 {
return nil, errors.New("Invalid route")
}
return resolveServerInterceptorRec(matchs[1:], lvl, cb, force)
}
// UnaryResolver returns a `grpc.UnaryServerInterceptor` that uses the
// appropriate chain of interceptors with the given gRPC request.
func (r *serverRouter) UnaryResolver() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// TODO: Find a more efficient way to chain the interceptors
interceptor := NewUnaryServerInterceptor()
_, err := resolveServerInterceptor(info.FullMethod, r.interceptors, func(lvl ServerInterceptor) {
interceptor.AddInterceptor(lvl.UnaryServerInterceptor())
}, false)
if err != nil {
return nil, grpc.Errorf(codes.Internal, err.Error())
}
return interceptor.Interceptor()(ctx, req, info, handler)
}
}
// StreamResolver returns a `grpc.StreamServerInterceptor` that uses the
// appropriate chain of interceptors with the given stream gRPC request.
func (r *serverRouter) StreamResolver() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
// TODO: Find a more efficient way to chain the interceptors
interceptor := NewStreamServerInterceptor()
_, err := resolveServerInterceptor(info.FullMethod, r.interceptors, func(lvl ServerInterceptor) {
interceptor.AddInterceptor(lvl.StreamServerInterceptor())
}, false)
if err != nil {
return grpc.Errorf(codes.Internal, err.Error())
}
return interceptor.Interceptor()(srv, ss, info, handler)
}
}
// GetRegister returns the underlying `ServerInterceptorRegister` which is the
// global level in the interceptor chain.
func (r *serverRouter) GetRegister() ServerInterceptorRegister {
return r.interceptors
}
// SetRegister sets the interceptor register of the router.
func (r *serverRouter) SetRegister(reg ServerInterceptorRegister) {
r.interceptors = reg
}