-
Notifications
You must be signed in to change notification settings - Fork 134
/
grpc.go
150 lines (133 loc) · 6.07 KB
/
grpc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package grpc
import (
"crypto/tls"
"fmt"
"net"
"runtime/debug"
"sync"
"time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
_ "google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/armadaerrors"
"github.com/armadaproject/armada/internal/common/auth"
"github.com/armadaproject/armada/internal/common/certs"
"github.com/armadaproject/armada/internal/common/grpc/configuration"
"github.com/armadaproject/armada/internal/common/requestid"
)
// CreateGrpcServer creates a gRPC server (by calling grpc.NewServer) with settings specific to
// this project, and registers services for, e.g., logging and authentication.
func CreateGrpcServer(
keepaliveParams keepalive.ServerParameters,
keepaliveEnforcementPolicy keepalive.EnforcementPolicy,
authServices []auth.AuthService,
tlsConfig configuration.TlsConfig,
logrusOptions ...grpc_logrus.Option,
) *grpc.Server {
// Logging, authentication, etc. are implemented via gRPC interceptors
// (i.e., via functions that are called before handling the actual request).
// There are separate interceptors for unary and streaming gRPC calls.
unaryInterceptors := []grpc.UnaryServerInterceptor{}
streamInterceptors := []grpc.StreamServerInterceptor{}
// Automatically recover from panics
// NOTE This must be the first interceptor, so it can handle panics in any subsequently added interceptor
recovery := grpc_recovery.WithRecoveryHandler(panicRecoveryHandler)
unaryInterceptors = append(unaryInterceptors, grpc_recovery.UnaryServerInterceptor(recovery))
streamInterceptors = append(streamInterceptors, grpc_recovery.StreamServerInterceptor(recovery))
// Logging (using logrus)
// By default, information contained in the request context is logged
// tagsExtractor pulls information out of the request payload (a protobuf) and stores it in
// the context, such that it is logged.
messageDefault := log.NewEntry(log.StandardLogger())
tagsExtractor := grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)
unaryInterceptors = append(unaryInterceptors,
grpc_ctxtags.UnaryServerInterceptor(tagsExtractor),
requestid.UnaryServerInterceptor(false),
armadaerrors.UnaryServerInterceptor(2000),
grpc_logrus.UnaryServerInterceptor(messageDefault, logrusOptions...),
)
streamInterceptors = append(streamInterceptors,
grpc_ctxtags.StreamServerInterceptor(tagsExtractor),
requestid.StreamServerInterceptor(false),
armadaerrors.StreamServerInterceptor(2000),
grpc_logrus.StreamServerInterceptor(messageDefault, logrusOptions...),
)
// Authentication
// The provided authServices represents a list of services that can be used to authenticate
// the client (e.g., username/password and OpenId). authFunction is a combination of these.
authFunction := auth.CreateMiddlewareAuthFunction(authServices)
unaryInterceptors = append(unaryInterceptors, grpc_auth.UnaryServerInterceptor(authFunction))
streamInterceptors = append(streamInterceptors, grpc_auth.StreamServerInterceptor(authFunction))
// Prometheus timeseries collection integration
grpc_prometheus.EnableHandlingTimeHistogram()
unaryInterceptors = append(unaryInterceptors, grpc_prometheus.UnaryServerInterceptor)
streamInterceptors = append(streamInterceptors, grpc_prometheus.StreamServerInterceptor)
serverOptions := []grpc.ServerOption{
grpc.KeepaliveParams(keepaliveParams),
grpc.KeepaliveEnforcementPolicy(keepaliveEnforcementPolicy),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(streamInterceptors...)),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(unaryInterceptors...)),
}
if tlsConfig.Enabled {
cachedCertificateService := certs.NewCachedCertificateService(tlsConfig.CertPath, tlsConfig.KeyPath, time.Minute)
go func() {
cachedCertificateService.Run(armadacontext.Background())
}()
tlsCreds := credentials.NewTLS(&tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert := cachedCertificateService.GetCertificate()
if cert == nil {
return nil, fmt.Errorf("unexpectedly received nil from certificate cache")
}
return cert, nil
},
})
serverOptions = append(serverOptions, grpc.Creds(tlsCreds))
}
// Interceptors are registered at server creation
return grpc.NewServer(serverOptions...)
}
// TODO We don't need this function. Just do this at the caller.
func Listen(port uint16, grpcServer *grpc.Server, wg *sync.WaitGroup) {
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil { // TODO Don't call fatal, return an error.
log.Fatalf("failed to listen: %v", err)
}
go func() {
defer log.Println("Stopping server.")
log.Printf("Grpc listening on %d", port)
if err := grpcServer.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
wg.Done()
}()
}
// CreateShutdownHandler returns a function that shuts down the grpcServer when the context is closed.
// The server is given gracePeriod to perform a graceful showdown and is then forcably stopped if necessary
func CreateShutdownHandler(ctx *armadacontext.Context, gracePeriod time.Duration, grpcServer *grpc.Server) func() error {
return func() error {
<-ctx.Done()
go func() {
time.Sleep(gracePeriod)
grpcServer.Stop()
}()
grpcServer.GracefulStop()
return nil
}
}
// This function is called whenever a gRPC handler panics.
func panicRecoveryHandler(p interface{}) (err error) {
log.Errorf("Request triggered panic with cause %v \n%s", p, string(debug.Stack()))
return status.Errorf(codes.Internal, "Internal server error caused by %v", p)
}