diff --git a/config/config.go b/config/config.go index f2024ba..00fc1a1 100755 --- a/config/config.go +++ b/config/config.go @@ -63,7 +63,7 @@ type Server struct { // ShutdownTimeout represents the duration before force shutdown. ShutdownTimeout string `yaml:"shutdownTimeout"` - // ShutdownDelay represents the delay duration between the health check server shutdown and the client sidecar server shutdown. + // ShutdownDelay represents the delay duration between the health check server shutdown and the api server shutdown. ShutdownDelay string `yaml:"shutdownDelay"` // TLS represents the TLS configuration of the authorization proxy. diff --git a/config/config_test.go b/config/config_test.go index e1929b3..2b20336 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -43,7 +43,7 @@ func TestNew(t *testing.T) { { name: "Test file content not valid", args: args{ - path: "../test/data/not_valid_config.yaml", + path: "../test/data/invalid_config.yaml", }, wantErr: fmt.Errorf("decode file failed: yaml: line 11: could not find expected ':'"), }, diff --git a/main.go b/main.go index f1659d8..1ae9624 100644 --- a/main.go +++ b/main.go @@ -118,13 +118,19 @@ func run(cfg config.Config) []error { signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + isSignal := false for { select { - case <-sigCh: + case sig := <-sigCh: + glg.Infof("authorization proxy received signal: %v", sig) + isSignal = true cancel() - glg.Warn("Got authorization-proxy server shutdown signal...") + glg.Warn("authorization proxy main process shutdown...") case errs := <-ech: - return errs + if !isSignal || len(errs) != 1 || errs[0] != ctx.Err() { + return errs + } + return nil } } } @@ -172,6 +178,8 @@ func main() { glg.Fatal(emsg) return } + glg.Info("authorization proxy main process shutdown success") + os.Exit(1) } func getVersion() string { diff --git a/main_test.go b/main_test.go index 11d3186..950eb2f 100644 --- a/main_test.go +++ b/main_test.go @@ -2,8 +2,12 @@ package main import ( "os" + "os/exec" "reflect" + "strconv" + "strings" "testing" + "time" "github.com/AthenZ/authorization-proxy/v4/config" "github.com/kpango/glg" @@ -70,6 +74,10 @@ func TestParseParams(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + defer func(oldArgs []string) { + // restore os.Args + os.Args = oldArgs + }(os.Args) if tt.beforeFunc != nil { tt.beforeFunc() } @@ -461,3 +469,135 @@ func Test_getVersion(t *testing.T) { }) } } + +func Test_main(t *testing.T) { + type test struct { + name string + beforeFunc func() + afterFunc func() + } + tests := []test{ + func() test { + var oldArgs []string + return test{ + name: "show version", + beforeFunc: func() { + oldArgs = os.Args + os.Args = []string{"authorization-proxy", "-version"} + }, + afterFunc: func() { + os.Args = oldArgs + }, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer tt.afterFunc() + tt.beforeFunc() + main() + }) + } +} + +func Test_mainExitCode(t *testing.T) { + tests := []struct { + name string + args []string + signal os.Signal + wantExitCode int + }{ + { + name: "normal exit", + args: []string{ + "-version", + }, + signal: nil, + wantExitCode: 0, + }, + { + name: "undefined flag", + args: []string{ + "-undefined_flag", + }, + signal: nil, + wantExitCode: 1, + }, + { + name: "run with log error", + args: []string{ + "-f", + "./test/data/invalid_log_config.yaml", + }, + signal: nil, + wantExitCode: 1, + }, + // TODO: need Athenz public key endpoint mock + /* + { + name: "run till termination SIGINT", + args: []string{ + "-f", + "./test/data/valid_config.yaml", + }, + signal: syscall.SIGINT, + wantExitCode: 1, + }, + { + name: "run till termination SIGTERM", + args: []string{ + "-f", + "./test/data/valid_config.yaml", + }, + signal: syscall.SIGTERM, + wantExitCode: 1, + }, + */ + } + + rc := os.Getenv("RUN_CASE") + if rc != "" { + c, err := strconv.Atoi(rc) + if err != nil { + panic(err) + } + tt := tests[c] + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = append([]string{"authorization-proxy"}, tt.args...) + + if tt.signal != nil { + // send signal + go func() { + proc, err := os.FindProcess(os.Getpid()) + if err != nil { + panic(err) + } + + time.Sleep(200 * time.Millisecond) + proc.Signal(tt.signal) + }() + } + + // run main + main() + return + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var outbuf, errbuf strings.Builder + + cmd := exec.Command(os.Args[0], "-test.run=Test_mainExitCode") + cmd.Stdout = &outbuf + cmd.Stderr = &errbuf + cmd.Env = append(os.Environ(), "RUN_CASE="+strconv.Itoa(i)) + err := cmd.Run() + exitCode := cmd.ProcessState.ExitCode() + if exitCode != tt.wantExitCode { + t.Errorf("main() err = %v, stdout = %s, stderr = %s, exit code = %v, wantExitCode %v", err, outbuf.String(), errbuf.String(), exitCode, tt.wantExitCode) + } + }) + } +} diff --git a/service/server.go b/service/server.go index b3ca10e..e15d269 100644 --- a/service/server.go +++ b/service/server.go @@ -271,7 +271,7 @@ func (s *server) ListenAndServe(ctx context.Context) <-chan []error { return errs } - shutdownSrvs := func(errs []error) { + shutdownSrvs := func(errs []error) []error { if s.hcRunning { glg.Info("authorization proxy health check server will shutdown...") errs = appendErr(errs, s.hcShutdown(context.Background())) @@ -286,65 +286,50 @@ func (s *server) ListenAndServe(ctx context.Context) <-chan []error { } if s.dRunning { glg.Info("authorization proxy debug server will shutdown...") - appendErr(errs, s.dShutdown(context.Background())) + errs = appendErr(errs, s.dShutdown(context.Background())) } - glg.Info("authorization proxy has already shutdown gracefully") + if len(errs) == 0 { + glg.Info("authorization proxy has already shutdown gracefully") + } + return errs } errs := make([]error, 0, 3) + + handleErr := func(err error) { + if err != nil { + errs = append(errs, errors.Wrap(err, "close running servers and return any error")) + } + s.mu.RLock() + errs = shutdownSrvs(errs) + s.mu.RUnlock() + echan <- errs + } + for { select { case <-ctx.Done(): // when context receive done signal, close running servers and return any error s.mu.RLock() - shutdownSrvs(errs) + errs = shutdownSrvs(errs) s.mu.RUnlock() echan <- appendErr(errs, ctx.Err()) return case err := <-sech: // when authorization proxy server returns, close running servers and return any error - if err != nil { - errs = append(errs, errors.Wrap(err, "close running servers and return any error")) - } - - s.mu.RLock() - shutdownSrvs(errs) - s.mu.RUnlock() - echan <- errs + handleErr(err) return case err := <-gsech: // when authorization proxy grpc server returns, close running servers and return any error - if err != nil { - errs = append(errs, errors.Wrap(err, "close running servers and return any error")) - } - - s.mu.RLock() - shutdownSrvs(errs) - s.mu.RUnlock() - echan <- errs + handleErr(err) return case err := <-hech: // when health check server returns, close running servers and return any error - if err != nil { - errs = append(errs, errors.Wrap(err, "close running servers and return any error")) - } - - s.mu.RLock() - shutdownSrvs(errs) - s.mu.RUnlock() - echan <- errs + handleErr(err) return case err := <-dech: // when debug server returns, close running servers and return any error - if err != nil { - errs = append(errs, errors.Wrap(err, "close running servers and return any error")) - } - - s.mu.RLock() - shutdownSrvs(errs) - s.mu.RUnlock() - echan <- errs + handleErr(err) return - } } }() @@ -364,7 +349,7 @@ func (s *server) dShutdown(ctx context.Context) error { return s.dsrv.Shutdown(dctx) } -// apiShutdown returns any error when shutdown the authorization proxy server. +// apiShutdown returns any error when shutdown the authorization proxy API server. // Before shutdown the authorization proxy server, it will sleep config.ShutdownDelay to prevent any issue from K8s func (s *server) apiShutdown(ctx context.Context) error { time.Sleep(s.sdd) @@ -373,7 +358,7 @@ func (s *server) apiShutdown(ctx context.Context) error { return s.srv.Shutdown(sctx) } -// apiShutdown returns any error when shutdown the authorization proxy server. +// grpcShutdown returns any error when shutdown the authorization proxy gPRC server. // Before shutdown the authorization proxy server, it will sleep config.ShutdownDelay to prevent any issue from K8s func (s *server) grpcShutdown() { time.Sleep(s.sdd) diff --git a/test/data/not_valid_config.yaml b/test/data/invalid_config.yaml similarity index 100% rename from test/data/not_valid_config.yaml rename to test/data/invalid_config.yaml diff --git a/test/data/invalid_log_config.yaml b/test/data/invalid_log_config.yaml new file mode 100644 index 0000000..2981153 --- /dev/null +++ b/test/data/invalid_log_config.yaml @@ -0,0 +1,4 @@ +--- +version: v2.0.0 +log: + level: invalid diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index 990513c..ad769dc 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -101,7 +101,7 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { pch := g.athenz.Start(ctx) for err := range pch { - if err != nil { + if err != nil && err != ctx.Err() { glg.Errorf("pch: %v", err) // count errors by cause cause := errors.Cause(err).Error() @@ -120,12 +120,11 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { // handle proxy server error, return on server shutdown done eg.Go(func() error { errs := <-g.server.ListenAndServe(ctx) + if len(errs) == 0 || len(errs) == 1 && errors.Cause(errs[0]) == ctx.Err() { + return nil + } glg.Errorf("sch: %v", errs) - if len(errs) == 0 { - // cannot be nil so that the context can cancel - return errors.New("") - } var baseErr error for i, err := range errs { if i == 0 { @@ -156,8 +155,10 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { perrs = append(perrs, errors.WithMessagef(errors.New(errMsg), "authorizerd: %d times appeared", count)) } - // proxy server go func, should always return not nil error - ech <- append(perrs, err) + if err != nil { + ech <- append(perrs, err) + } + ech <- perrs }() return ech diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index 7c5c699..4ab40e0 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -219,10 +219,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { args: args{ ctx: ctx, }, - wantErrs: []error{ - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), - context.Canceled, - }, + wantErrs: []error{}, checkFunc: func(got <-chan []error, wantErrs []error) error { cancel() mux := &sync.Mutex{} @@ -287,7 +284,6 @@ func Test_authzProxyDaemon_Start(t *testing.T) { ctx: ctx, }, wantErrs: []error{ - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), errors.WithMessage(dummyErr, "server fails"), }, checkFunc: func(got <-chan []error, wantErrs []error) error { @@ -366,8 +362,6 @@ func Test_authzProxyDaemon_Start(t *testing.T) { }, wantErrs: []error{ errors.WithMessage(errors.Cause(errors.WithMessage(dummyErr, "authorizer daemon fails")), "authorizerd: 3 times appeared"), - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), - context.Canceled, }, checkFunc: func(got <-chan []error, wantErrs []error) error { cancel() @@ -436,10 +430,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { args: args{ ctx: ctx, }, - wantErrs: []error{ - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), - errors.New(""), - }, + wantErrs: []error{}, checkFunc: func(got <-chan []error, wantErrs []error) error { cancel() mux := &sync.Mutex{} @@ -509,7 +500,6 @@ func Test_authzProxyDaemon_Start(t *testing.T) { ctx: ctx, }, wantErrs: []error{ - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), errors.Wrap(dummyErr, context.Canceled.Error()), }, checkFunc: func(got <-chan []error, wantErrs []error) error {