forked from improbable-eng/grpc-web
/
main.go
124 lines (109 loc) · 4.08 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package main
import (
"fmt"
"log"
"net"
"net/http"
_ "net/http/pprof" // register in DefaultServerMux
"os"
"time"
"crypto/tls"
"github.com/sirupsen/logrus"
"github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
"github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/mwitkow/go-conntrack"
"github.com/mwitkow/grpc-proxy/proxy"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/pflag"
"golang.org/x/net/context"
_ "golang.org/x/net/trace" // register in DefaultServerMux
"google.golang.org/grpc"
)
var (
flagBindAddr = pflag.String("server_bind_address", "0.0.0.0", "address to bind the server to")
flagHttpPort = pflag.Int("server_http_debug_port", 8080, "TCP port to listen on for HTTP1.1 debug calls. If 0, no insecure HTTP will be open.")
flagHttpTlsPort = pflag.Int("server_http_tls_port", 8443, "TCP port to listen on for HTTPS (gRPC, gRPC-Web). If 0, no TLS will be open.")
flagHttpMaxWriteTimeout = pflag.Duration("server_http_max_write_timeout", 10*time.Second, "HTTP server config, max write duration.")
flagHttpMaxReadTimeout = pflag.Duration("server_http_max_read_timeout", 10*time.Second, "HTTP server config, max read duration.")
)
func main() {
pflag.Parse()
serverTls := buildServerTlsOrFail()
logrus.SetOutput(os.Stdout)
logEntry := logrus.NewEntry(logrus.StandardLogger())
grpcServer := buildGrpcProxyServer(logEntry)
errChan := make(chan error)
// gRPC-Web compatibility layer with CORS configured to accept on every
wrappedGrpc := grpcweb.WrapServer(grpcServer, grpcweb.WithCorsForRegisteredEndpointsOnly(false))
// Debug server.
debugServer := http.Server{
WriteTimeout: *flagHttpMaxWriteTimeout,
ReadTimeout: *flagHttpMaxReadTimeout,
Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
wrappedGrpc.ServeHTTP(resp, req)
}),
}
http.Handle("/metrics", promhttp.Handler())
debugListener := buildListenerOrFail("http", *flagHttpPort)
go func() {
logrus.Infof("listening for http on: %v", debugListener.Addr().String())
if err := debugServer.Serve(debugListener); err != nil {
errChan <- fmt.Errorf("http_debug server error: %v", err)
}
}()
// Debug server.
servingServer := http.Server{
WriteTimeout: *flagHttpMaxWriteTimeout,
ReadTimeout: *flagHttpMaxReadTimeout,
Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
wrappedGrpc.ServeHTTP(resp, req)
}),
}
servingListener := buildListenerOrFail("http", *flagHttpTlsPort)
servingListener = tls.NewListener(servingListener, serverTls)
go func() {
logrus.Infof("listening for http_tls on: %v", servingListener.Addr().String())
if err := servingServer.Serve(servingListener); err != nil {
errChan <- fmt.Errorf("http_tls server error: %v", err)
}
}()
<-errChan
// TODO(mwitkow): Add graceful shutdown.
}
func buildGrpcProxyServer(logger *logrus.Entry) *grpc.Server {
// gRPC-wide changes.
grpc.EnableTracing = true
grpc_logrus.ReplaceGrpcLogger(logger)
// gRPC proxy logic.
backendConn := dialBackendOrFail()
director := func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error) {
return backendConn, nil
}
// Server with logging and monitoring enabled.
return grpc.NewServer(
grpc.CustomCodec(proxy.Codec()), // needed for proxy to function.
grpc.UnknownServiceHandler(proxy.TransparentHandler(director)),
grpc_middleware.WithUnaryServerChain(
grpc_logrus.UnaryServerInterceptor(logger),
grpc_prometheus.UnaryServerInterceptor,
),
grpc_middleware.WithStreamServerChain(
grpc_logrus.StreamServerInterceptor(logger),
grpc_prometheus.StreamServerInterceptor,
),
)
}
func buildListenerOrFail(name string, port int) net.Listener {
addr := fmt.Sprintf("%s:%d", *flagBindAddr, port)
listener, err := net.Listen("tcp", addr)
if err != nil {
log.Fatalf("failed listening for '%v' on %v: %v", name, port, err)
}
return conntrack.NewListener(listener,
conntrack.TrackWithName(name),
conntrack.TrackWithTcpKeepAlive(20*time.Second),
conntrack.TrackWithTracing(),
)
}