From 2eb674671c14fb3093255cd0c7578c44b926b749 Mon Sep 17 00:00:00 2001 From: t4niwa <114040262+t4niwa@users.noreply.github.com> Date: Wed, 30 Nov 2022 13:50:02 +0900 Subject: [PATCH 01/58] add option status code log from origin (#4) * add status-code log Signed-off-by: taniwa * update Signed-off-by: taniwa * add option for origin log Signed-off-by: taniwa * update Signed-off-by: taniwa * fix Signed-off-by: taniwa * add handler_test Signed-off-by: taniwa Signed-off-by: taniwa Signed-off-by: Kyo Fujisaki --- config/config.go | 14 ++++++++++++++ config/config_test.go | 6 ++++++ handler/handler.go | 14 ++++++++++++++ handler/handler_test.go | 19 +++++++++++++++++++ test/data/example_config.yaml | 5 +++++ 5 files changed, 58 insertions(+) mode change 100755 => 100644 test/data/example_config.yaml diff --git a/config/config.go b/config/config.go index 0e54b81..f2024ba 100755 --- a/config/config.go +++ b/config/config.go @@ -154,6 +154,9 @@ type Proxy struct { // Transport exposes http.Transport parameters Transport Transport `yaml:"transport,omitempty"` + + // OriginLog represents log configuration from origin + OriginLog OriginLog `yaml:"originLog"` } // Authorization represents the detail authorization configuration. @@ -287,6 +290,17 @@ type Transport struct { ForceAttemptHTTP2 bool `yaml:"forceAttemptHTTP2,omitempty"` } +// OriginLog represents log configuration from origin +type OriginLog struct { + StatusCode StatusCode `yaml:"statusCode"` +} + +// StatusCode represents statuscode log configuration +type StatusCode struct { + Enable bool `yaml:"enable"` + Exclude []int `yaml:"exclude"` +} + // New returns the decoded configuration YAML file as *Config struct. Returns non-nil error if any. func New(path string) (*Config, error) { f, err := os.OpenFile(path, os.O_RDONLY, 0o600) diff --git a/config/config_test.go b/config/config_test.go index 38840d1..e1929b3 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -127,6 +127,12 @@ func TestNew(t *testing.T) { ReadBufferSize: 0, ForceAttemptHTTP2: true, }, + OriginLog: OriginLog{ + StatusCode: StatusCode{ + Enable: true, + Exclude: []int{200}, + }, + }, }, Authorization: Authorization{ PublicKey: PublicKey{ diff --git a/handler/handler.go b/handler/handler.go index 8ae88fe..8418fd6 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -45,6 +45,19 @@ func New(cfg config.Proxy, bp httputil.BufferPool, prov service.Authorizationd) host := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + var modifyResponse func(res *http.Response) error = nil + if cfg.OriginLog.StatusCode.Enable { + modifyResponse = func(res *http.Response) error { + for _, statusCode := range cfg.OriginLog.StatusCode.Exclude { + if statusCode == res.StatusCode { + return nil + } + } + glg.Infof("Origin request: %s %s, Response: status code: %d", res.Request.Method, res.Request.URL, res.StatusCode) + return nil + } + } + return &httputil.ReverseProxy{ BufferPool: bp, Director: func(r *http.Request) { @@ -71,6 +84,7 @@ func New(cfg config.Proxy, bp httputil.BufferPool, prov service.Authorizationd) *r = *req }, + ModifyResponse: modifyResponse, Transport: &transport{ prov: prov, RoundTripper: transportFromCfg(cfg.Transport), diff --git a/handler/handler_test.go b/handler/handler_test.go index 8f8d761..f30815f 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -534,6 +534,25 @@ func TestNew(t *testing.T) { return nil }, }, + { + name: "check originlog is used", + args: args{ + cfg: config.Proxy{ + OriginLog: config.OriginLog{ + StatusCode: config.StatusCode{ + Enable: true, + Exclude: []int{}, + }, + }, + }, + }, + checkFunc: func(h http.Handler) error { + if h.(*httputil.ReverseProxy).ModifyResponse == nil { + return errors.Errorf("unexpected ModifyResponse") + } + return nil + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/test/data/example_config.yaml b/test/data/example_config.yaml old mode 100755 new mode 100644 index 757b983..dfd6ecb --- a/test/data/example_config.yaml +++ b/test/data/example_config.yaml @@ -44,6 +44,11 @@ proxy: writeBufferSize: 0 readBufferSize: 0 forceAttemptHTTP2: true + originLog: + statusCode: + enable: true + exclude: + - 200 authorization: athenzDomains: - provider-domain1 From 2afba0bb831b382756b40f186eae413dfd4e5b57 Mon Sep 17 00:00:00 2001 From: Windz Date: Wed, 7 Dec 2022 16:22:16 +0900 Subject: [PATCH 02/58] fix error & fatal in normal shutdown (#6) * draft Signed-off-by: wfan * bug fix: server error not appened Signed-off-by: wfan * add main unit test Signed-off-by: wfan * remove - Signed-off-by: wfan * refactor: handle err for shared functionalities Signed-off-by: Jeongwoo Kim - jekim * fix comment Signed-off-by: wfan Signed-off-by: wfan Signed-off-by: Jeongwoo Kim - jekim Co-authored-by: Jeongwoo Kim - jekim Signed-off-by: Kyo Fujisaki --- config/config.go | 2 +- config/config_test.go | 2 +- main.go | 14 +- main_test.go | 140 ++++++++++++++++++ service/server.go | 63 +++----- ..._valid_config.yaml => invalid_config.yaml} | 0 test/data/invalid_log_config.yaml | 4 + usecase/authz_proxyd.go | 15 +- usecase/authz_proxyd_test.go | 14 +- 9 files changed, 191 insertions(+), 63 deletions(-) rename test/data/{not_valid_config.yaml => invalid_config.yaml} (100%) create mode 100644 test/data/invalid_log_config.yaml diff --git a/config/config.go b/config/config.go index f2024ba..00fc1a1 100755 --- a/config/config.go +++ b/config/config.go @@ -63,7 +63,7 @@ type Server struct { // ShutdownTimeout represents the duration before force shutdown. ShutdownTimeout string `yaml:"shutdownTimeout"` - // ShutdownDelay represents the delay duration between the health check server shutdown and the client sidecar server shutdown. + // ShutdownDelay represents the delay duration between the health check server shutdown and the api server shutdown. ShutdownDelay string `yaml:"shutdownDelay"` // TLS represents the TLS configuration of the authorization proxy. diff --git a/config/config_test.go b/config/config_test.go index e1929b3..2b20336 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -43,7 +43,7 @@ func TestNew(t *testing.T) { { name: "Test file content not valid", args: args{ - path: "../test/data/not_valid_config.yaml", + path: "../test/data/invalid_config.yaml", }, wantErr: fmt.Errorf("decode file failed: yaml: line 11: could not find expected ':'"), }, diff --git a/main.go b/main.go index f1659d8..1ae9624 100644 --- a/main.go +++ b/main.go @@ -118,13 +118,19 @@ func run(cfg config.Config) []error { signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + isSignal := false for { select { - case <-sigCh: + case sig := <-sigCh: + glg.Infof("authorization proxy received signal: %v", sig) + isSignal = true cancel() - glg.Warn("Got authorization-proxy server shutdown signal...") + glg.Warn("authorization proxy main process shutdown...") case errs := <-ech: - return errs + if !isSignal || len(errs) != 1 || errs[0] != ctx.Err() { + return errs + } + return nil } } } @@ -172,6 +178,8 @@ func main() { glg.Fatal(emsg) return } + glg.Info("authorization proxy main process shutdown success") + os.Exit(1) } func getVersion() string { diff --git a/main_test.go b/main_test.go index 11d3186..950eb2f 100644 --- a/main_test.go +++ b/main_test.go @@ -2,8 +2,12 @@ package main import ( "os" + "os/exec" "reflect" + "strconv" + "strings" "testing" + "time" "github.com/AthenZ/authorization-proxy/v4/config" "github.com/kpango/glg" @@ -70,6 +74,10 @@ func TestParseParams(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + defer func(oldArgs []string) { + // restore os.Args + os.Args = oldArgs + }(os.Args) if tt.beforeFunc != nil { tt.beforeFunc() } @@ -461,3 +469,135 @@ func Test_getVersion(t *testing.T) { }) } } + +func Test_main(t *testing.T) { + type test struct { + name string + beforeFunc func() + afterFunc func() + } + tests := []test{ + func() test { + var oldArgs []string + return test{ + name: "show version", + beforeFunc: func() { + oldArgs = os.Args + os.Args = []string{"authorization-proxy", "-version"} + }, + afterFunc: func() { + os.Args = oldArgs + }, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer tt.afterFunc() + tt.beforeFunc() + main() + }) + } +} + +func Test_mainExitCode(t *testing.T) { + tests := []struct { + name string + args []string + signal os.Signal + wantExitCode int + }{ + { + name: "normal exit", + args: []string{ + "-version", + }, + signal: nil, + wantExitCode: 0, + }, + { + name: "undefined flag", + args: []string{ + "-undefined_flag", + }, + signal: nil, + wantExitCode: 1, + }, + { + name: "run with log error", + args: []string{ + "-f", + "./test/data/invalid_log_config.yaml", + }, + signal: nil, + wantExitCode: 1, + }, + // TODO: need Athenz public key endpoint mock + /* + { + name: "run till termination SIGINT", + args: []string{ + "-f", + "./test/data/valid_config.yaml", + }, + signal: syscall.SIGINT, + wantExitCode: 1, + }, + { + name: "run till termination SIGTERM", + args: []string{ + "-f", + "./test/data/valid_config.yaml", + }, + signal: syscall.SIGTERM, + wantExitCode: 1, + }, + */ + } + + rc := os.Getenv("RUN_CASE") + if rc != "" { + c, err := strconv.Atoi(rc) + if err != nil { + panic(err) + } + tt := tests[c] + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = append([]string{"authorization-proxy"}, tt.args...) + + if tt.signal != nil { + // send signal + go func() { + proc, err := os.FindProcess(os.Getpid()) + if err != nil { + panic(err) + } + + time.Sleep(200 * time.Millisecond) + proc.Signal(tt.signal) + }() + } + + // run main + main() + return + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var outbuf, errbuf strings.Builder + + cmd := exec.Command(os.Args[0], "-test.run=Test_mainExitCode") + cmd.Stdout = &outbuf + cmd.Stderr = &errbuf + cmd.Env = append(os.Environ(), "RUN_CASE="+strconv.Itoa(i)) + err := cmd.Run() + exitCode := cmd.ProcessState.ExitCode() + if exitCode != tt.wantExitCode { + t.Errorf("main() err = %v, stdout = %s, stderr = %s, exit code = %v, wantExitCode %v", err, outbuf.String(), errbuf.String(), exitCode, tt.wantExitCode) + } + }) + } +} diff --git a/service/server.go b/service/server.go index b3ca10e..e15d269 100644 --- a/service/server.go +++ b/service/server.go @@ -271,7 +271,7 @@ func (s *server) ListenAndServe(ctx context.Context) <-chan []error { return errs } - shutdownSrvs := func(errs []error) { + shutdownSrvs := func(errs []error) []error { if s.hcRunning { glg.Info("authorization proxy health check server will shutdown...") errs = appendErr(errs, s.hcShutdown(context.Background())) @@ -286,65 +286,50 @@ func (s *server) ListenAndServe(ctx context.Context) <-chan []error { } if s.dRunning { glg.Info("authorization proxy debug server will shutdown...") - appendErr(errs, s.dShutdown(context.Background())) + errs = appendErr(errs, s.dShutdown(context.Background())) } - glg.Info("authorization proxy has already shutdown gracefully") + if len(errs) == 0 { + glg.Info("authorization proxy has already shutdown gracefully") + } + return errs } errs := make([]error, 0, 3) + + handleErr := func(err error) { + if err != nil { + errs = append(errs, errors.Wrap(err, "close running servers and return any error")) + } + s.mu.RLock() + errs = shutdownSrvs(errs) + s.mu.RUnlock() + echan <- errs + } + for { select { case <-ctx.Done(): // when context receive done signal, close running servers and return any error s.mu.RLock() - shutdownSrvs(errs) + errs = shutdownSrvs(errs) s.mu.RUnlock() echan <- appendErr(errs, ctx.Err()) return case err := <-sech: // when authorization proxy server returns, close running servers and return any error - if err != nil { - errs = append(errs, errors.Wrap(err, "close running servers and return any error")) - } - - s.mu.RLock() - shutdownSrvs(errs) - s.mu.RUnlock() - echan <- errs + handleErr(err) return case err := <-gsech: // when authorization proxy grpc server returns, close running servers and return any error - if err != nil { - errs = append(errs, errors.Wrap(err, "close running servers and return any error")) - } - - s.mu.RLock() - shutdownSrvs(errs) - s.mu.RUnlock() - echan <- errs + handleErr(err) return case err := <-hech: // when health check server returns, close running servers and return any error - if err != nil { - errs = append(errs, errors.Wrap(err, "close running servers and return any error")) - } - - s.mu.RLock() - shutdownSrvs(errs) - s.mu.RUnlock() - echan <- errs + handleErr(err) return case err := <-dech: // when debug server returns, close running servers and return any error - if err != nil { - errs = append(errs, errors.Wrap(err, "close running servers and return any error")) - } - - s.mu.RLock() - shutdownSrvs(errs) - s.mu.RUnlock() - echan <- errs + handleErr(err) return - } } }() @@ -364,7 +349,7 @@ func (s *server) dShutdown(ctx context.Context) error { return s.dsrv.Shutdown(dctx) } -// apiShutdown returns any error when shutdown the authorization proxy server. +// apiShutdown returns any error when shutdown the authorization proxy API server. // Before shutdown the authorization proxy server, it will sleep config.ShutdownDelay to prevent any issue from K8s func (s *server) apiShutdown(ctx context.Context) error { time.Sleep(s.sdd) @@ -373,7 +358,7 @@ func (s *server) apiShutdown(ctx context.Context) error { return s.srv.Shutdown(sctx) } -// apiShutdown returns any error when shutdown the authorization proxy server. +// grpcShutdown returns any error when shutdown the authorization proxy gPRC server. // Before shutdown the authorization proxy server, it will sleep config.ShutdownDelay to prevent any issue from K8s func (s *server) grpcShutdown() { time.Sleep(s.sdd) diff --git a/test/data/not_valid_config.yaml b/test/data/invalid_config.yaml similarity index 100% rename from test/data/not_valid_config.yaml rename to test/data/invalid_config.yaml diff --git a/test/data/invalid_log_config.yaml b/test/data/invalid_log_config.yaml new file mode 100644 index 0000000..2981153 --- /dev/null +++ b/test/data/invalid_log_config.yaml @@ -0,0 +1,4 @@ +--- +version: v2.0.0 +log: + level: invalid diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index 990513c..ad769dc 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -101,7 +101,7 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { pch := g.athenz.Start(ctx) for err := range pch { - if err != nil { + if err != nil && err != ctx.Err() { glg.Errorf("pch: %v", err) // count errors by cause cause := errors.Cause(err).Error() @@ -120,12 +120,11 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { // handle proxy server error, return on server shutdown done eg.Go(func() error { errs := <-g.server.ListenAndServe(ctx) + if len(errs) == 0 || len(errs) == 1 && errors.Cause(errs[0]) == ctx.Err() { + return nil + } glg.Errorf("sch: %v", errs) - if len(errs) == 0 { - // cannot be nil so that the context can cancel - return errors.New("") - } var baseErr error for i, err := range errs { if i == 0 { @@ -156,8 +155,10 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { perrs = append(perrs, errors.WithMessagef(errors.New(errMsg), "authorizerd: %d times appeared", count)) } - // proxy server go func, should always return not nil error - ech <- append(perrs, err) + if err != nil { + ech <- append(perrs, err) + } + ech <- perrs }() return ech diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index 7c5c699..4ab40e0 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -219,10 +219,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { args: args{ ctx: ctx, }, - wantErrs: []error{ - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), - context.Canceled, - }, + wantErrs: []error{}, checkFunc: func(got <-chan []error, wantErrs []error) error { cancel() mux := &sync.Mutex{} @@ -287,7 +284,6 @@ func Test_authzProxyDaemon_Start(t *testing.T) { ctx: ctx, }, wantErrs: []error{ - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), errors.WithMessage(dummyErr, "server fails"), }, checkFunc: func(got <-chan []error, wantErrs []error) error { @@ -366,8 +362,6 @@ func Test_authzProxyDaemon_Start(t *testing.T) { }, wantErrs: []error{ errors.WithMessage(errors.Cause(errors.WithMessage(dummyErr, "authorizer daemon fails")), "authorizerd: 3 times appeared"), - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), - context.Canceled, }, checkFunc: func(got <-chan []error, wantErrs []error) error { cancel() @@ -436,10 +430,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { args: args{ ctx: ctx, }, - wantErrs: []error{ - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), - errors.New(""), - }, + wantErrs: []error{}, checkFunc: func(got <-chan []error, wantErrs []error) error { cancel() mux := &sync.Mutex{} @@ -509,7 +500,6 @@ func Test_authzProxyDaemon_Start(t *testing.T) { ctx: ctx, }, wantErrs: []error{ - errors.WithMessage(context.Canceled, "authorizerd: 1 times appeared"), errors.Wrap(dummyErr, context.Canceled.Error()), }, checkFunc: func(got <-chan []error, wantErrs []error) error { From 74f011fa7306a1d4c47bf15ef59663692d2166f3 Mon Sep 17 00:00:00 2001 From: Windz Date: Mon, 26 Dec 2022 14:00:38 +0900 Subject: [PATCH 03/58] add resource prefix config (#12) * add resource prefix config Signed-off-by: wfan * add unit test Signed-off-by: wfan * fix test Signed-off-by: wfan * upgrade authorizer Signed-off-by: wfan * upgrade go.mod Signed-off-by: wfan Signed-off-by: wfan Signed-off-by: Kyo Fujisaki --- config/config.go | 3 +++ config/config_test.go | 1 + go.mod | 38 +++++++++++++++++------------------ go.sum | 28 +++++++++++++------------- test/data/example_config.yaml | 1 + usecase/authz_proxyd.go | 4 ++++ usecase/authz_proxyd_test.go | 17 ++++++++++++++++ 7 files changed, 59 insertions(+), 33 deletions(-) diff --git a/config/config.go b/config/config.go index 00fc1a1..97dea45 100755 --- a/config/config.go +++ b/config/config.go @@ -220,6 +220,9 @@ type Policy struct { // MappingRules represents translation rules for determining action and resource. MappingRules map[string][]authorizerd.Rule `yaml:"mappingRules"` + + // ResourcePrefix represents prefix prepended to mapped resource. + ResourcePrefix string `yaml:"resourcePrefix"` } // JWK represents the configuration to fetch Athenz JWK. diff --git a/config/config_test.go b/config/config_test.go index 2b20336..065b5be 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -175,6 +175,7 @@ func TestNew(t *testing.T) { }, }, }, + ResourcePrefix: "/public", }, JWK: JWK{ RefreshPeriod: "", diff --git a/go.mod b/go.mod index 359978a..6b5b3f2 100644 --- a/go.mod +++ b/go.mod @@ -3,31 +3,31 @@ module github.com/AthenZ/authorization-proxy/v4 go 1.19 replace ( - cloud.google.com/go => cloud.google.com/go v0.106.0 + cloud.google.com/go => cloud.google.com/go v0.107.0 github.com/golang/mock => github.com/golang/mock v1.6.0 github.com/golang/protobuf => github.com/golang/protobuf v1.5.2 github.com/google/go-cmp => github.com/google/go-cmp v0.5.9 - github.com/google/pprof => github.com/google/pprof v0.0.0-20221112000123-84eb7ad69597 + github.com/google/pprof => github.com/google/pprof v0.0.0-20221219190121-3cb0bae90811 github.com/mwitkow/grpc-proxy => github.com/mwitkow/grpc-proxy v0.0.0-20181017164139-0f1106ef9c76 - golang.org/x/crypto => golang.org/x/crypto v0.2.0 - golang.org/x/exp => golang.org/x/exp v0.0.0-20221114191408-850992195362 - golang.org/x/image => golang.org/x/image v0.1.0 + golang.org/x/crypto => golang.org/x/crypto v0.4.0 + golang.org/x/exp => golang.org/x/exp v0.0.0-20221217163422-3c43f8badb15 + golang.org/x/image => golang.org/x/image v0.2.0 golang.org/x/lint => golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 golang.org/x/mobile => golang.org/x/mobile v0.0.0-20221110043201-43a038452099 golang.org/x/mod => golang.org/x/mod v0.7.0 - golang.org/x/net => golang.org/x/net v0.2.0 - golang.org/x/oauth2 => golang.org/x/oauth2 v0.2.0 + golang.org/x/net => golang.org/x/net v0.4.0 + golang.org/x/oauth2 => golang.org/x/oauth2 v0.3.0 golang.org/x/sync => golang.org/x/sync v0.1.0 - golang.org/x/sys => golang.org/x/sys v0.2.0 - golang.org/x/term => golang.org/x/term v0.2.0 - golang.org/x/text => golang.org/x/text v0.4.0 - golang.org/x/time => golang.org/x/time v0.2.0 - golang.org/x/tools => golang.org/x/tools v0.3.0 + golang.org/x/sys => golang.org/x/sys v0.3.0 + golang.org/x/term => golang.org/x/term v0.3.0 + golang.org/x/text => golang.org/x/text v0.5.0 + golang.org/x/time => golang.org/x/time v0.3.0 + golang.org/x/tools => golang.org/x/tools v0.4.0 golang.org/x/xerrors => golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 - google.golang.org/api => google.golang.org/api v0.103.0 + google.golang.org/api => google.golang.org/api v0.105.0 google.golang.org/appengine => google.golang.org/appengine v1.6.7 - google.golang.org/genproto => google.golang.org/genproto v0.0.0-20221114212237-e4508ebdbee1 - google.golang.org/grpc => google.golang.org/grpc v1.50.1 + google.golang.org/genproto => google.golang.org/genproto v0.0.0-20221207170731-23e4bf6bdc37 + google.golang.org/grpc => google.golang.org/grpc v1.51.0 google.golang.org/protobuf => google.golang.org/protobuf v1.28.1 ) @@ -37,7 +37,7 @@ require ( github.com/mwitkow/grpc-proxy v0.0.0-20181017164139-0f1106ef9c76 github.com/pkg/errors v0.9.1 golang.org/x/sync v0.1.0 - google.golang.org/grpc v1.50.1 + google.golang.org/grpc v1.51.0 google.golang.org/protobuf v1.28.1 gopkg.in/yaml.v2 v2.4.0 ) @@ -60,8 +60,8 @@ require ( github.com/lestrrat-go/option v1.0.0 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/crypto v0.1.0 // indirect - golang.org/x/net v0.2.0 // indirect - golang.org/x/sys v0.2.0 // indirect - golang.org/x/text v0.4.0 // indirect + golang.org/x/net v0.3.0 // indirect + golang.org/x/sys v0.3.0 // indirect + golang.org/x/text v0.5.0 // indirect google.golang.org/genproto v0.0.0-20220713161829-9c7dac0a6568 // indirect ) diff --git a/go.sum b/go.sum index 59cfdb9..af6e64d 100644 --- a/go.sum +++ b/go.sum @@ -59,23 +59,23 @@ github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaD go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY= -golang.org/x/crypto v0.2.0 h1:BRXPfhNivWL5Yq0BGQ39a2sW6t44aODpfxkWjYdzewE= -golang.org/x/crypto v0.2.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= +golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= +golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= +golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= -google.golang.org/genproto v0.0.0-20221114212237-e4508ebdbee1 h1:jCw9YRd2s40X9Vxi4zKsPRvSPlHWNqadVkpbMsCPzPQ= -google.golang.org/genproto v0.0.0-20221114212237-e4508ebdbee1/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg= -google.golang.org/grpc v1.50.1 h1:DS/BukOZWp8s6p4Dt/tOaJaTQyPyOoCcrjroHuCeLzY= -google.golang.org/grpc v1.50.1/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= +golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= +golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= +golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= +golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= +google.golang.org/genproto v0.0.0-20221207170731-23e4bf6bdc37 h1:jmIfw8+gSvXcZSgaFAGyInDXeWzUhvYH57G/5GKMn70= +google.golang.org/genproto v0.0.0-20221207170731-23e4bf6bdc37/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM= +google.golang.org/grpc v1.51.0 h1:E1eGv1FTqoLIdnBCZufiSHgKjlqG6fKFf6pPWtMTh8U= +google.golang.org/grpc v1.51.0/go.mod h1:wgNDFcnuBGmxLKI/qn4T+m5BtEBYXJPvibbUPsAIPww= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/test/data/example_config.yaml b/test/data/example_config.yaml index dfd6ecb..eabb86a 100644 --- a/test/data/example_config.yaml +++ b/test/data/example_config.yaml @@ -81,6 +81,7 @@ authorization: action: action path: "/path1/{path2}?param={value}" resource: "{path2}.{value}" + resourcePrefix: /public jwk: refreshPeriod: "" retryDelay: "" diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index ad769dc..ba16043 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -219,6 +219,10 @@ func newAuthzD(cfg config.Config) (service.Authorizationd, error) { } policyOpts = append(policyOpts, authorizerd.WithTranslator(translator)) } + + if prefix := authzCfg.Policy.ResourcePrefix; prefix != "" { + policyOpts = append(policyOpts, authorizerd.WithResourcePrefix(prefix)) + } } var rtOpts []authorizerd.Option if authzCfg.RoleToken.Enable { diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index 4ab40e0..b85abc3 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -767,6 +767,23 @@ func Test_newAuthzD(t *testing.T) { }, want: true, }, + { + name: "test success ResourcePrefix set", + args: args{ + cfg: config.Config{ + Authorization: config.Authorization{ + Policy: config.Policy{ + ResourcePrefix: "/public", + }, + RoleToken: config.RoleToken{ + Enable: true, + RoleAuthHeader: "Athenz-Role-Auth", + }, + }, + }, + }, + want: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 150d873a65583eb6c6d6e8986afd80455b57bbe9 Mon Sep 17 00:00:00 2001 From: Windz Date: Wed, 18 Jan 2023 13:48:56 +0900 Subject: [PATCH 04/58] new `noAuthPaths` option supporting wildcard characters (#15) * draft Signed-off-by: wfan * add unit test Signed-off-by: wfan * fix quote Signed-off-by: wfan * fix empty regex Signed-off-by: wfan * Update handler/error.go Co-authored-by: ssunorz <42366422+ssunorz@users.noreply.github.com> Signed-off-by: Windz Signed-off-by: wfan Signed-off-by: Windz Co-authored-by: ssunorz <42366422+ssunorz@users.noreply.github.com> Signed-off-by: Kyo Fujisaki --- config/config.go | 3 + config/config_test.go | 7 +- handler/error.go | 3 + handler/handler.go | 16 ++ handler/handler_test.go | 88 ++++++++ handler/transport.go | 26 ++- handler/transport_test.go | 383 ++++++++++++++++++++++++++++++++++ test/data/example_config.yaml | 4 + 8 files changed, 522 insertions(+), 8 deletions(-) diff --git a/config/config.go b/config/config.go index 97dea45..f228032 100755 --- a/config/config.go +++ b/config/config.go @@ -146,6 +146,9 @@ type Proxy struct { // Tips for performance: define your health check endpoint with a different length from the most frequently used endpoint, for example, use `/healthcheck` (len: 12) when `/most_used` (len: 10), instead of `/healthccc` (len: 10) OriginHealthCheckPaths []string `yaml:"originHealthCheckPaths"` + // NoAuthPaths represents endpoints that requires NO authorization. Wildcard characters supported in Athenz policy are supported too. + NoAuthPaths []string `yaml:"noAuthPaths"` + // PreserveHost represents whether to preserve the host header from the request. PreserveHost bool `yaml:"preserveHost"` diff --git a/config/config_test.go b/config/config_test.go index 065b5be..8627890 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -111,7 +111,12 @@ func TestNew(t *testing.T) { Port: 80, BufferSize: 4096, OriginHealthCheckPaths: []string{}, - PreserveHost: true, + NoAuthPaths: []string{ + "/no-auth/any/*", + "/no-auth/single/a?c", + "/no-auth/no-regex/^$|([{", + }, + PreserveHost: true, Transport: Transport{ TLSHandshakeTimeout: 10 * time.Second, DisableKeepAlives: false, diff --git a/handler/error.go b/handler/error.go index 4edf8be..6c06fd7 100644 --- a/handler/error.go +++ b/handler/error.go @@ -48,4 +48,7 @@ const ( // ErrRoleTokenNotFound "role token not found" ErrRoleTokenNotFound = "role token not found" + + // ErrInvalidProxyConfig "invalid proxy config". + ErrInvalidProxyConfig = "invalid proxy config" ) diff --git a/handler/handler.go b/handler/handler.go index 8418fd6..82a6011 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -29,6 +29,7 @@ import ( "github.com/kpango/glg" "github.com/pkg/errors" + "github.com/AthenZ/athenz-authorizer/v5/policy" "github.com/AthenZ/authorization-proxy/v4/config" "github.com/AthenZ/authorization-proxy/v4/service" ) @@ -89,6 +90,7 @@ func New(cfg config.Proxy, bp httputil.BufferPool, prov service.Authorizationd) prov: prov, RoundTripper: transportFromCfg(cfg.Transport), cfg: cfg, + noAuthPaths: mapPathToAssertion(cfg.NoAuthPaths), }, ErrorHandler: handleError, } @@ -156,6 +158,20 @@ func transportFromCfg(cfg config.Transport) *http.Transport { return t } +func mapPathToAssertion(paths []string) []*policy.Assertion { + as := make([]*policy.Assertion, len(paths)) + for i, p := range paths { + var err error + as[i], err = policy.NewAssertion("", ":"+p, "") + if err != nil { + // NewAssertion() escapes all regex characters and should NOT return ANY errors. + glg.Errorf("Invalid proxy.noAuthPaths: %s", p) + panic(ErrInvalidProxyConfig) + } + } + return as +} + func handleError(rw http.ResponseWriter, r *http.Request, err error) { if r != nil && r.Body != nil { io.Copy(ioutil.Discard, r.Body) diff --git a/handler/handler_test.go b/handler/handler_test.go index f30815f..e16ad17 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -15,6 +15,7 @@ import ( "time" authorizerd "github.com/AthenZ/athenz-authorizer/v5" + "github.com/AthenZ/athenz-authorizer/v5/policy" "github.com/AthenZ/authorization-proxy/v4/config" "github.com/AthenZ/authorization-proxy/v4/infra" "github.com/AthenZ/authorization-proxy/v4/service" @@ -625,6 +626,93 @@ func Test_transportFromCfg(t *testing.T) { } } +func Test_mapPathToAssertion(t *testing.T) { + type args struct { + paths []string + } + tests := []struct { + name string + args args + want []*policy.Assertion + wantPanic any + }{ + { + name: "nil list", + args: args{ + paths: nil, + }, + want: []*policy.Assertion{}, + }, + { + name: "empty list", + args: args{ + paths: []string{}, + }, + want: []*policy.Assertion{}, + }, + { + name: "single assertion", + args: args{ + paths: []string{ + "/path/656", + }, + }, + want: func() (as []*policy.Assertion) { + a, err := policy.NewAssertion("", ":/path/656", "") + if err != nil { + panic(err) + } + as = append(as, a) + return as + }(), + }, + { + name: "multiple assertion", + args: args{ + paths: []string{ + "/path/672", + "/path/673", + }, + }, + want: func() (as []*policy.Assertion) { + a1, err := policy.NewAssertion("", ":/path/672", "") + if err != nil { + panic(err) + } + a2, err := policy.NewAssertion("", ":/path/673", "") + if err != nil { + panic(err) + } + as = append(as, a1, a2) + return as + }(), + }, + // { + // name: "invalid assertion", + // args: args{ + // paths: []string{ + // "no invalid value", + // }, + // }, + // want: nil, + // wantPanic: ErrInvalidProxyConfig, + // }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + err := recover() + if err != tt.wantPanic { + t.Errorf("mapPathToAssertion() panic = %v, want panic %v", err, tt.wantPanic) + } + }() + if got := mapPathToAssertion(tt.args.paths); !reflect.DeepEqual(got, tt.want) { + t.Errorf("mapPathToAssertion() = %v, want %v", got, tt.want) + } + }) + } +} + func Test_handleError(t *testing.T) { type args struct { rw http.ResponseWriter diff --git a/handler/transport.go b/handler/transport.go index 99ab0dd..62f3a50 100644 --- a/handler/transport.go +++ b/handler/transport.go @@ -22,6 +22,7 @@ import ( "strings" authorizerd "github.com/AthenZ/athenz-authorizer/v5" + "github.com/AthenZ/athenz-authorizer/v5/policy" "github.com/AthenZ/authorization-proxy/v4/config" "github.com/AthenZ/authorization-proxy/v4/service" @@ -32,18 +33,29 @@ import ( type transport struct { http.RoundTripper - prov service.Authorizationd - cfg config.Proxy + prov service.Authorizationd + cfg config.Proxy + noAuthPaths []*policy.Assertion } // Based on the following. // https://github.com/golang/oauth2/blob/bf48bf16ab8d622ce64ec6ce98d2c98f916b6303/transport.go func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) { - for _, urlPath := range t.cfg.OriginHealthCheckPaths { - if urlPath == r.URL.Path { - glg.Info("Authorization checking skipped on: " + r.URL.Path) - r.TLS = nil - return t.RoundTripper.RoundTrip(r) + // bypass authoriztion + if len(r.URL.Path) != 0 { // prevent bypassing empty path on default config + for _, urlPath := range t.cfg.OriginHealthCheckPaths { + if urlPath == r.URL.Path { + glg.Info("Authorization checking skipped on: " + r.URL.Path) + r.TLS = nil + return t.RoundTripper.RoundTrip(r) + } + } + for _, ass := range t.noAuthPaths { + if ass.ResourceRegexp.MatchString(strings.ToLower(r.URL.Path)) { + glg.Infof("Authorization checking skipped by %s on: %s", ass.ResourceRegexpString, r.URL.Path) + r.TLS = nil + return t.RoundTripper.RoundTrip(r) + } } } diff --git a/handler/transport_test.go b/handler/transport_test.go index 68eb7dc..afcfc53 100644 --- a/handler/transport_test.go +++ b/handler/transport_test.go @@ -2,11 +2,13 @@ package handler import ( "errors" + "io" "net/http" "reflect" "testing" authorizerd "github.com/AthenZ/athenz-authorizer/v5" + "github.com/AthenZ/athenz-authorizer/v5/policy" "github.com/AthenZ/authorization-proxy/v4/config" "github.com/AthenZ/authorization-proxy/v4/service" ) @@ -26,10 +28,18 @@ func (r *readCloseCounter) Close() error { } func Test_transport_RoundTrip(t *testing.T) { + wrapAssertion := func(s string) *policy.Assertion { + a, err := policy.NewAssertion("", ":"+s, "") + if err != nil { + panic(err) + } + return a + } type fields struct { RoundTripper http.RoundTripper prov service.Authorizationd cfg config.Proxy + noAuthPaths []*policy.Assertion } type args struct { r *http.Request @@ -287,6 +297,90 @@ func Test_transport_RoundTrip(t *testing.T) { wantErr: true, wantCloseCount: 1, }, + { + name: "NoAuthPaths match, bypass role token verification", + fields: fields{ + RoundTripper: &RoundTripperMock{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + }, nil + }, + }, + prov: &service.AuthorizerdMock{ + VerifyFunc: func(r *http.Request, act, res string) (authorizerd.Principal, error) { + return nil, errors.New("role token error") + }, + }, + cfg: config.Proxy{}, + noAuthPaths: []*policy.Assertion{ + wrapAssertion("/no-auth"), + }, + }, + args: args{ + r: func() *http.Request { + r, _ := http.NewRequest("GET", "http://athenz.io/no-auth", nil) + return r + }(), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + want: &http.Response{ + StatusCode: 200, + }, + wantErr: false, + wantCloseCount: 0, + }, + { + name: "NoAuthPaths NONE match, verify role token", + fields: fields{ + RoundTripper: nil, + prov: &service.AuthorizerdMock{ + VerifyFunc: func(r *http.Request, act, res string) (authorizerd.Principal, error) { + return nil, errors.New("role token error") + }, + }, + cfg: config.Proxy{}, + noAuthPaths: []*policy.Assertion{ + wrapAssertion("/no-auth"), + }, + }, + args: args{ + r: func() *http.Request { + r, _ := http.NewRequest("GET", "http://athenz.io/no-auth/", nil) + return r + }(), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + wantErr: true, + wantCloseCount: 1, + }, + { + name: "NoAuthPaths NOT set, verify role token", + fields: fields{ + RoundTripper: nil, + prov: &service.AuthorizerdMock{ + VerifyFunc: func(r *http.Request, act, res string) (authorizerd.Principal, error) { + return nil, errors.New("role token error") + }, + }, + cfg: config.Proxy{}, + }, + args: args{ + r: func() *http.Request { + r, _ := http.NewRequest("GET", "http://athenz.io/no-auth", nil) + return r + }(), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + wantErr: true, + wantCloseCount: 1, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -294,6 +388,7 @@ func Test_transport_RoundTrip(t *testing.T) { RoundTripper: tt.fields.RoundTripper, prov: tt.fields.prov, cfg: tt.fields.cfg, + noAuthPaths: tt.fields.noAuthPaths, } if tt.args.body != nil { tt.args.r.Body = tt.args.body @@ -314,3 +409,291 @@ func Test_transport_RoundTrip(t *testing.T) { }) } } + +func Test_transport_RoundTrip_WildcardBypass(t *testing.T) { + wrapAssertion := func(s string) *policy.Assertion { + a, err := policy.NewAssertion("", ":"+s, "") + if err != nil { + panic(err) + } + return a + } + wrapRequest := func(method, url string, body io.Reader) *http.Request { + r, err := http.NewRequest(method, url, body) + if err != nil { + panic(err) + } + return r + } + type fields struct { + RoundTripper http.RoundTripper + prov service.Authorizationd + cfg config.Proxy + noAuthPaths []*policy.Assertion + } + type args struct { + r *http.Request + body *readCloseCounter + } + tests := []struct { + name string + fields fields + argss []args + want *http.Response + wantErr bool + wantCloseCount int + }{ + { + name: "NoAuthPaths '*' match, bypass role token verification", + fields: fields{ + RoundTripper: &RoundTripperMock{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + }, nil + }, + }, + prov: &service.AuthorizerdMock{ + VerifyFunc: func(r *http.Request, act, res string) (authorizerd.Principal, error) { + return nil, errors.New("role token error") + }, + }, + cfg: config.Proxy{}, + noAuthPaths: []*policy.Assertion{ + wrapAssertion("/no-auth*"), + }, + }, + argss: []args{ + { + r: wrapRequest("GET", "http://athenz.io/no-auth", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-authhhhhh", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/abc", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/abc/483", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + }, + want: &http.Response{ + StatusCode: 200, + }, + wantErr: false, + wantCloseCount: 0, + }, + { + name: "NoAuthPaths '?' match, bypass role token verification", + fields: fields{ + RoundTripper: &RoundTripperMock{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + }, nil + }, + }, + prov: &service.AuthorizerdMock{ + VerifyFunc: func(r *http.Request, act, res string) (authorizerd.Principal, error) { + return nil, errors.New("role token error") + }, + }, + cfg: config.Proxy{}, + noAuthPaths: []*policy.Assertion{ + wrapAssertion("/no-auth/a??c"), + }, + }, + argss: []args{ + { + r: wrapRequest("GET", "http://athenz.io/no-auth/aaac", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/accc", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/abbc", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + }, + want: &http.Response{ + StatusCode: 200, + }, + wantErr: false, + wantCloseCount: 0, + }, + { + name: "NoAuthPaths '?' NOT match, verify role token", + fields: fields{ + RoundTripper: nil, + prov: &service.AuthorizerdMock{ + VerifyFunc: func(r *http.Request, act, res string) (authorizerd.Principal, error) { + return nil, errors.New("role token error") + }, + }, + cfg: config.Proxy{}, + noAuthPaths: []*policy.Assertion{ + wrapAssertion("/no-auth/a??c"), + }, + }, + argss: []args{ + { + r: wrapRequest("GET", "http://athenz.io/no-auth/aaaa", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/cccc", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/abbbc", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/123456", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + }, + wantErr: true, + wantCloseCount: 1, + }, + { + name: "NoAuthPaths empty string NOT match, verify role token", + fields: fields{ + RoundTripper: nil, + prov: &service.AuthorizerdMock{ + VerifyFunc: func(r *http.Request, act, res string) (authorizerd.Principal, error) { + return nil, errors.New("role token error") + }, + }, + cfg: config.Proxy{}, + noAuthPaths: []*policy.Assertion{ + wrapAssertion(""), + }, + }, + argss: []args{ + { + r: wrapRequest("GET", "http://athenz.io", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + }, + wantErr: true, + wantCloseCount: 1, + }, + { + name: "NoAuthPaths NO escape, verify role token", + fields: fields{ + RoundTripper: &RoundTripperMock{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + }, nil + }, + }, + prov: &service.AuthorizerdMock{ + VerifyFunc: func(r *http.Request, act, res string) (authorizerd.Principal, error) { + return nil, errors.New("role token error") + }, + }, + cfg: config.Proxy{}, + noAuthPaths: []*policy.Assertion{ + wrapAssertion("/no-auth/wildcard/\\*"), + wrapAssertion("/no-auth/single/\\?"), + wrapAssertion("/no-auth/escape/\\\\"), + }, + }, + argss: []args{ + { + r: wrapRequest("GET", "http://athenz.io/no-auth/wildcard/*", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/single/?", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + { + r: wrapRequest("GET", "http://athenz.io/no-auth/escape/\\", nil), + body: &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + }, + }, + }, + wantErr: true, + wantCloseCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &transport{ + RoundTripper: tt.fields.RoundTripper, + prov: tt.fields.prov, + cfg: tt.fields.cfg, + noAuthPaths: tt.fields.noAuthPaths, + } + for _, args := range tt.argss { + if args.body != nil { + args.r.Body = args.body + } + got, err := tr.RoundTrip(args.r) + if (err != nil) != tt.wantErr { + t.Errorf("transport.RoundTrip() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("transport.RoundTrip() = %v, want %v", got, tt.want) + } + if args.body != nil { + if args.body.CloseCount != tt.wantCloseCount { + t.Errorf("Body was closed %d times, expected %d", args.body.CloseCount, tt.wantCloseCount) + } + } + } + }) + } +} diff --git a/test/data/example_config.yaml b/test/data/example_config.yaml index eabb86a..a3a30e5 100644 --- a/test/data/example_config.yaml +++ b/test/data/example_config.yaml @@ -29,6 +29,10 @@ proxy: port: 80 bufferSize: 4096 originHealthCheckPaths: [] + noAuthPaths: + - "/no-auth/any/*" + - "/no-auth/single/a?c" + - "/no-auth/no-regex/^$|([{" preserveHost: true transport: tlsHandshakeTimeout: "10s" From 67a3c7d3526e3780370e0e46cae234f9ac08f62f Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 1 Feb 2023 11:08:20 +0900 Subject: [PATCH 05/58] Add cert refresh period configulation Signed-off-by: Kyo Fujisaki --- config/config.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config/config.go b/config/config.go index f228032..2843d3e 100755 --- a/config/config.go +++ b/config/config.go @@ -89,6 +89,9 @@ type TLS struct { // CAPath represents the CA certificate chain file path for verifying client certificates. CAPath string `yaml:"caPath"` + + // CertRefreshPeriod represents + CertRefreshPeriod string `yaml:"certRefreshPeriod"` } // HealthCheck represents the health check server configuration. From d3d6a1888434c3942d22d14ec984c7a75dc82d8d Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 1 Feb 2023 11:09:25 +0900 Subject: [PATCH 06/58] Add parse cert refresh period and Add Refresh logic Signed-off-by: Kyo Fujisaki --- service/server.go | 89 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 3 deletions(-) diff --git a/service/server.go b/service/server.go index e15d269..5d57ea9 100644 --- a/service/server.go +++ b/service/server.go @@ -17,11 +17,15 @@ limitations under the License. package service import ( + "bytes" "context" + "crypto/sha256" + "crypto/tls" "fmt" "io" "net" "net/http" + "os" "strconv" "sync" "time" @@ -38,6 +42,7 @@ import ( // Server represents a authorization proxy server behavior type Server interface { ListenAndServe(context.Context) <-chan []error + RefreshCertificate(context.Context) error } type server struct { @@ -51,6 +56,13 @@ type server struct { grpcSrvRunning bool grpcCloser io.Closer + // server tls + srvCrt *tls.Certificate + srvCrtHash []byte + srvCrtKeyHash []byte + srvCrtMu sync.RWMutex + crtRefreshPeriod time.Duration + // Health Check server hcsrv *http.Server hcRunning bool @@ -107,7 +119,7 @@ func NewServer(opts ...Option) (Server, error) { } if s.cfg.TLS.Enable { - cfg, err := NewTLSConfig(s.cfg.TLS) + cfg, err := NewTLSConfig(s.cfg.TLS, s) if err != nil { return nil, err } @@ -140,6 +152,13 @@ func NewServer(opts ...Option) (Server, error) { s.dsrv.SetKeepAlivesEnabled(true) } + if s.cfg.TLS.CertRefreshPeriod != "" { + s.crtRefreshPeriod, err = time.ParseDuration(s.cfg.TLS.CertRefreshPeriod) + if err != nil { + glg.Warn(err) + } + } + s.sdt, err = time.ParseDuration(s.cfg.ShutdownTimeout) if err != nil { glg.Warn(err) @@ -393,8 +412,7 @@ func (s *server) listenAndServeAPI() error { if !s.cfg.TLS.Enable { return s.srv.ListenAndServe() } - - cfg, err := NewTLSConfig(s.cfg.TLS) + cfg, err := NewTLSConfig(s.cfg.TLS, s) if err == nil && cfg != nil { s.srv.TLSConfig = cfg } @@ -426,3 +444,68 @@ func (s *server) grpcSrvEnable() bool { func (s *server) debugSrvEnable() bool { return s.cfg.Debug.Enable } + +func hash(file string) ([]byte, error) { + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return nil, err + } + + return h.Sum(nil), nil +} + +// getCertificate return server TLS certificate. +func (s *server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { + s.srvCrtMu.RLock() + defer s.srvCrtMu.RUnlock() + return s.srvCrt, nil +} + +// RefreshCertificate is refresh certificate function. +func (s *server) RefreshCertificate(ctx context.Context) error { + ticker := time.NewTicker(s.crtRefreshPeriod) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + serverCertHash, err := hash(s.cfg.TLS.CertPath) + if err != nil { + glg.Error("Failed to refresh server certificate: %s.", err.Error()) + continue + } + serverCertKeyHash, err := hash(s.cfg.TLS.KeyPath) + if err != nil { + glg.Error("Failed to refresh server certificate: %s.", err.Error()) + continue + } + + s.srvCrtMu.Lock() + + different := !bytes.Equal(s.srvCrtHash, serverCertHash) || + !bytes.Equal(s.srvCrtKeyHash, serverCertKeyHash) + + if different { // load and store + newCert, err := tls.LoadX509KeyPair(s.cfg.TLS.CertPath, s.cfg.TLS.KeyPath) + if err != nil { + glg.Error("Failed to refresh server certificate: %s.", err.Error()) + s.srvCrtMu.Unlock() + continue + } + s.srvCrt = &newCert + s.srvCrtHash = serverCertHash + s.srvCrtKeyHash = serverCertKeyHash + glg.Info("Refreshed server certificate.") + } + + s.srvCrtMu.Unlock() + } + } +} From ac3201277f3ea7e8e888c2d8c1763203bf841d0a Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 1 Feb 2023 11:09:46 +0900 Subject: [PATCH 07/58] Run cert refresh logic Signed-off-by: Kyo Fujisaki --- usecase/authz_proxyd.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index ba16043..4fdc69a 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -136,6 +136,13 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { return baseErr }) + // handle cert refresh goroutine erorr + if _, err := time.ParseDuration(g.cfg.Server.TLS.CertRefreshPeriod); g.cfg.Server.TLS.Enable && g.cfg.Server.TLS.CertRefreshPeriod != "" && err == nil { + eg.Go(func() error { + return g.server.RefreshCertificate(ctx) + }) + } + // wait for shutdown, and summarize errors go func() { defer close(ech) From 0c786f2a3906bf10fd586ab13ff012aa5176cd43 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 1 Feb 2023 11:10:23 +0900 Subject: [PATCH 08/58] Load certificate Signed-off-by: Kyo Fujisaki --- service/tls.go | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/service/tls.go b/service/tls.go index e6385c7..320d9d8 100644 --- a/service/tls.go +++ b/service/tls.go @@ -29,7 +29,7 @@ import ( // It reads TLS configuration and initializes *tls.Config struct. // It initializes TLS configuration, for example the CA certificate and key to start TLS server. // Server and CA Certificate, and private key will read from files from file paths defined in environment variables. -func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { +func NewTLSConfig(cfg config.TLS, s *server) (*tls.Config, error) { t := &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{ @@ -65,7 +65,8 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { // tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only // tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only // }, - ClientAuth: tls.NoClientCert, + ClientAuth: tls.NoClientCert, + GetCertificate: s.getCertificate, } cert := config.GetActualValue(cfg.CertPath) @@ -73,12 +74,26 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { ca := config.GetActualValue(cfg.CAPath) if cert != "" && key != "" { + s.srvCrtMu.Lock() + defer s.srvCrtMu.Unlock() + crt, err := tls.LoadX509KeyPair(cert, key) if err != nil { return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") } - t.Certificates = make([]tls.Certificate, 1) - t.Certificates[0] = crt + + crtHash, err := hash(cert) + if err != nil { + return nil, errors.Wrap(err, "hash(cert)") + } + + crtKeyHash, err := hash(key) + if err != nil { + return nil, errors.Wrap(err, "hash(key)") + } + s.srvCrt = &crt + s.srvCrtHash = crtHash + s.srvCrtKeyHash = crtKeyHash } if ca != "" { @@ -89,7 +104,6 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { t.ClientCAs = pool t.ClientAuth = tls.RequireAndVerifyClientCert } - return t, nil } From 8e3ac7be5aef78bcc3b1303882ff198dc5c56971 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 1 Feb 2023 17:45:27 +0900 Subject: [PATCH 09/58] Add new server struct option Signed-off-by: Kyo Fujisaki --- service/option.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/service/option.go b/service/option.go index 3e67c36..ed0bf21 100644 --- a/service/option.go +++ b/service/option.go @@ -1,6 +1,7 @@ package service import ( + "crypto/tls" "io" "net/http" @@ -39,6 +40,13 @@ func WithGRPCCloser(c io.Closer) Option { } } +// WithTLSConfig returns a TLS Config functional option +func WithTLSConfig(t *tls.Config) Option { + return func(s *server) { + s.tlsConifg = t + } +} + // WithGRPCServer returns a gRPC Server functional option func WithGRPCServer(srv *grpc.Server) Option { return func(s *server) { From 4af61cf1ece1ced4e6aec25d27b86c294f6d7a70 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 1 Feb 2023 17:47:15 +0900 Subject: [PATCH 10/58] Add TLSCertificateCache, New function Signed-off-by: Kyo Fujisaki --- service/tls.go | 166 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 157 insertions(+), 9 deletions(-) diff --git a/service/tls.go b/service/tls.go index 320d9d8..795b2e0 100644 --- a/service/tls.go +++ b/service/tls.go @@ -17,19 +17,40 @@ limitations under the License. package service import ( + "bytes" + "context" "crypto/tls" "crypto/x509" "io/ioutil" + "sync" + "time" "github.com/AthenZ/authorization-proxy/v4/config" + "github.com/kpango/glg" "github.com/pkg/errors" ) +type TLSCertificateCache struct { + // server tls + serverCert *tls.Certificate + serverCertHash []byte + serverCertKeyHash []byte + serverCertPath string + serverCertKeyPath string + serverCertMutex sync.RWMutex + certRefreshPeriod time.Duration +} + +type TLSConfigWithTLSCertificateCache struct { + TLSConfig *tls.Config + TLSCertficateCache *TLSCertificateCache +} + // NewTLSConfig returns a *tls.Config struct or error. // It reads TLS configuration and initializes *tls.Config struct. // It initializes TLS configuration, for example the CA certificate and key to start TLS server. // Server and CA Certificate, and private key will read from files from file paths defined in environment variables. -func NewTLSConfig(cfg config.TLS, s *server) (*tls.Config, error) { +func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { t := &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{ @@ -65,8 +86,7 @@ func NewTLSConfig(cfg config.TLS, s *server) (*tls.Config, error) { // tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only // tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only // }, - ClientAuth: tls.NoClientCert, - GetCertificate: s.getCertificate, + ClientAuth: tls.NoClientCert, } cert := config.GetActualValue(cfg.CertPath) @@ -74,9 +94,74 @@ func NewTLSConfig(cfg config.TLS, s *server) (*tls.Config, error) { ca := config.GetActualValue(cfg.CAPath) if cert != "" && key != "" { - s.srvCrtMu.Lock() - defer s.srvCrtMu.Unlock() + crt, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") + } + t.Certificates = make([]tls.Certificate, 1) + t.Certificates[0] = crt + } + + if ca != "" { + pool, err := NewX509CertPool(ca) + if err != nil { + return nil, errors.Wrap(err, "NewX509CertPool(ca)") + } + t.ClientCAs = pool + t.ClientAuth = tls.RequireAndVerifyClientCert + } + + return t, nil +} + +func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCertificateCache, error) { + tcc := &TLSCertificateCache{} + t := &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + // PreferServerCipherSuites: true, + // CipherSuites: []uint16{ + // tls.TLS_RSA_WITH_RC4_128_SHA, + // tls.TLS_RSA_WITH_AES_128_CBC_SHA, + // tls.TLS_RSA_WITH_AES_256_CBC_SHA, + // tls.TLS_RSA_WITH_AES_128_CBC_SHA256, + // tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + // tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + // tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + // tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + // tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + // tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + // tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + // tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + // tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + // tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + // tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + // tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + // tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + // tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + // tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, // Maybe this is work on TLS 1.2 + // tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, // TLS1.3 Feature + // tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, // TLS1.3 Feature + // tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + // tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + // }, + ClientAuth: tls.NoClientCert, + GetCertificate: tcc.getCertificate, + } + var err error + + cert := config.GetActualValue(cfg.CertPath) + key := config.GetActualValue(cfg.KeyPath) + ca := config.GetActualValue(cfg.CAPath) + + if cert != "" && key != "" { crt, err := tls.LoadX509KeyPair(cert, key) if err != nil { return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") @@ -91,9 +176,18 @@ func NewTLSConfig(cfg config.TLS, s *server) (*tls.Config, error) { if err != nil { return nil, errors.Wrap(err, "hash(key)") } - s.srvCrt = &crt - s.srvCrtHash = crtHash - s.srvCrtKeyHash = crtKeyHash + tcc.serverCert = &crt + tcc.serverCertHash = crtHash + tcc.serverCertKeyHash = crtKeyHash + tcc.serverCertPath = cert + tcc.serverCertKeyPath = key + } + + if cfg.CertRefreshPeriod != "" { + tcc.certRefreshPeriod, err = time.ParseDuration(cfg.CertRefreshPeriod) + if err != nil { + return nil, errors.Wrap(err, "ParseDuration(cfg.CertRefreshPeriod)") + } } if ca != "" { @@ -104,7 +198,11 @@ func NewTLSConfig(cfg config.TLS, s *server) (*tls.Config, error) { t.ClientCAs = pool t.ClientAuth = tls.RequireAndVerifyClientCert } - return t, nil + + return &TLSConfigWithTLSCertificateCache{ + TLSConfig: t, + TLSCertficateCache: tcc, + }, nil } // NewX509CertPool returns *x509.CertPool struct or error. @@ -123,3 +221,53 @@ func NewX509CertPool(path string) (*x509.CertPool, error) { } return pool, errors.Wrap(err, "x509.SystemCertPool()") } + +// getCertificate return server TLS certificate. +func (tcc *TLSCertificateCache) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { + tcc.serverCertMutex.RLock() + defer tcc.serverCertMutex.RUnlock() + return tcc.serverCert, nil +} + +// RefreshCertificate is refresh certificate function. +func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { + ticker := time.NewTicker(tcc.certRefreshPeriod) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + serverCertHash, err := hash(tcc.serverCertPath) + if err != nil { + glg.Error("Failed to refresh server certificate: %s.", err.Error()) + continue + } + serverCertKeyHash, err := hash(tcc.serverCertKeyPath) + if err != nil { + glg.Error("Failed to refresh server certificate: %s.", err.Error()) + continue + } + + tcc.serverCertMutex.Lock() + + different := !bytes.Equal(tcc.serverCertHash, serverCertHash) || + !bytes.Equal(tcc.serverCertKeyHash, serverCertKeyHash) + + if different { // load and store + newCert, err := tls.LoadX509KeyPair(tcc.serverCertPath, tcc.serverCertKeyPath) + if err != nil { + glg.Error("Failed to refresh server certificate: %s.", err.Error()) + tcc.serverCertMutex.Unlock() + continue + } + tcc.serverCert = &newCert + tcc.serverCertHash = serverCertHash + tcc.serverCertKeyHash = serverCertKeyHash + glg.Info("Refreshed server certificate.") + } + + tcc.serverCertMutex.Unlock() + } + } +} From 54a97d45b7c03f9bace40afc67a1c929448c9c4f Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 1 Feb 2023 17:48:18 +0900 Subject: [PATCH 11/58] Move TLS config setting to authz_proxyd.go Signed-off-by: Kyo Fujisaki --- service/server.go | 84 +++-------------------------------------- usecase/authz_proxyd.go | 45 ++++++++++++++++------ 2 files changed, 40 insertions(+), 89 deletions(-) diff --git a/service/server.go b/service/server.go index 5d57ea9..c91df73 100644 --- a/service/server.go +++ b/service/server.go @@ -17,7 +17,6 @@ limitations under the License. package service import ( - "bytes" "context" "crypto/sha256" "crypto/tls" @@ -42,7 +41,6 @@ import ( // Server represents a authorization proxy server behavior type Server interface { ListenAndServe(context.Context) <-chan []error - RefreshCertificate(context.Context) error } type server struct { @@ -56,12 +54,7 @@ type server struct { grpcSrvRunning bool grpcCloser io.Closer - // server tls - srvCrt *tls.Certificate - srvCrtHash []byte - srvCrtKeyHash []byte - srvCrtMu sync.RWMutex - crtRefreshPeriod time.Duration + tlsConifg *tls.Config // Health Check server hcsrv *http.Server @@ -119,12 +112,7 @@ func NewServer(opts ...Option) (Server, error) { } if s.cfg.TLS.Enable { - cfg, err := NewTLSConfig(s.cfg.TLS, s) - if err != nil { - return nil, err - } - - gopts = append(gopts, grpc.Creds(credentials.NewTLS(cfg))) + gopts = append(gopts, grpc.Creds(credentials.NewTLS(s.tlsConifg))) } s.grpcSrv = grpc.NewServer(gopts...) @@ -134,6 +122,9 @@ func NewServer(opts ...Option) (Server, error) { Handler: s.srvHandler, } s.srv.SetKeepAlivesEnabled(true) + if s.cfg.TLS.Enable { + s.srv.TLSConfig = s.tlsConifg + } } if s.hcSrvEnable() { @@ -152,13 +143,6 @@ func NewServer(opts ...Option) (Server, error) { s.dsrv.SetKeepAlivesEnabled(true) } - if s.cfg.TLS.CertRefreshPeriod != "" { - s.crtRefreshPeriod, err = time.ParseDuration(s.cfg.TLS.CertRefreshPeriod) - if err != nil { - glg.Warn(err) - } - } - s.sdt, err = time.ParseDuration(s.cfg.ShutdownTimeout) if err != nil { glg.Warn(err) @@ -412,13 +396,7 @@ func (s *server) listenAndServeAPI() error { if !s.cfg.TLS.Enable { return s.srv.ListenAndServe() } - cfg, err := NewTLSConfig(s.cfg.TLS, s) - if err == nil && cfg != nil { - s.srv.TLSConfig = cfg - } - if err != nil { - glg.Error(errors.Wrap(err, "cannot NewTLSConfig(s.cfg.TLS)")) - } + return s.srv.ListenAndServeTLS("", "") } @@ -459,53 +437,3 @@ func hash(file string) ([]byte, error) { return h.Sum(nil), nil } - -// getCertificate return server TLS certificate. -func (s *server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { - s.srvCrtMu.RLock() - defer s.srvCrtMu.RUnlock() - return s.srvCrt, nil -} - -// RefreshCertificate is refresh certificate function. -func (s *server) RefreshCertificate(ctx context.Context) error { - ticker := time.NewTicker(s.crtRefreshPeriod) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return nil - case <-ticker.C: - serverCertHash, err := hash(s.cfg.TLS.CertPath) - if err != nil { - glg.Error("Failed to refresh server certificate: %s.", err.Error()) - continue - } - serverCertKeyHash, err := hash(s.cfg.TLS.KeyPath) - if err != nil { - glg.Error("Failed to refresh server certificate: %s.", err.Error()) - continue - } - - s.srvCrtMu.Lock() - - different := !bytes.Equal(s.srvCrtHash, serverCertHash) || - !bytes.Equal(s.srvCrtKeyHash, serverCertKeyHash) - - if different { // load and store - newCert, err := tls.LoadX509KeyPair(s.cfg.TLS.CertPath, s.cfg.TLS.KeyPath) - if err != nil { - glg.Error("Failed to refresh server certificate: %s.", err.Error()) - s.srvCrtMu.Unlock() - continue - } - s.srvCrt = &newCert - s.srvCrtHash = serverCertHash - s.srvCrtKeyHash = serverCertKeyHash - glg.Info("Refreshed server certificate.") - } - - s.srvCrtMu.Unlock() - } - } -} diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index 4fdc69a..02a1893 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -42,10 +42,11 @@ type AuthzProxyDaemon interface { } type authzProxyDaemon struct { - cfg config.Config - athenz service.Authorizationd - server service.Server - grpcServer service.Server + cfg config.Config + athenz service.Authorizationd + server service.Server + grpcServer service.Server + tlsCertificateCache *service.TLSCertificateCache } // New returns a Authorization Proxy daemon, or error occurred. @@ -64,21 +65,43 @@ func New(cfg config.Config) (AuthzProxyDaemon, error) { handler.WithAuthorizationd(athenz), ) - srv, err := service.NewServer( + serverOption := []service.Option{ service.WithServerConfig(cfg.Server), service.WithRestHandler(handler.New(cfg.Proxy, infra.NewBuffer(cfg.Proxy.BufferSize), athenz)), service.WithDebugHandler(debugMux), service.WithGRPCHandler(gh), service.WithGRPCCloser(closer), - ) + } + + var tlsConfig *tls.Config + var tlsCertificateCache *service.TLSCertificateCache + if cfg.Server.TLS.Enable { + if cfg.Server.TLS.CertRefreshPeriod != "" { + configWithCache, err := service.NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS) + if err != nil { + return nil, errors.Wrap(err, "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS)") + } + tlsConfig = configWithCache.TLSConfig + tlsCertificateCache = configWithCache.TLSCertficateCache + } else { + tlsConfig, err = service.NewTLSConfig(cfg.Server.TLS) + if err != nil { + return nil, errors.Wrap(err, "cannot NewTLSConfig(cfg.Server.TLS)") + } + } + serverOption = append(serverOption, service.WithTLSConfig(tlsConfig)) + } + + srv, err := service.NewServer(serverOption...) if err != nil { return nil, err } return &authzProxyDaemon{ - cfg: cfg, - athenz: athenz, - server: srv, + cfg: cfg, + athenz: athenz, + server: srv, + tlsCertificateCache: tlsCertificateCache, }, nil } @@ -137,9 +160,9 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { }) // handle cert refresh goroutine erorr - if _, err := time.ParseDuration(g.cfg.Server.TLS.CertRefreshPeriod); g.cfg.Server.TLS.Enable && g.cfg.Server.TLS.CertRefreshPeriod != "" && err == nil { + if g.cfg.Server.TLS.Enable && g.cfg.Server.TLS.CertRefreshPeriod != "" { eg.Go(func() error { - return g.server.RefreshCertificate(ctx) + return g.tlsCertificateCache.RefreshCertificate(ctx) }) } From ee84eab2ab186ad02a48f00a6aa84e59bfb51841 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 1 Feb 2023 18:23:15 +0900 Subject: [PATCH 12/58] Add comments Signed-off-by: Kyo Fujisaki --- service/tls.go | 9 +++++++-- usecase/authz_proxyd.go | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/service/tls.go b/service/tls.go index 795b2e0..0b4bba2 100644 --- a/service/tls.go +++ b/service/tls.go @@ -30,8 +30,8 @@ import ( "github.com/pkg/errors" ) +// TLSCertificateCache represents refresh certificate type TLSCertificateCache struct { - // server tls serverCert *tls.Certificate serverCertHash []byte serverCertKeyHash []byte @@ -114,6 +114,11 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { return t, nil } +// NewTLSConfigWithTLSCertificateCache returns a *TLSConfigWithTLSCertificateCache struct or error. +// It use to enable the certificate auto-reload feature. +// It reads TLS configuration and initializes *tls.Config / TLSCertificateCache struct. +// It initializes TLS configuration, for example the CA certificate and key to start TLS server. +// Server and CA Certificate, and private key will read from files from file paths defined in environment variables. func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCertificateCache, error) { tcc := &TLSCertificateCache{} t := &tls.Config{ @@ -229,7 +234,7 @@ func (tcc *TLSCertificateCache) getCertificate(h *tls.ClientHelloInfo) (*tls.Cer return tcc.serverCert, nil } -// RefreshCertificate is refresh certificate function. +// RefreshCertificate is refresh certificate for TLS. func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { ticker := time.NewTicker(tcc.certRefreshPeriod) defer ticker.Stop() diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index 02a1893..36c6224 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -76,6 +76,7 @@ func New(cfg config.Config) (AuthzProxyDaemon, error) { var tlsConfig *tls.Config var tlsCertificateCache *service.TLSCertificateCache if cfg.Server.TLS.Enable { + // Enable auto-reload if CertRefreshPeriod is set. if cfg.Server.TLS.CertRefreshPeriod != "" { configWithCache, err := service.NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS) if err != nil { @@ -160,6 +161,7 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { }) // handle cert refresh goroutine erorr + // prevent run RefreshCertificate if Enable is false and CertRefreshPeriod is set if g.cfg.Server.TLS.Enable && g.cfg.Server.TLS.CertRefreshPeriod != "" { eg.Go(func() error { return g.tlsCertificateCache.RefreshCertificate(ctx) From d907cb14b068f3b36f33d9fd8f17ce7b75830f0b Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 2 Feb 2023 15:25:28 +0900 Subject: [PATCH 13/58] Move TLS cert invalid test Signed-off-by: Kyo Fujisaki --- service/server_test.go | 28 ------------------------ usecase/authz_proxyd_test.go | 42 +++++++++++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/service/server_test.go b/service/server_test.go index 42c1007..fa2aafe 100644 --- a/service/server_test.go +++ b/service/server_test.go @@ -202,34 +202,6 @@ func TestNewServer(t *testing.T) { return nil }, }, - { - name: "return error when grpc TLS cert invalid", - args: args{ - opts: []Option{ - WithGRPCHandler(func(srv interface{}, stream grpc.ServerStream) error { - return nil - }), - WithServerConfig(config.Server{ - Port: 9999, - TLS: config.TLS{ - Enable: true, - CertPath: "../test/data/invalid_dummyServer.crt", - KeyPath: "../test/data/invalid_dummyServer.key", - }, - }), - }, - }, - wantErr: errors.New("tls.LoadX509KeyPair(cert, key): tls: failed to find any PEM data in certificate input"), - checkFunc: func(got, want Server, gotErr, wantErr error) error { - if gotErr.Error() != wantErr.Error() { - return errors.Errorf("got error is not matched with want error, got: %v, want: %v", gotErr, wantErr) - } - if !reflect.DeepEqual(got, want) { - return fmt.Errorf("not matched") - } - return nil - }, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index b85abc3..450e853 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -21,10 +21,11 @@ func TestNew(t *testing.T) { cfg config.Config } type test struct { - name string - args args - checkFunc func(AuthzProxyDaemon) error - wantErr bool + name string + args args + checkFunc func(AuthzProxyDaemon) error + wantErr bool + wantErrStr string } tests := []test{ func() test { @@ -91,7 +92,28 @@ func TestNew(t *testing.T) { }, }, }, - wantErr: true, + wantErr: true, + wantErrStr: "cannot newAuthzD(cfg): error create pubkeyd: invalid refresh period: time: invalid duration \"dummy\"", + }, { + name: "return error when grpc TLS cert invalid", + args: args{ + cfg: config.Config{ + Authorization: config.Authorization{ + RoleToken: config.RoleToken{ + Enable: true, + }, + }, + Server: config.Server{ + TLS: config.TLS{ + Enable: true, + CertPath: "../test/data/invalid_dummyServer.crt", + KeyPath: "../test/data/invalid_dummyServer.key", + }, + }, + }, + }, + wantErr: true, + wantErrStr: "cannot NewTLSConfig(cfg.Server.TLS): tls.LoadX509KeyPair(cert, key): tls: failed to find any PEM data in certificate input", }, } for _, tt := range tests { @@ -101,6 +123,16 @@ func TestNew(t *testing.T) { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.wantErr { + if err == nil { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err.Error() != tt.wantErrStr { + t.Errorf("New() error = %v, wantErrStr = %v", err, tt.wantErrStr) + return + } + } if tt.checkFunc != nil { if err = tt.checkFunc(got); err != nil { t.Errorf("New() error = %v", err) From 0fe265d103e91de525c411733a8a10826ddcc521 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 2 Feb 2023 18:39:09 +0900 Subject: [PATCH 14/58] Add WithTLSConfig test Signed-off-by: Kyo Fujisaki --- service/option_test.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/service/option_test.go b/service/option_test.go index 670e2a5..9ecf976 100644 --- a/service/option_test.go +++ b/service/option_test.go @@ -1,6 +1,7 @@ package service import ( + "crypto/tls" "io" "net/http" "net/http/httptest" @@ -206,6 +207,42 @@ func TestWithGRPCServer(t *testing.T) { } } +func TestWithTLSConfig(t *testing.T) { + type args struct { + t *tls.Config + } + tests := []struct { + name string + args args + checkFunc func(Option) error + }{ + { + name: "set success", + args: args{ + t: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + }, + checkFunc: func(o Option) error { + srv := &server{} + o(srv) + if srv.tlsConifg.MinVersion != tls.VersionTLS12 { + return errors.New("value cannot set") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithTLSConfig(tt.args.t) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithTLSConfig() error = %v", err) + } + }) + } +} + func TestWithDebugHandler(t *testing.T) { type args struct { h http.Handler From 4755b3a93cba6f84abdcfd87a0cc1f2a75e77b65 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 2 Feb 2023 18:42:52 +0900 Subject: [PATCH 15/58] Move hash() function Signed-off-by: Kyo Fujisaki --- service/server.go | 17 ----------------- service/tls.go | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/service/server.go b/service/server.go index c91df73..5275940 100644 --- a/service/server.go +++ b/service/server.go @@ -18,13 +18,11 @@ package service import ( "context" - "crypto/sha256" "crypto/tls" "fmt" "io" "net" "net/http" - "os" "strconv" "sync" "time" @@ -422,18 +420,3 @@ func (s *server) grpcSrvEnable() bool { func (s *server) debugSrvEnable() bool { return s.cfg.Debug.Enable } - -func hash(file string) ([]byte, error) { - f, err := os.Open(file) - if err != nil { - return nil, err - } - defer f.Close() - - h := sha256.New() - if _, err := io.Copy(h, f); err != nil { - return nil, err - } - - return h.Sum(nil), nil -} diff --git a/service/tls.go b/service/tls.go index 0b4bba2..18705ce 100644 --- a/service/tls.go +++ b/service/tls.go @@ -19,9 +19,12 @@ package service import ( "bytes" "context" + "crypto/sha256" "crypto/tls" "crypto/x509" + "io" "io/ioutil" + "os" "sync" "time" @@ -276,3 +279,18 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { } } } + +func hash(file string) ([]byte, error) { + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return nil, err + } + + return h.Sum(nil), nil +} From 8bcf71504fcb5fa096c3d9735763ed2659b53426 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 2 Feb 2023 19:16:21 +0900 Subject: [PATCH 16/58] Add enable / tlsConfig check Signed-off-by: Kyo Fujisaki --- service/server.go | 4 ++++ service/server_test.go | 1 + 2 files changed, 5 insertions(+) diff --git a/service/server.go b/service/server.go index 5275940..b73dd9b 100644 --- a/service/server.go +++ b/service/server.go @@ -103,6 +103,10 @@ func NewServer(opts ...Option) (Server, error) { o(s) } + if s.cfg.TLS.Enable && s.tlsConifg == nil { + return nil, errors.New("s.cfg.TLS.Enable is true, but s.tlsConifg is nil.") + } + if s.grpcSrvEnable() { gopts := []grpc.ServerOption{ grpc.CustomCodec(proxy.Codec()), diff --git a/service/server_test.go b/service/server_test.go index fa2aafe..3d3f708 100644 --- a/service/server_test.go +++ b/service/server_test.go @@ -171,6 +171,7 @@ func TestNewServer(t *testing.T) { KeyPath: "../test/data/dummyServer.key", }, }), + WithTLSConfig(&tls.Config{}), }, }, want: &server{ From cdc7e63af968025e87222d49487b3ec3598ba8b0 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Fri, 3 Feb 2023 21:20:11 +0900 Subject: [PATCH 17/58] Add HTTPS server test Signed-off-by: Kyo Fujisaki --- service/server_test.go | 74 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/service/server_test.go b/service/server_test.go index 3d3f708..cf769fc 100644 --- a/service/server_test.go +++ b/service/server_test.go @@ -3,6 +3,7 @@ package service import ( "context" "crypto/tls" + "crypto/x509" "fmt" "io" "io/ioutil" @@ -129,6 +130,67 @@ func TestNewServer(t *testing.T) { return nil }, }, + { + name: "Check HTTPS server address and certificate", + args: args{ + opts: []Option{ + WithServerConfig(config.Server{ + Port: 9999, + TLS: config.TLS{ + Enable: true, + }, + HealthCheck: config.HealthCheck{ + Port: 8080, + Endpoint: "/healthz", + }, + }), + WithRestHandler(func() http.Handler { + return nil + }()), + WithTLSConfig(func() *tls.Config { + cfg, err := NewTLSConfig(config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + }) + if err != nil { + return nil + } + return cfg + }()), + }, + }, + want: &server{ + srv: &http.Server{ + Addr: fmt.Sprintf(":%d", 9999), + TLSConfig: func() *tls.Config { + cfg, err := NewTLSConfig(config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + }) + if err != nil { + return nil + } + return cfg + }(), + }, + }, + checkFunc: func(got, want Server, gotErr, wantErr error) error { + if !errors.Is(gotErr, wantErr) { + return errors.Errorf("got error is not matched with want error, got: %s, want: %s", gotErr, wantErr) + } + if got.(*server).srv.Addr != want.(*server).srv.Addr { + return fmt.Errorf("Server Addr not equals\tgot: %s\twant: %s", got.(*server).srv.Addr, want.(*server).srv.Addr) + } + gotCert, _ := x509.ParseCertificate(got.(*server).srv.TLSConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.(*server).srv.TLSConfig.Certificates[0].Certificate[0]) + if gotCert.SerialNumber == nil || gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { + return fmt.Errorf("Certificate SerialNumber not equals\tgot: %s\twant: %s", got.(*server).srv.TLSConfig.Certificates[0].Leaf.Subject.CommonName, want.(*server).srv.TLSConfig.Certificates[0].Leaf.Subject.CommonName) + } + return nil + }, + }, { name: "Check GRPC server not nil", args: args{ @@ -171,7 +233,17 @@ func TestNewServer(t *testing.T) { KeyPath: "../test/data/dummyServer.key", }, }), - WithTLSConfig(&tls.Config{}), + WithTLSConfig(func() *tls.Config { + cfg, err := NewTLSConfig(config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + }) + if err != nil { + return nil + } + return cfg + }()), }, }, want: &server{ From e78e4952a3fe25c60587457632e238b7aa90a03d Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Mon, 6 Feb 2023 13:57:37 +0900 Subject: [PATCH 18/58] Use atomic.Value cache Signed-off-by: Kyo Fujisaki --- service/tls.go | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/service/tls.go b/service/tls.go index 18705ce..bbe75e4 100644 --- a/service/tls.go +++ b/service/tls.go @@ -25,7 +25,7 @@ import ( "io" "io/ioutil" "os" - "sync" + "sync/atomic" "time" "github.com/AthenZ/authorization-proxy/v4/config" @@ -35,12 +35,11 @@ import ( // TLSCertificateCache represents refresh certificate type TLSCertificateCache struct { - serverCert *tls.Certificate + serverCert atomic.Value serverCertHash []byte serverCertKeyHash []byte serverCertPath string serverCertKeyPath string - serverCertMutex sync.RWMutex certRefreshPeriod time.Duration } @@ -184,7 +183,7 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti if err != nil { return nil, errors.Wrap(err, "hash(key)") } - tcc.serverCert = &crt + tcc.serverCert.Store(&crt) tcc.serverCertHash = crtHash tcc.serverCertKeyHash = crtKeyHash tcc.serverCertPath = cert @@ -232,9 +231,7 @@ func NewX509CertPool(path string) (*x509.CertPool, error) { // getCertificate return server TLS certificate. func (tcc *TLSCertificateCache) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { - tcc.serverCertMutex.RLock() - defer tcc.serverCertMutex.RUnlock() - return tcc.serverCert, nil + return tcc.serverCert.Load().(*tls.Certificate), nil } // RefreshCertificate is refresh certificate for TLS. @@ -246,6 +243,7 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { case <-ctx.Done(): return nil case <-ticker.C: + glg.Info("Checking to refresh server certificate.") serverCertHash, err := hash(tcc.serverCertPath) if err != nil { glg.Error("Failed to refresh server certificate: %s.", err.Error()) @@ -257,25 +255,20 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { continue } - tcc.serverCertMutex.Lock() - different := !bytes.Equal(tcc.serverCertHash, serverCertHash) || !bytes.Equal(tcc.serverCertKeyHash, serverCertKeyHash) - if different { // load and store + if different { newCert, err := tls.LoadX509KeyPair(tcc.serverCertPath, tcc.serverCertKeyPath) if err != nil { glg.Error("Failed to refresh server certificate: %s.", err.Error()) - tcc.serverCertMutex.Unlock() continue } - tcc.serverCert = &newCert + tcc.serverCert.Store(&newCert) tcc.serverCertHash = serverCertHash tcc.serverCertKeyHash = serverCertKeyHash glg.Info("Refreshed server certificate.") } - - tcc.serverCertMutex.Unlock() } } } From 7303bc68163963f76fad7f1924eaa1079a84baa0 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Mon, 6 Feb 2023 15:20:46 +0900 Subject: [PATCH 19/58] Add NewServer error test Signed-off-by: Kyo Fujisaki --- service/server_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/service/server_test.go b/service/server_test.go index cf769fc..bf6c567 100644 --- a/service/server_test.go +++ b/service/server_test.go @@ -275,6 +275,30 @@ func TestNewServer(t *testing.T) { return nil }, }, + { + name: "Check TLS.Enable is true and tlsConfig is nil, return error", + args: args{ + opts: []Option{ + WithServerConfig(config.Server{ + Port: 9999, + TLS: config.TLS{ + Enable: true, + }, + }), + }, + }, + want: nil, + wantErr: errors.New("s.cfg.TLS.Enable is true, but s.tlsConifg is nil."), + checkFunc: func(got, want Server, gotErr, wantErr error) error { + if gotErr.Error() != wantErr.Error() { + return errors.Errorf("got error is not matched with want error, got: %s, want: %s", gotErr, wantErr) + } + if got != nil { + return fmt.Errorf("want: nil, got: %s", got) + } + return nil + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From f5f796fb0d950f0c7670df2aff60092bb629aa4e Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Mon, 6 Feb 2023 17:32:52 +0900 Subject: [PATCH 20/58] Add lock for update Signed-off-by: Kyo Fujisaki --- service/tls.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/service/tls.go b/service/tls.go index bbe75e4..3b9d2d5 100644 --- a/service/tls.go +++ b/service/tls.go @@ -25,6 +25,7 @@ import ( "io" "io/ioutil" "os" + "sync" "sync/atomic" "time" @@ -40,6 +41,7 @@ type TLSCertificateCache struct { serverCertKeyHash []byte serverCertPath string serverCertKeyPath string + serverCertMutex sync.Mutex certRefreshPeriod time.Duration } @@ -231,6 +233,7 @@ func NewX509CertPool(path string) (*x509.CertPool, error) { // getCertificate return server TLS certificate. func (tcc *TLSCertificateCache) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { + // serverCert is atomic.Value, so this can read it without lock. return tcc.serverCert.Load().(*tls.Certificate), nil } @@ -254,7 +257,9 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { glg.Error("Failed to refresh server certificate: %s.", err.Error()) continue } - + // A lock for when there are other features to update. + // serverCert is atomic.Value, so this can read it without lock. + tcc.serverCertMutex.Lock() different := !bytes.Equal(tcc.serverCertHash, serverCertHash) || !bytes.Equal(tcc.serverCertKeyHash, serverCertKeyHash) @@ -262,6 +267,7 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { newCert, err := tls.LoadX509KeyPair(tcc.serverCertPath, tcc.serverCertKeyPath) if err != nil { glg.Error("Failed to refresh server certificate: %s.", err.Error()) + tcc.serverCertMutex.Unlock() continue } tcc.serverCert.Store(&newCert) @@ -269,6 +275,7 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { tcc.serverCertKeyHash = serverCertKeyHash glg.Info("Refreshed server certificate.") } + tcc.serverCertMutex.Unlock() } } } From 6511e47cc1bcded52206a128bf7f1cd1a9b0f1cd Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Mon, 6 Feb 2023 17:36:21 +0900 Subject: [PATCH 21/58] Delete comment Signed-off-by: Kyo Fujisaki --- service/tls.go | 58 +++----------------------------------------------- 1 file changed, 3 insertions(+), 55 deletions(-) diff --git a/service/tls.go b/service/tls.go index 3b9d2d5..b96e00d 100644 --- a/service/tls.go +++ b/service/tls.go @@ -64,33 +64,7 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { tls.X25519, }, SessionTicketsDisabled: true, - // PreferServerCipherSuites: true, - // CipherSuites: []uint16{ - // tls.TLS_RSA_WITH_RC4_128_SHA, - // tls.TLS_RSA_WITH_AES_128_CBC_SHA, - // tls.TLS_RSA_WITH_AES_256_CBC_SHA, - // tls.TLS_RSA_WITH_AES_128_CBC_SHA256, - // tls.TLS_RSA_WITH_AES_128_GCM_SHA256, - // tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - // tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - // tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - // tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - // tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, - // tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - // tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - // tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - // tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - // tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - // tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - // tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - // tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - // tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, // Maybe this is work on TLS 1.2 - // tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, // TLS1.3 Feature - // tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, // TLS1.3 Feature - // tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - // tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - // }, - ClientAuth: tls.NoClientCert, + ClientAuth: tls.NoClientCert, } cert := config.GetActualValue(cfg.CertPath) @@ -134,34 +108,8 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti tls.X25519, }, SessionTicketsDisabled: true, - // PreferServerCipherSuites: true, - // CipherSuites: []uint16{ - // tls.TLS_RSA_WITH_RC4_128_SHA, - // tls.TLS_RSA_WITH_AES_128_CBC_SHA, - // tls.TLS_RSA_WITH_AES_256_CBC_SHA, - // tls.TLS_RSA_WITH_AES_128_CBC_SHA256, - // tls.TLS_RSA_WITH_AES_128_GCM_SHA256, - // tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - // tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - // tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - // tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - // tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, - // tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - // tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - // tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - // tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - // tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - // tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - // tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - // tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - // tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, // Maybe this is work on TLS 1.2 - // tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, // TLS1.3 Feature - // tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, // TLS1.3 Feature - // tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - // tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - // }, - ClientAuth: tls.NoClientCert, - GetCertificate: tcc.getCertificate, + ClientAuth: tls.NoClientCert, + GetCertificate: tcc.getCertificate, } var err error From 17aa4623dbc93eb7e480482b8f632f72d8f89bd2 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Mon, 6 Feb 2023 20:20:00 +0900 Subject: [PATCH 22/58] Add TestNewTLSConfigWithTLSCertificateCache Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 487 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 481 insertions(+), 6 deletions(-) diff --git a/service/tls_test.go b/service/tls_test.go index d8c278b..dad4958 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -7,7 +7,9 @@ import ( "fmt" "io/ioutil" "strings" + "sync/atomic" "testing" + "time" "github.com/AthenZ/authorization-proxy/v4/config" ) @@ -49,7 +51,7 @@ func TestNewTLSConfig(t *testing.T) { }, SessionTicketsDisabled: true, Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.CertPath, defaultArgs.KeyPath) + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) return []tls.Certificate{cert} }(), ClientAuth: tls.RequireAndVerifyClientCert, @@ -74,7 +76,7 @@ func TestNewTLSConfig(t *testing.T) { }, SessionTicketsDisabled: true, Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.CertPath, defaultArgs.KeyPath) + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) return []tls.Certificate{cert} }(), ClientAuth: tls.RequireAndVerifyClientCert, @@ -112,7 +114,7 @@ func TestNewTLSConfig(t *testing.T) { }, SessionTicketsDisabled: true, Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.CertPath, defaultArgs.KeyPath) + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) return []tls.Certificate{cert} }(), ClientAuth: tls.RequireAndVerifyClientCert, @@ -137,7 +139,7 @@ func TestNewTLSConfig(t *testing.T) { }, SessionTicketsDisabled: true, Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.CertPath, defaultArgs.KeyPath) + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) return []tls.Certificate{cert} }(), ClientAuth: tls.RequireAndVerifyClientCert, @@ -171,7 +173,7 @@ func TestNewTLSConfig(t *testing.T) { }, SessionTicketsDisabled: true, Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.CertPath, defaultArgs.KeyPath) + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) return []tls.Certificate{cert} }(), ClientAuth: tls.RequireAndVerifyClientCert, @@ -254,7 +256,7 @@ func TestNewTLSConfig(t *testing.T) { }, SessionTicketsDisabled: true, Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.CertPath, defaultArgs.KeyPath) + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) return []tls.Certificate{cert} }(), ClientAuth: tls.RequireAndVerifyClientCert, @@ -294,6 +296,479 @@ func TestNewTLSConfig(t *testing.T) { } } +func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { + type args struct { + CertPath string + KeyPath string + CAPath string + cfg config.TLS + } + defaultArgs := args{ + cfg: config.TLS{ + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + CAPath: "../test/data/dummyCa.pem", + }, + } + var defaultServerCert atomic.Value + defaultServerCertData, err := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + if err != nil { + t.Errorf("LoadX509KeyPair failed: %s", err) + return + } + defaultServerCert.Store(&defaultServerCertData) + defaultServerCerttHash, err := hash(defaultArgs.cfg.CertPath) + if err != nil { + t.Errorf("hash failed: %s", err) + return + } + defaultServerCerttKeyHash, _ := hash(defaultArgs.cfg.KeyPath) + if err != nil { + t.Errorf("hash failed: %s", err) + return + } + + tests := []struct { + name string + args args + want *TLSConfigWithTLSCertificateCache + beforeFunc func(args args) + checkFunc func(*TLSConfigWithTLSCertificateCache, *TLSConfigWithTLSCertificateCache) error + afterFunc func(args args) + wantErr error + }{ + { + name: "return value MinVersion test.", + args: defaultArgs, + want: &TLSConfigWithTLSCertificateCache{ + &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, + }, + &TLSCertificateCache{ + serverCert: defaultServerCert, + serverCertHash: defaultServerCerttHash, + serverCertKeyHash: defaultServerCerttKeyHash, + serverCertPath: defaultArgs.cfg.CertPath, + serverCertKeyPath: defaultArgs.cfg.KeyPath, + certRefreshPeriod: 0, + }, + }, + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if got.TLSConfig.MinVersion != want.TLSConfig.MinVersion { + return fmt.Errorf("MinVersion not Matched :\tgot %d\twant %d", got.TLSConfig.MinVersion, want.TLSConfig.MinVersion) + } + gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { + return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) + } + if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + } + return nil + }, + }, + { + name: "return value CurvePreferences test.", + args: defaultArgs, + want: &TLSConfigWithTLSCertificateCache{&tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, + }, + &TLSCertificateCache{ + serverCert: defaultServerCert, + serverCertHash: defaultServerCerttHash, + serverCertKeyHash: defaultServerCerttKeyHash, + serverCertPath: defaultArgs.cfg.CertPath, + serverCertKeyPath: defaultArgs.cfg.KeyPath, + certRefreshPeriod: 0, + }, + }, + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if len(got.TLSConfig.CurvePreferences) != len(want.TLSConfig.CurvePreferences) { + return fmt.Errorf("CurvePreferences not Matched length:\tgot %d\twant %d", len(got.TLSConfig.CurvePreferences), len(want.TLSConfig.CurvePreferences)) + } + for _, actualValue := range got.TLSConfig.CurvePreferences { + var match bool + for _, expectedValue := range want.TLSConfig.CurvePreferences { + if actualValue == expectedValue { + match = true + break + } + } + + if !match { + return fmt.Errorf("CurvePreferences not Find :\twant %d", want.TLSConfig.MinVersion) + } + } + gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { + return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) + } + if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + } + return nil + }, + }, + { + name: "return value SessionTicketsDisabled test.", + args: defaultArgs, + want: &TLSConfigWithTLSCertificateCache{ + &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, + }, + &TLSCertificateCache{ + serverCert: defaultServerCert, + serverCertHash: defaultServerCerttHash, + serverCertKeyHash: defaultServerCerttKeyHash, + serverCertPath: defaultArgs.cfg.CertPath, + serverCertKeyPath: defaultArgs.cfg.KeyPath, + certRefreshPeriod: 0, + }, + }, + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if got.TLSConfig.SessionTicketsDisabled != want.TLSConfig.SessionTicketsDisabled { + return fmt.Errorf("SessionTicketsDisabled not matched :\tgot %v\twant %v", got.TLSConfig.SessionTicketsDisabled, want.TLSConfig.SessionTicketsDisabled) + } + gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { + return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) + } + if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + } + return nil + }, + }, + { + name: "return value Certificates test.", + args: defaultArgs, + want: &TLSConfigWithTLSCertificateCache{ + &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, + }, + &TLSCertificateCache{ + serverCert: defaultServerCert, + serverCertHash: defaultServerCerttHash, + serverCertKeyHash: defaultServerCerttKeyHash, + serverCertPath: defaultArgs.cfg.CertPath, + serverCertKeyPath: defaultArgs.cfg.KeyPath, + certRefreshPeriod: 0, + }, + }, + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + for _, wantVal := range want.TLSConfig.Certificates { + notExist := false + for _, gotVal := range got.TLSConfig.Certificates { + if gotVal.PrivateKey == wantVal.PrivateKey { + notExist = true + break + } + } + if notExist { + return fmt.Errorf("Certificates PrivateKey not Matched :\twant %s", wantVal.PrivateKey) + } + } + gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { + return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) + } + if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + } + return nil + }, + }, + { + name: "return value ClientAuth test.", + args: defaultArgs, + want: &TLSConfigWithTLSCertificateCache{ + &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, + }, + &TLSCertificateCache{ + serverCert: defaultServerCert, + serverCertHash: defaultServerCerttHash, + serverCertKeyHash: defaultServerCerttKeyHash, + serverCertPath: defaultArgs.cfg.CertPath, + serverCertKeyPath: defaultArgs.cfg.KeyPath, + certRefreshPeriod: 0, + }, + }, + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if got.TLSConfig.ClientAuth != want.TLSConfig.ClientAuth { + return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", got.TLSConfig.ClientAuth, want.TLSConfig.ClientAuth) + } + gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { + return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) + } + if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + } + return nil + }, + }, + { + name: "return value certRefreshPeriod test.", + args: args{ + cfg: config.TLS{ + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + CAPath: "../test/data/dummyCa.pem", + CertRefreshPeriod: "12345s", + }, + }, + want: &TLSConfigWithTLSCertificateCache{ + &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, + }, + &TLSCertificateCache{ + serverCert: defaultServerCert, + serverCertHash: defaultServerCerttHash, + serverCertKeyHash: defaultServerCerttKeyHash, + serverCertPath: defaultArgs.cfg.CertPath, + serverCertKeyPath: defaultArgs.cfg.KeyPath, + certRefreshPeriod: 12345 * time.Second, + }, + }, + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if got.TLSConfig.ClientAuth != want.TLSConfig.ClientAuth { + return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", got.TLSConfig.ClientAuth, want.TLSConfig.ClientAuth) + } + gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { + return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) + } + if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + } + return nil + }, + }, + { + name: "cert file not found return value Certificates test.", + args: args{ + cfg: config.TLS{ + CertPath: "", + }, + }, + want: &TLSConfigWithTLSCertificateCache{ + &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: nil, + ClientAuth: tls.RequireAndVerifyClientCert, + }, + &TLSCertificateCache{}, + }, + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if got.TLSConfig.Certificates != nil { + return fmt.Errorf("Certificates not nil") + } + return nil + }, + }, + { + name: "cert file not found return value ClientAuth test.", + args: args{ + cfg: config.TLS{ + CertPath: "", + }, + }, + want: &TLSConfigWithTLSCertificateCache{ + &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: nil, + ClientAuth: tls.RequireAndVerifyClientCert, + }, + &TLSCertificateCache{}, + }, + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if got.TLSConfig.Certificates != nil { + return fmt.Errorf("Certificates not nil") + } + return nil + }, + }, + { + name: "CA file not found return value ClientAuth test.", + args: args{ + cfg: config.TLS{ + CertPath: "", + CAPath: "", + }, + }, + + want: &TLSConfigWithTLSCertificateCache{ + &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, + }, + &TLSCertificateCache{}, + }, + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if got.TLSConfig.ClientAuth != 0 { + return fmt.Errorf("ClientAuth is :\t%d", got.TLSConfig.ClientAuth) + } + return nil + }, + }, + { + name: "CertRefreshPeriod invalid return error test.", + args: args{ + cfg: config.TLS{ + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + CAPath: "../test/data/dummyCa.pem", + CertRefreshPeriod: "invalid duration", + }, + }, + want: nil, + wantErr: errors.New("ParseDuration(cfg.CertRefreshPeriod): time: invalid duration \"invalid duration\""), + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if got != nil { + return fmt.Errorf("got not nil :\tgot %d \twant %d", &got, &want) + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.beforeFunc != nil { + tt.beforeFunc(tt.args) + } + + got, err := NewTLSConfigWithTLSCertificateCache(tt.args.cfg) + if tt.wantErr == nil && err != nil { + t.Errorf("NewTLSConfigWithTLSCertificateCache() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr != nil { + if tt.wantErr.Error() != err.Error() { + t.Errorf("NewTLSConfigWithTLSCertificateCache() error = %v, wantErr %v", err, tt.wantErr) + return + } + } + + if tt.checkFunc != nil { + err = tt.checkFunc(got, tt.want) + if err != nil { + t.Errorf("NewTLSConfigWithTLSCertificateCache() error = %v", err) + return + } + } + + if tt.afterFunc != nil { + tt.afterFunc(tt.args) + } + }) + } +} + func TestNewX509CertPool(t *testing.T) { type args struct { path string From 9930e2404b795c0a9df3ec17b86ac3793fa6c012 Mon Sep 17 00:00:00 2001 From: Tomohiro Hirata - tomohira Date: Mon, 6 Feb 2023 18:43:08 +0900 Subject: [PATCH 23/58] Implement test for authz_proxyd.New Signed-off-by: Kyo Fujisaki --- usecase/authz_proxyd_test.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index 450e853..5486a5b 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -114,6 +114,25 @@ func TestNew(t *testing.T) { }, wantErr: true, wantErrStr: "cannot NewTLSConfig(cfg.Server.TLS): tls.LoadX509KeyPair(cert, key): tls: failed to find any PEM data in certificate input", + }, { + name: "return error when CertRefreshPeriod invalid (failed to parse)", + args: args{ + cfg: config.Config{ + Authorization: config.Authorization{ + RoleToken: config.RoleToken{ + Enable: true, + }, + }, + Server: config.Server{ + TLS: config.TLS{ + Enable: true, + CertRefreshPeriod: "abcdefg", + }, + }, + }, + }, + wantErr: true, + wantErrStr: "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS): ParseDuration(cfg.CertRefreshPeriod): time: invalid duration \"abcdefg\"", }, } for _, tt := range tests { From 6a7dff0ae307aa3c7b570f1795a358d473e35da0 Mon Sep 17 00:00:00 2001 From: Tomohiro Hirata - tomohira Date: Mon, 6 Feb 2023 18:44:44 +0900 Subject: [PATCH 24/58] Fix typo Signed-off-by: Kyo Fujisaki --- usecase/authz_proxyd.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index 36c6224..b852ebe 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -160,7 +160,7 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { return baseErr }) - // handle cert refresh goroutine erorr + // handle cert refresh goroutine error // prevent run RefreshCertificate if Enable is false and CertRefreshPeriod is set if g.cfg.Server.TLS.Enable && g.cfg.Server.TLS.CertRefreshPeriod != "" { eg.Go(func() error { From 56170d17348358b2da41f4b5fc3e533833510468 Mon Sep 17 00:00:00 2001 From: Tomohiro Hirata - tomohira Date: Mon, 6 Feb 2023 19:27:24 +0900 Subject: [PATCH 25/58] Add test for authz_proxyd.Start Signed-off-by: Kyo Fujisaki --- usecase/authz_proxyd_test.go | 74 ++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index 5486a5b..d66dcc3 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -585,6 +585,80 @@ func Test_authzProxyDaemon_Start(t *testing.T) { }, } }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + dummyErr := errors.New("dummy") + return test{ + name: "Daemon stops when TLS.Enable = false and CertRefreshPeriod is set", + fields: fields{ + cfg: config.Config{ + Server: config.Server{ + TLS: config.TLS{ + Enable: false, + CertRefreshPeriod: "3d", + }, + }, + }, + athenz: &service.AuthorizerdMock{ + StartFunc: func(ctx context.Context) <-chan error { + ech := make(chan error) + go func() { + defer close(ech) + <-ctx.Done() + ech <- ctx.Err() + }() + return ech + }, + }, + server: &service.ServerMock{ + ListenAndServeFunc: func(ctx context.Context) <-chan []error { + ech := make(chan []error) + go func() { + defer close(ech) + ech <- []error{errors.WithMessage(dummyErr, "server fails")} + }() + return ech + }, + }, + }, + args: args{ + ctx: ctx, + }, + wantErrs: []error{ + errors.WithMessage(dummyErr, "server fails"), + }, + checkFunc: func(got <-chan []error, wantErrs []error) error { + mux := &sync.Mutex{} + + gotErrs := make([][]error, 0) + mux.Lock() + go func() { + defer mux.Unlock() + err, ok := <-got + if !ok { + return + } + gotErrs = append(gotErrs, err) + }() + time.Sleep(time.Second) + + mux.Lock() + defer mux.Unlock() + + // check only send errors once and the errors are expected ignoring order + sort.Slice(gotErrs[0], getLessErrorFunc(gotErrs[0])) + sort.Slice(wantErrs, getLessErrorFunc(wantErrs)) + gotErrsStr := fmt.Sprintf("%v", gotErrs[0]) + wantErrsStr := fmt.Sprintf("%v", wantErrs) + if len(gotErrs) != 1 || !reflect.DeepEqual(gotErrsStr, wantErrsStr) { + return errors.Errorf("Invalid err, got: %v, want: %v", gotErrsStr, wantErrsStr) + } + + cancel() + return nil + }, + } + }(), } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From a1ba82f2cf9fe804d7bcb70e5e9794b38485bc74 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Mon, 6 Feb 2023 20:38:18 +0900 Subject: [PATCH 26/58] Add TestTLSCertificateCache_getCertificate Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 63 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/service/tls_test.go b/service/tls_test.go index dad4958..ee9de18 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "io/ioutil" + "reflect" "strings" + "sync" "sync/atomic" "testing" "time" @@ -842,3 +844,64 @@ func TestNewX509CertPool(t *testing.T) { }) } } + +func TestTLSCertificateCache_getCertificate(t *testing.T) { + type fields struct { + serverCert atomic.Value + serverCertHash []byte + serverCertKeyHash []byte + serverCertPath string + serverCertKeyPath string + serverCertMutex sync.Mutex + certRefreshPeriod time.Duration + } + type args struct { + h *tls.ClientHelloInfo + } + var defaultServerCert atomic.Value + defaultServerCertData, err := tls.LoadX509KeyPair("../test/data/dummyServer.crt", "../test/data/dummyServer.key") + if err != nil { + t.Errorf("LoadX509KeyPair failed: %s", err) + return + } + defaultServerCert.Store(&defaultServerCertData) + tests := []struct { + name string + fields fields + args args + want *tls.Certificate + wantErr bool + }{ + { + name: "Check return serverCert", + fields: fields{ + serverCert: defaultServerCert, + }, + args: args{ + h: &tls.ClientHelloInfo{}, + }, + want: defaultServerCert.Load().(*tls.Certificate), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tcc := &TLSCertificateCache{ + serverCert: tt.fields.serverCert, + serverCertHash: tt.fields.serverCertHash, + serverCertKeyHash: tt.fields.serverCertKeyHash, + serverCertPath: tt.fields.serverCertPath, + serverCertKeyPath: tt.fields.serverCertKeyPath, + serverCertMutex: tt.fields.serverCertMutex, + certRefreshPeriod: tt.fields.certRefreshPeriod, + } + got, err := tcc.getCertificate(tt.args.h) + if (err != nil) != tt.wantErr { + t.Errorf("TLSCertificateCache.getCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("TLSCertificateCache.getCertificate() = %v, want %v", got, tt.want) + } + }) + } +} From 56a98624979150e1241048dcfeea261fc9516916 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Tue, 7 Feb 2023 14:27:07 +0900 Subject: [PATCH 27/58] Add TLSCertificateCache_RefreshCertificate template Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/service/tls_test.go b/service/tls_test.go index ee9de18..4870be4 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -905,3 +906,76 @@ func TestTLSCertificateCache_getCertificate(t *testing.T) { }) } } + +func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { + type fields struct { + serverCert atomic.Value + serverCertHash []byte + serverCertKeyHash []byte + serverCertPath string + serverCertKeyPath string + serverCertMutex sync.Mutex + certRefreshPeriod time.Duration + } + type args struct { + ctx context.Context + } + type test struct { + name string + fields fields + args args + want error + checkFunc func(*TLSCertificateCache, error, error) error + afterFunc func() error + } + tests := []test{ + func() test { + ctx, cancelFunc := context.WithCancel(context.Background()) + //key := "../test/data/dummyServer.key" + //cert := "../test/data/dummyServer.crt" + + return test{ + name: "Test refresh function can start and stop", + fields: fields{}, + args: args{ + ctx: ctx, + }, + checkFunc: func(tcc *TLSCertificateCache, got error, want error) error { + time.Sleep(time.Millisecond * 150) + return nil + }, + afterFunc: func() error { + cancelFunc() + return nil + }, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.afterFunc != nil { + defer func() { + if err := tt.afterFunc(); err != nil { + t.Errorf("afterFunc error, error: %v", err) + return + } + }() + } + + tcc := &TLSCertificateCache{ + serverCert: tt.fields.serverCert, + serverCertHash: tt.fields.serverCertHash, + serverCertKeyHash: tt.fields.serverCertKeyHash, + serverCertPath: tt.fields.serverCertPath, + serverCertKeyPath: tt.fields.serverCertKeyPath, + serverCertMutex: tt.fields.serverCertMutex, + certRefreshPeriod: tt.fields.certRefreshPeriod, + } + + got := tcc.RefreshCertificate(tt.args.ctx) + if err := tt.checkFunc(tcc, got, tt.want); err != nil { + t.Errorf("TLSCertificateCache.RefreshCertificate() error = %v, want %v", err, tt.want) + } + }) + } +} From 0029638ee0a8b8e5d3b6be12d62ae5b0dca3ee74 Mon Sep 17 00:00:00 2001 From: Tomohiro Hirata - tomohira Date: Tue, 7 Feb 2023 14:20:59 +0900 Subject: [PATCH 28/58] Fix param format for test Signed-off-by: Kyo Fujisaki --- usecase/authz_proxyd_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index d66dcc3..3f449cf 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -54,6 +54,10 @@ func TestNew(t *testing.T) { HealthCheck: config.HealthCheck{ Endpoint: "/dummy", }, + TLS: config.TLS{ + Enable: true, + CertRefreshPeriod: "24h", + }, }, Proxy: config.Proxy{ BufferSize: 512, From 98ef5ec1e058ce7239e770abde821b2cc24376c6 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Tue, 7 Feb 2023 16:19:46 +0900 Subject: [PATCH 29/58] Add refresh testcase(still in progress) Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 136 +++++++++++++++++++++++++++++++---- test/data/newDummyServer.crt | 19 +++++ 2 files changed, 142 insertions(+), 13 deletions(-) create mode 100644 test/data/newDummyServer.crt diff --git a/service/tls_test.go b/service/tls_test.go index 4870be4..238f139 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -6,7 +6,9 @@ import ( "crypto/x509" "errors" "fmt" + "io" "io/ioutil" + "os" "reflect" "strings" "sync" @@ -921,31 +923,131 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { ctx context.Context } type test struct { - name string - fields fields - args args - want error - checkFunc func(*TLSCertificateCache, error, error) error - afterFunc func() error + name string + fields fields + args args + want error + beforeFunc func() error + checkFunc func(*TLSCertificateCache, error, error) error + afterFunc func() error + } + copyCert := func(srcPath, dstPath string) error { + src, err := os.Open(srcPath) + if err != nil { + t.Errorf("test cert copy failed: %s", err) + return err + } + defer src.Close() + + dst, err := os.Create(dstPath) + if err != nil { + t.Errorf("test cert copy failed: %s", err) + return err + } + defer dst.Close() + + _, err = io.Copy(dst, src) + if err != nil { + t.Errorf("test cert copy failed: %s", err) + return err + } + return nil + } + testCertPath := "../test/data/test.crt" + testCertKeyPath := "../test/data/test.key" + oldCertPath := "../test/data/dummyServer.crt" + oldCertKeyPath := "../test/data/dummyServer.key" + newCertPath := "../test/data/newDummyServer.crt" + + var defaultServerCert atomic.Value + defaultServerCertData, err := tls.LoadX509KeyPair("../test/data/dummyServer.crt", "../test/data/dummyServer.key") + if err != nil { + t.Errorf("LoadX509KeyPair failed: %s", err) + return + } + defaultServerCert.Store(&defaultServerCertData) + defaultServerCerttHash, err := hash(oldCertPath) + if err != nil { + t.Errorf("hash failed: %s", err) + return + } + defaultServerCerttKeyHash, _ := hash(oldCertKeyPath) + if err != nil { + t.Errorf("hash failed: %s", err) + return } + // newCert key == oldCert key + newCert, err := tls.LoadX509KeyPair(newCertPath, oldCertKeyPath) + if err != nil { + t.Errorf("LoadX509KeyPair failed: %s", err) + return + } + tests := []test{ func() test { ctx, cancelFunc := context.WithCancel(context.Background()) - //key := "../test/data/dummyServer.key" - //cert := "../test/data/dummyServer.crt" return test{ - name: "Test refresh function can start and stop", - fields: fields{}, + name: "Test refresh server cert", + fields: fields{ + serverCert: defaultServerCert, + serverCertHash: defaultServerCerttHash, + serverCertKeyHash: defaultServerCerttKeyHash, + serverCertPath: testCertPath, + serverCertKeyPath: testCertKeyPath, + certRefreshPeriod: 1 * time.Second, + serverCertMutex: sync.Mutex{}, + }, args: args{ ctx: ctx, }, + beforeFunc: func() error { + err := copyCert(oldCertPath, testCertPath) + if err != nil { + return err + } + err = copyCert(oldCertKeyPath, testCertKeyPath) + if err != nil { + return err + } + return nil + }, checkFunc: func(tcc *TLSCertificateCache, got error, want error) error { - time.Sleep(time.Millisecond * 150) + if got != nil { + return got + } + cachedCert := tcc.serverCert.Load() + cc, _ := x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) + dc, _ := x509.ParseCertificate(defaultServerCertData.Certificate[0]) + if cc.SerialNumber != dc.SerialNumber { + return errors.New("Serial Number not Matched") + } + // refresh certificate + err = copyCert(newCertPath, testCertPath) + if err != nil { + return err + } + time.Sleep(1 * time.Second) + cachedCert = tcc.serverCert.Load() + cc, _ = x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) + nc, _ := x509.ParseCertificate(newCert.Certificate[0]) + if cc.SerialNumber != nc.SerialNumber { + return errors.New("Serial Number not Matched") + } return nil }, afterFunc: func() error { cancelFunc() + err := os.Remove(testCertPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } + err = os.Remove(testCertKeyPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } return nil }, } @@ -961,6 +1063,12 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { } }() } + if tt.beforeFunc != nil { + if err := tt.beforeFunc(); err != nil { + t.Errorf("beforeFunc error, error: %v", err) + return + } + } tcc := &TLSCertificateCache{ serverCert: tt.fields.serverCert, @@ -971,8 +1079,10 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { serverCertMutex: tt.fields.serverCertMutex, certRefreshPeriod: tt.fields.certRefreshPeriod, } - - got := tcc.RefreshCertificate(tt.args.ctx) + //errCh := make(chan error) + go func() error { + return tcc.RefreshCertificate(tt.args.ctx) + }() if err := tt.checkFunc(tcc, got, tt.want); err != nil { t.Errorf("TLSCertificateCache.RefreshCertificate() error = %v, want %v", err, tt.want) } diff --git a/test/data/newDummyServer.crt b/test/data/newDummyServer.crt new file mode 100644 index 0000000..3e5e4c0 --- /dev/null +++ b/test/data/newDummyServer.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIC/jCCAeYCCQDYE18V6q0sozANBgkqhkiG9w0BAQUFADBBMQswCQYDVQQGEwJK +UDEOMAwGA1UECAwFVG9reW8xEjAQBgNVBAcMCWNoaXlvZGFrdTEOMAwGA1UECgwF +eWFob28wHhcNMjMwMjA3MDU0NzEwWhcNMzMwMjA0MDU0NzEwWjBBMQswCQYDVQQG +EwJKUDEOMAwGA1UECAwFVG9reW8xEjAQBgNVBAcMCWNoaXlvZGFrdTEOMAwGA1UE +CgwFeWFob28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUBKwhDHGl +yj5GFOhDlyH67BsW9l/A2rna21WVnq69ANN9WmaiFp32C+2BrqF8DYwtPyoaTgXd ++WQ8wmtjlYwp378P743vBgf6tazZYKbCa7PZsNMLCdI0KurXf6iwOkCY7zQQcOji +STUVtfr009LFgc8KBduxBiyEe1tXoN1qrjUZgty/05EtW09QFE2XqVFQS4jN6CC4 +whmN3J7yWHyZv3K1lO0aeC31kIOqw0/DxD/Dj/RMJBpz0q3jeYJwvM3rE4DNajTW +6BLiPWVkLW4H3paRE6jYfyCRazE8tLLWHzGGqPWUElV2rKWzdV6OnyWLFxWC2RK/ +71IuZS1UDqmHAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAAlLPl+OXQnoSkCGUvVa +Dz/9qCBbho38VBABRjN9ZwxRZXhF6La38gTpjsbu5iPidND+rEGqjQQIszahoNk2 +gcec23SQfpEswnoi3lHO/sEcxchKT6VD2RTgOCV9LQFebbu3EmH7JAsz8xmgWVMR +4T34E+KkMMq3ScUGXcbX812q3CoTglG2HVwJIzW+wYF6l7aVsrHSI51yKAs41fO5 +UY9zacsWYvEeOet6WXep0/G7gJuxJ2W7d6EP82AfduiOW7LLBUUny4eukbnKljaG +uR2iTOkr/PWcxMpO0+iNyvRpyUwc3A1KsL7h2z8/VH9QEbCMO3xl0GTTvDcFltAk +ug8= +-----END CERTIFICATE----- From 75247f11a68cf2235e9c3cdb2997ed92d731de56 Mon Sep 17 00:00:00 2001 From: wfan Date: Tue, 7 Feb 2023 16:19:07 +0900 Subject: [PATCH 30/58] fix server.go unit test Signed-off-by: Kyo Fujisaki --- service/server_test.go | 150 +++++++++++++++++++++++++++++------------ 1 file changed, 106 insertions(+), 44 deletions(-) diff --git a/service/server_test.go b/service/server_test.go index bf6c567..57e9b86 100644 --- a/service/server_test.go +++ b/service/server_test.go @@ -589,8 +589,14 @@ func Test_server_ListenAndServe(t *testing.T) { } }(), func() test { - key := "../test/data/dummyServer.key" - cert := "../test/data/dummyServer.crt" + tc, err := NewTLSConfig(config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + }) + if err != nil { + panic(err) + } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) @@ -609,8 +615,9 @@ func Test_server_ListenAndServe(t *testing.T) { fields: fields{ srv: func() *http.Server { srv := &http.Server{ - Addr: fmt.Sprintf(":%d", apiSrvPort), - Handler: handler, + Addr: fmt.Sprintf(":%d", apiSrvPort), + Handler: handler, + TLSConfig: tc, } srv.SetKeepAlivesEnabled(true) @@ -637,9 +644,7 @@ func Test_server_ListenAndServe(t *testing.T) { cfg: config.Server{ Port: apiSrvPort, TLS: config.TLS{ - Enable: true, - CertPath: cert, - KeyPath: key, + Enable: true, }, HealthCheck: config.HealthCheck{ Port: hcSrvPort, @@ -683,7 +688,6 @@ func Test_server_ListenAndServe(t *testing.T) { }, } }(), - func() test { ctx, cancel := context.WithCancel(context.Background()) @@ -789,8 +793,14 @@ func Test_server_ListenAndServe(t *testing.T) { }(), func() test { - key := "../test/data/dummyServer.key" - cert := "../test/data/dummyServer.crt" + tc, err := NewTLSConfig(config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + }) + if err != nil { + panic(err) + } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) @@ -809,8 +819,9 @@ func Test_server_ListenAndServe(t *testing.T) { fields: fields{ srv: func() *http.Server { srv := &http.Server{ - Addr: fmt.Sprintf(":%d", apiSrvPort), - Handler: handler, + Addr: fmt.Sprintf(":%d", apiSrvPort), + Handler: handler, + TLSConfig: tc, } srv.SetKeepAlivesEnabled(true) @@ -837,9 +848,7 @@ func Test_server_ListenAndServe(t *testing.T) { cfg: config.Server{ Port: apiSrvPort, TLS: config.TLS{ - Enable: true, - CertPath: cert, - KeyPath: key, + Enable: true, }, HealthCheck: config.HealthCheck{ Port: hcSrvPort, @@ -884,8 +893,14 @@ func Test_server_ListenAndServe(t *testing.T) { } }(), func() test { - key := "../test/data/dummyServer.key" - cert := "../test/data/dummyServer.crt" + tc, err := NewTLSConfig(config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + }) + if err != nil { + panic(err) + } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) @@ -904,8 +919,9 @@ func Test_server_ListenAndServe(t *testing.T) { fields: fields{ srv: func() *http.Server { srv := &http.Server{ - Addr: fmt.Sprintf(":%d", apiSrvPort), - Handler: handler, + Addr: fmt.Sprintf(":%d", apiSrvPort), + Handler: handler, + TLSConfig: tc, } srv.SetKeepAlivesEnabled(true) @@ -932,9 +948,7 @@ func Test_server_ListenAndServe(t *testing.T) { cfg: config.Server{ Port: apiSrvPort, TLS: config.TLS{ - Enable: true, - CertPath: cert, - KeyPath: key, + Enable: true, }, HealthCheck: config.HealthCheck{ Port: hcSrvPort, @@ -980,8 +994,14 @@ func Test_server_ListenAndServe(t *testing.T) { }(), func() test { ctx, cancelFunc := context.WithCancel(context.Background()) - key := "../test/data/dummyServer.key" - cert := "../test/data/dummyServer.crt" + tc, err := NewTLSConfig(config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + }) + if err != nil { + panic(err) + } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) @@ -1000,8 +1020,9 @@ func Test_server_ListenAndServe(t *testing.T) { fields: fields{ srv: func() *http.Server { srv := &http.Server{ - Addr: fmt.Sprintf(":%d", apiSrvPort), - Handler: handler, + Addr: fmt.Sprintf(":%d", apiSrvPort), + Handler: handler, + TLSConfig: tc, } srv.SetKeepAlivesEnabled(true) @@ -1028,9 +1049,7 @@ func Test_server_ListenAndServe(t *testing.T) { cfg: config.Server{ Port: apiSrvPort, TLS: config.TLS{ - Enable: true, - CertPath: cert, - KeyPath: key, + Enable: true, }, HealthCheck: config.HealthCheck{ Port: hcSrvPort, @@ -1073,8 +1092,14 @@ func Test_server_ListenAndServe(t *testing.T) { }(), func() test { ctx, cancelFunc := context.WithCancel(context.Background()) - key := "../test/data/dummyServer.key" - cert := "../test/data/dummyServer.crt" + tc, err := NewTLSConfig(config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + }) + if err != nil { + panic(err) + } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) @@ -1093,8 +1118,9 @@ func Test_server_ListenAndServe(t *testing.T) { fields: fields{ srv: func() *http.Server { srv := &http.Server{ - Addr: fmt.Sprintf(":%d", apiSrvPort), - Handler: handler, + Addr: fmt.Sprintf(":%d", apiSrvPort), + Handler: handler, + TLSConfig: tc, } srv.SetKeepAlivesEnabled(true) @@ -1121,9 +1147,7 @@ func Test_server_ListenAndServe(t *testing.T) { cfg: config.Server{ Port: apiSrvPort, TLS: config.TLS{ - Enable: true, - CertPath: cert, - KeyPath: key, + Enable: true, }, HealthCheck: config.HealthCheck{ // Port: hcSrvPort, @@ -1633,25 +1657,63 @@ func Test_server_listenAndServeAPI(t *testing.T) { want error } tests := []test{ - func() test { - key := "../test/data/dummyServer.key" - cert := "../test/data/dummyServer.crt" + { + name: "Test HTTP server startup", + fields: fields{ + srv: &http.Server{ + Handler: func() http.Handler { + return nil + }(), + Addr: fmt.Sprintf(":%d", 9999), + }, + cfg: config.Server{ + Port: 9999, + TLS: config.TLS{ + Enable: false, + }, + }, + }, + checkFunc: func(s *server, want error) error { + // listenAndServeAPI function is blocking, so we need to set timer to shutdown the process + go func() { + time.Sleep(time.Second * 1) + if err := s.srv.Shutdown(context.Background()); err != nil { + panic(err) + } + }() + + got := s.listenAndServeAPI() + if got != want { + return fmt.Errorf("got:\t%v\nwant:\t%v", got, want) + } + return nil + }, + want: http.ErrServerClosed, + }, + func() test { + tc, err := NewTLSConfig(config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + }) + if err != nil { + panic(err) + } return test{ - name: "Test server startup", + name: "Test HTTPS server startup", fields: fields{ srv: &http.Server{ Handler: func() http.Handler { return nil }(), - Addr: fmt.Sprintf(":%d", 9999), + Addr: fmt.Sprintf(":%d", 9999), + TLSConfig: tc, }, cfg: config.Server{ Port: 9999, TLS: config.TLS{ - Enable: true, - CertPath: cert, - KeyPath: key, + Enable: true, }, }, }, From 8b87f4bb3fab8f7c6df54afccfcac320058db2cd Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Tue, 7 Feb 2023 16:22:58 +0900 Subject: [PATCH 31/58] Revert "Add refresh testcase(still in progress)" This reverts commit fd9a57843e59b4ede32de87fa63430c5a3e7c62f. Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 136 ++++------------------------------- test/data/newDummyServer.crt | 19 ----- 2 files changed, 13 insertions(+), 142 deletions(-) delete mode 100644 test/data/newDummyServer.crt diff --git a/service/tls_test.go b/service/tls_test.go index 238f139..4870be4 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -6,9 +6,7 @@ import ( "crypto/x509" "errors" "fmt" - "io" "io/ioutil" - "os" "reflect" "strings" "sync" @@ -923,131 +921,31 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { ctx context.Context } type test struct { - name string - fields fields - args args - want error - beforeFunc func() error - checkFunc func(*TLSCertificateCache, error, error) error - afterFunc func() error - } - copyCert := func(srcPath, dstPath string) error { - src, err := os.Open(srcPath) - if err != nil { - t.Errorf("test cert copy failed: %s", err) - return err - } - defer src.Close() - - dst, err := os.Create(dstPath) - if err != nil { - t.Errorf("test cert copy failed: %s", err) - return err - } - defer dst.Close() - - _, err = io.Copy(dst, src) - if err != nil { - t.Errorf("test cert copy failed: %s", err) - return err - } - return nil - } - testCertPath := "../test/data/test.crt" - testCertKeyPath := "../test/data/test.key" - oldCertPath := "../test/data/dummyServer.crt" - oldCertKeyPath := "../test/data/dummyServer.key" - newCertPath := "../test/data/newDummyServer.crt" - - var defaultServerCert atomic.Value - defaultServerCertData, err := tls.LoadX509KeyPair("../test/data/dummyServer.crt", "../test/data/dummyServer.key") - if err != nil { - t.Errorf("LoadX509KeyPair failed: %s", err) - return - } - defaultServerCert.Store(&defaultServerCertData) - defaultServerCerttHash, err := hash(oldCertPath) - if err != nil { - t.Errorf("hash failed: %s", err) - return - } - defaultServerCerttKeyHash, _ := hash(oldCertKeyPath) - if err != nil { - t.Errorf("hash failed: %s", err) - return - } - // newCert key == oldCert key - newCert, err := tls.LoadX509KeyPair(newCertPath, oldCertKeyPath) - if err != nil { - t.Errorf("LoadX509KeyPair failed: %s", err) - return + name string + fields fields + args args + want error + checkFunc func(*TLSCertificateCache, error, error) error + afterFunc func() error } - tests := []test{ func() test { ctx, cancelFunc := context.WithCancel(context.Background()) + //key := "../test/data/dummyServer.key" + //cert := "../test/data/dummyServer.crt" return test{ - name: "Test refresh server cert", - fields: fields{ - serverCert: defaultServerCert, - serverCertHash: defaultServerCerttHash, - serverCertKeyHash: defaultServerCerttKeyHash, - serverCertPath: testCertPath, - serverCertKeyPath: testCertKeyPath, - certRefreshPeriod: 1 * time.Second, - serverCertMutex: sync.Mutex{}, - }, + name: "Test refresh function can start and stop", + fields: fields{}, args: args{ ctx: ctx, }, - beforeFunc: func() error { - err := copyCert(oldCertPath, testCertPath) - if err != nil { - return err - } - err = copyCert(oldCertKeyPath, testCertKeyPath) - if err != nil { - return err - } - return nil - }, checkFunc: func(tcc *TLSCertificateCache, got error, want error) error { - if got != nil { - return got - } - cachedCert := tcc.serverCert.Load() - cc, _ := x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) - dc, _ := x509.ParseCertificate(defaultServerCertData.Certificate[0]) - if cc.SerialNumber != dc.SerialNumber { - return errors.New("Serial Number not Matched") - } - // refresh certificate - err = copyCert(newCertPath, testCertPath) - if err != nil { - return err - } - time.Sleep(1 * time.Second) - cachedCert = tcc.serverCert.Load() - cc, _ = x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) - nc, _ := x509.ParseCertificate(newCert.Certificate[0]) - if cc.SerialNumber != nc.SerialNumber { - return errors.New("Serial Number not Matched") - } + time.Sleep(time.Millisecond * 150) return nil }, afterFunc: func() error { cancelFunc() - err := os.Remove(testCertPath) - if err != nil { - t.Errorf("test cert remove failed: %s", err) - return err - } - err = os.Remove(testCertKeyPath) - if err != nil { - t.Errorf("test cert remove failed: %s", err) - return err - } return nil }, } @@ -1063,12 +961,6 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { } }() } - if tt.beforeFunc != nil { - if err := tt.beforeFunc(); err != nil { - t.Errorf("beforeFunc error, error: %v", err) - return - } - } tcc := &TLSCertificateCache{ serverCert: tt.fields.serverCert, @@ -1079,10 +971,8 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { serverCertMutex: tt.fields.serverCertMutex, certRefreshPeriod: tt.fields.certRefreshPeriod, } - //errCh := make(chan error) - go func() error { - return tcc.RefreshCertificate(tt.args.ctx) - }() + + got := tcc.RefreshCertificate(tt.args.ctx) if err := tt.checkFunc(tcc, got, tt.want); err != nil { t.Errorf("TLSCertificateCache.RefreshCertificate() error = %v, want %v", err, tt.want) } diff --git a/test/data/newDummyServer.crt b/test/data/newDummyServer.crt deleted file mode 100644 index 3e5e4c0..0000000 --- a/test/data/newDummyServer.crt +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIC/jCCAeYCCQDYE18V6q0sozANBgkqhkiG9w0BAQUFADBBMQswCQYDVQQGEwJK -UDEOMAwGA1UECAwFVG9reW8xEjAQBgNVBAcMCWNoaXlvZGFrdTEOMAwGA1UECgwF -eWFob28wHhcNMjMwMjA3MDU0NzEwWhcNMzMwMjA0MDU0NzEwWjBBMQswCQYDVQQG -EwJKUDEOMAwGA1UECAwFVG9reW8xEjAQBgNVBAcMCWNoaXlvZGFrdTEOMAwGA1UE -CgwFeWFob28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUBKwhDHGl -yj5GFOhDlyH67BsW9l/A2rna21WVnq69ANN9WmaiFp32C+2BrqF8DYwtPyoaTgXd -+WQ8wmtjlYwp378P743vBgf6tazZYKbCa7PZsNMLCdI0KurXf6iwOkCY7zQQcOji -STUVtfr009LFgc8KBduxBiyEe1tXoN1qrjUZgty/05EtW09QFE2XqVFQS4jN6CC4 -whmN3J7yWHyZv3K1lO0aeC31kIOqw0/DxD/Dj/RMJBpz0q3jeYJwvM3rE4DNajTW -6BLiPWVkLW4H3paRE6jYfyCRazE8tLLWHzGGqPWUElV2rKWzdV6OnyWLFxWC2RK/ -71IuZS1UDqmHAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAAlLPl+OXQnoSkCGUvVa -Dz/9qCBbho38VBABRjN9ZwxRZXhF6La38gTpjsbu5iPidND+rEGqjQQIszahoNk2 -gcec23SQfpEswnoi3lHO/sEcxchKT6VD2RTgOCV9LQFebbu3EmH7JAsz8xmgWVMR -4T34E+KkMMq3ScUGXcbX812q3CoTglG2HVwJIzW+wYF6l7aVsrHSI51yKAs41fO5 -UY9zacsWYvEeOet6WXep0/G7gJuxJ2W7d6EP82AfduiOW7LLBUUny4eukbnKljaG -uR2iTOkr/PWcxMpO0+iNyvRpyUwc3A1KsL7h2z8/VH9QEbCMO3xl0GTTvDcFltAk -ug8= ------END CERTIFICATE----- From 54184d11d5c369829cedf591d06dc99b62b50ce6 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Tue, 7 Feb 2023 16:53:00 +0900 Subject: [PATCH 32/58] Revert "Revert "Add refresh testcase(still in progress)"" This reverts commit f5003fdb4952a050b5728d91e184d5011b8487dc. Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 136 +++++++++++++++++++++++++++++++---- test/data/newDummyServer.crt | 19 +++++ 2 files changed, 142 insertions(+), 13 deletions(-) create mode 100644 test/data/newDummyServer.crt diff --git a/service/tls_test.go b/service/tls_test.go index 4870be4..238f139 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -6,7 +6,9 @@ import ( "crypto/x509" "errors" "fmt" + "io" "io/ioutil" + "os" "reflect" "strings" "sync" @@ -921,31 +923,131 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { ctx context.Context } type test struct { - name string - fields fields - args args - want error - checkFunc func(*TLSCertificateCache, error, error) error - afterFunc func() error + name string + fields fields + args args + want error + beforeFunc func() error + checkFunc func(*TLSCertificateCache, error, error) error + afterFunc func() error + } + copyCert := func(srcPath, dstPath string) error { + src, err := os.Open(srcPath) + if err != nil { + t.Errorf("test cert copy failed: %s", err) + return err + } + defer src.Close() + + dst, err := os.Create(dstPath) + if err != nil { + t.Errorf("test cert copy failed: %s", err) + return err + } + defer dst.Close() + + _, err = io.Copy(dst, src) + if err != nil { + t.Errorf("test cert copy failed: %s", err) + return err + } + return nil + } + testCertPath := "../test/data/test.crt" + testCertKeyPath := "../test/data/test.key" + oldCertPath := "../test/data/dummyServer.crt" + oldCertKeyPath := "../test/data/dummyServer.key" + newCertPath := "../test/data/newDummyServer.crt" + + var defaultServerCert atomic.Value + defaultServerCertData, err := tls.LoadX509KeyPair("../test/data/dummyServer.crt", "../test/data/dummyServer.key") + if err != nil { + t.Errorf("LoadX509KeyPair failed: %s", err) + return + } + defaultServerCert.Store(&defaultServerCertData) + defaultServerCerttHash, err := hash(oldCertPath) + if err != nil { + t.Errorf("hash failed: %s", err) + return + } + defaultServerCerttKeyHash, _ := hash(oldCertKeyPath) + if err != nil { + t.Errorf("hash failed: %s", err) + return } + // newCert key == oldCert key + newCert, err := tls.LoadX509KeyPair(newCertPath, oldCertKeyPath) + if err != nil { + t.Errorf("LoadX509KeyPair failed: %s", err) + return + } + tests := []test{ func() test { ctx, cancelFunc := context.WithCancel(context.Background()) - //key := "../test/data/dummyServer.key" - //cert := "../test/data/dummyServer.crt" return test{ - name: "Test refresh function can start and stop", - fields: fields{}, + name: "Test refresh server cert", + fields: fields{ + serverCert: defaultServerCert, + serverCertHash: defaultServerCerttHash, + serverCertKeyHash: defaultServerCerttKeyHash, + serverCertPath: testCertPath, + serverCertKeyPath: testCertKeyPath, + certRefreshPeriod: 1 * time.Second, + serverCertMutex: sync.Mutex{}, + }, args: args{ ctx: ctx, }, + beforeFunc: func() error { + err := copyCert(oldCertPath, testCertPath) + if err != nil { + return err + } + err = copyCert(oldCertKeyPath, testCertKeyPath) + if err != nil { + return err + } + return nil + }, checkFunc: func(tcc *TLSCertificateCache, got error, want error) error { - time.Sleep(time.Millisecond * 150) + if got != nil { + return got + } + cachedCert := tcc.serverCert.Load() + cc, _ := x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) + dc, _ := x509.ParseCertificate(defaultServerCertData.Certificate[0]) + if cc.SerialNumber != dc.SerialNumber { + return errors.New("Serial Number not Matched") + } + // refresh certificate + err = copyCert(newCertPath, testCertPath) + if err != nil { + return err + } + time.Sleep(1 * time.Second) + cachedCert = tcc.serverCert.Load() + cc, _ = x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) + nc, _ := x509.ParseCertificate(newCert.Certificate[0]) + if cc.SerialNumber != nc.SerialNumber { + return errors.New("Serial Number not Matched") + } return nil }, afterFunc: func() error { cancelFunc() + err := os.Remove(testCertPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } + err = os.Remove(testCertKeyPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } return nil }, } @@ -961,6 +1063,12 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { } }() } + if tt.beforeFunc != nil { + if err := tt.beforeFunc(); err != nil { + t.Errorf("beforeFunc error, error: %v", err) + return + } + } tcc := &TLSCertificateCache{ serverCert: tt.fields.serverCert, @@ -971,8 +1079,10 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { serverCertMutex: tt.fields.serverCertMutex, certRefreshPeriod: tt.fields.certRefreshPeriod, } - - got := tcc.RefreshCertificate(tt.args.ctx) + //errCh := make(chan error) + go func() error { + return tcc.RefreshCertificate(tt.args.ctx) + }() if err := tt.checkFunc(tcc, got, tt.want); err != nil { t.Errorf("TLSCertificateCache.RefreshCertificate() error = %v, want %v", err, tt.want) } diff --git a/test/data/newDummyServer.crt b/test/data/newDummyServer.crt new file mode 100644 index 0000000..3e5e4c0 --- /dev/null +++ b/test/data/newDummyServer.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIC/jCCAeYCCQDYE18V6q0sozANBgkqhkiG9w0BAQUFADBBMQswCQYDVQQGEwJK +UDEOMAwGA1UECAwFVG9reW8xEjAQBgNVBAcMCWNoaXlvZGFrdTEOMAwGA1UECgwF +eWFob28wHhcNMjMwMjA3MDU0NzEwWhcNMzMwMjA0MDU0NzEwWjBBMQswCQYDVQQG +EwJKUDEOMAwGA1UECAwFVG9reW8xEjAQBgNVBAcMCWNoaXlvZGFrdTEOMAwGA1UE +CgwFeWFob28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUBKwhDHGl +yj5GFOhDlyH67BsW9l/A2rna21WVnq69ANN9WmaiFp32C+2BrqF8DYwtPyoaTgXd ++WQ8wmtjlYwp378P743vBgf6tazZYKbCa7PZsNMLCdI0KurXf6iwOkCY7zQQcOji +STUVtfr009LFgc8KBduxBiyEe1tXoN1qrjUZgty/05EtW09QFE2XqVFQS4jN6CC4 +whmN3J7yWHyZv3K1lO0aeC31kIOqw0/DxD/Dj/RMJBpz0q3jeYJwvM3rE4DNajTW +6BLiPWVkLW4H3paRE6jYfyCRazE8tLLWHzGGqPWUElV2rKWzdV6OnyWLFxWC2RK/ +71IuZS1UDqmHAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAAlLPl+OXQnoSkCGUvVa +Dz/9qCBbho38VBABRjN9ZwxRZXhF6La38gTpjsbu5iPidND+rEGqjQQIszahoNk2 +gcec23SQfpEswnoi3lHO/sEcxchKT6VD2RTgOCV9LQFebbu3EmH7JAsz8xmgWVMR +4T34E+KkMMq3ScUGXcbX812q3CoTglG2HVwJIzW+wYF6l7aVsrHSI51yKAs41fO5 +UY9zacsWYvEeOet6WXep0/G7gJuxJ2W7d6EP82AfduiOW7LLBUUny4eukbnKljaG +uR2iTOkr/PWcxMpO0+iNyvRpyUwc3A1KsL7h2z8/VH9QEbCMO3xl0GTTvDcFltAk +ug8= +-----END CERTIFICATE----- From a03adb85bec43a07f6a59b81d45965e94fddaf70 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Tue, 7 Feb 2023 16:55:40 +0900 Subject: [PATCH 33/58] Commentout work in progress test Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 351 ++++++++++++++++++++++---------------------- 1 file changed, 174 insertions(+), 177 deletions(-) diff --git a/service/tls_test.go b/service/tls_test.go index 238f139..1ae99d7 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -1,14 +1,11 @@ package service import ( - "context" "crypto/tls" "crypto/x509" "errors" "fmt" - "io" "io/ioutil" - "os" "reflect" "strings" "sync" @@ -909,183 +906,183 @@ func TestTLSCertificateCache_getCertificate(t *testing.T) { } } -func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { - type fields struct { - serverCert atomic.Value - serverCertHash []byte - serverCertKeyHash []byte - serverCertPath string - serverCertKeyPath string - serverCertMutex sync.Mutex - certRefreshPeriod time.Duration - } - type args struct { - ctx context.Context - } - type test struct { - name string - fields fields - args args - want error - beforeFunc func() error - checkFunc func(*TLSCertificateCache, error, error) error - afterFunc func() error - } - copyCert := func(srcPath, dstPath string) error { - src, err := os.Open(srcPath) - if err != nil { - t.Errorf("test cert copy failed: %s", err) - return err - } - defer src.Close() +// func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { +// type fields struct { +// serverCert atomic.Value +// serverCertHash []byte +// serverCertKeyHash []byte +// serverCertPath string +// serverCertKeyPath string +// serverCertMutex sync.Mutex +// certRefreshPeriod time.Duration +// } +// type args struct { +// ctx context.Context +// } +// type test struct { +// name string +// fields fields +// args args +// want error +// beforeFunc func() error +// checkFunc func(*TLSCertificateCache, error, error) error +// afterFunc func() error +// } +// copyCert := func(srcPath, dstPath string) error { +// src, err := os.Open(srcPath) +// if err != nil { +// t.Errorf("test cert copy failed: %s", err) +// return err +// } +// defer src.Close() - dst, err := os.Create(dstPath) - if err != nil { - t.Errorf("test cert copy failed: %s", err) - return err - } - defer dst.Close() +// dst, err := os.Create(dstPath) +// if err != nil { +// t.Errorf("test cert copy failed: %s", err) +// return err +// } +// defer dst.Close() - _, err = io.Copy(dst, src) - if err != nil { - t.Errorf("test cert copy failed: %s", err) - return err - } - return nil - } - testCertPath := "../test/data/test.crt" - testCertKeyPath := "../test/data/test.key" - oldCertPath := "../test/data/dummyServer.crt" - oldCertKeyPath := "../test/data/dummyServer.key" - newCertPath := "../test/data/newDummyServer.crt" +// _, err = io.Copy(dst, src) +// if err != nil { +// t.Errorf("test cert copy failed: %s", err) +// return err +// } +// return nil +// } +// testCertPath := "../test/data/test.crt" +// testCertKeyPath := "../test/data/test.key" +// oldCertPath := "../test/data/dummyServer.crt" +// oldCertKeyPath := "../test/data/dummyServer.key" +// newCertPath := "../test/data/newDummyServer.crt" - var defaultServerCert atomic.Value - defaultServerCertData, err := tls.LoadX509KeyPair("../test/data/dummyServer.crt", "../test/data/dummyServer.key") - if err != nil { - t.Errorf("LoadX509KeyPair failed: %s", err) - return - } - defaultServerCert.Store(&defaultServerCertData) - defaultServerCerttHash, err := hash(oldCertPath) - if err != nil { - t.Errorf("hash failed: %s", err) - return - } - defaultServerCerttKeyHash, _ := hash(oldCertKeyPath) - if err != nil { - t.Errorf("hash failed: %s", err) - return - } - // newCert key == oldCert key - newCert, err := tls.LoadX509KeyPair(newCertPath, oldCertKeyPath) - if err != nil { - t.Errorf("LoadX509KeyPair failed: %s", err) - return - } +// var defaultServerCert atomic.Value +// defaultServerCertData, err := tls.LoadX509KeyPair("../test/data/dummyServer.crt", "../test/data/dummyServer.key") +// if err != nil { +// t.Errorf("LoadX509KeyPair failed: %s", err) +// return +// } +// defaultServerCert.Store(&defaultServerCertData) +// defaultServerCerttHash, err := hash(oldCertPath) +// if err != nil { +// t.Errorf("hash failed: %s", err) +// return +// } +// defaultServerCerttKeyHash, _ := hash(oldCertKeyPath) +// if err != nil { +// t.Errorf("hash failed: %s", err) +// return +// } +// // newCert key == oldCert key +// newCert, err := tls.LoadX509KeyPair(newCertPath, oldCertKeyPath) +// if err != nil { +// t.Errorf("LoadX509KeyPair failed: %s", err) +// return +// } - tests := []test{ - func() test { - ctx, cancelFunc := context.WithCancel(context.Background()) +// tests := []test{ +// func() test { +// ctx, cancelFunc := context.WithCancel(context.Background()) - return test{ - name: "Test refresh server cert", - fields: fields{ - serverCert: defaultServerCert, - serverCertHash: defaultServerCerttHash, - serverCertKeyHash: defaultServerCerttKeyHash, - serverCertPath: testCertPath, - serverCertKeyPath: testCertKeyPath, - certRefreshPeriod: 1 * time.Second, - serverCertMutex: sync.Mutex{}, - }, - args: args{ - ctx: ctx, - }, - beforeFunc: func() error { - err := copyCert(oldCertPath, testCertPath) - if err != nil { - return err - } - err = copyCert(oldCertKeyPath, testCertKeyPath) - if err != nil { - return err - } - return nil - }, - checkFunc: func(tcc *TLSCertificateCache, got error, want error) error { - if got != nil { - return got - } - cachedCert := tcc.serverCert.Load() - cc, _ := x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) - dc, _ := x509.ParseCertificate(defaultServerCertData.Certificate[0]) - if cc.SerialNumber != dc.SerialNumber { - return errors.New("Serial Number not Matched") - } - // refresh certificate - err = copyCert(newCertPath, testCertPath) - if err != nil { - return err - } - time.Sleep(1 * time.Second) - cachedCert = tcc.serverCert.Load() - cc, _ = x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) - nc, _ := x509.ParseCertificate(newCert.Certificate[0]) - if cc.SerialNumber != nc.SerialNumber { - return errors.New("Serial Number not Matched") - } - return nil - }, - afterFunc: func() error { - cancelFunc() - err := os.Remove(testCertPath) - if err != nil { - t.Errorf("test cert remove failed: %s", err) - return err - } - err = os.Remove(testCertKeyPath) - if err != nil { - t.Errorf("test cert remove failed: %s", err) - return err - } - return nil - }, - } - }(), - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.afterFunc != nil { - defer func() { - if err := tt.afterFunc(); err != nil { - t.Errorf("afterFunc error, error: %v", err) - return - } - }() - } - if tt.beforeFunc != nil { - if err := tt.beforeFunc(); err != nil { - t.Errorf("beforeFunc error, error: %v", err) - return - } - } +// return test{ +// name: "Test refresh server cert", +// fields: fields{ +// serverCert: defaultServerCert, +// serverCertHash: defaultServerCerttHash, +// serverCertKeyHash: defaultServerCerttKeyHash, +// serverCertPath: testCertPath, +// serverCertKeyPath: testCertKeyPath, +// certRefreshPeriod: 1 * time.Second, +// serverCertMutex: sync.Mutex{}, +// }, +// args: args{ +// ctx: ctx, +// }, +// beforeFunc: func() error { +// err := copyCert(oldCertPath, testCertPath) +// if err != nil { +// return err +// } +// err = copyCert(oldCertKeyPath, testCertKeyPath) +// if err != nil { +// return err +// } +// return nil +// }, +// checkFunc: func(tcc *TLSCertificateCache, got error, want error) error { +// if got != nil { +// return got +// } +// cachedCert := tcc.serverCert.Load() +// cc, _ := x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) +// dc, _ := x509.ParseCertificate(defaultServerCertData.Certificate[0]) +// if cc.SerialNumber != dc.SerialNumber { +// return errors.New("Serial Number not Matched") +// } +// // refresh certificate +// err = copyCert(newCertPath, testCertPath) +// if err != nil { +// return err +// } +// time.Sleep(1 * time.Second) +// cachedCert = tcc.serverCert.Load() +// cc, _ = x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) +// nc, _ := x509.ParseCertificate(newCert.Certificate[0]) +// if cc.SerialNumber != nc.SerialNumber { +// return errors.New("Serial Number not Matched") +// } +// return nil +// }, +// afterFunc: func() error { +// cancelFunc() +// err := os.Remove(testCertPath) +// if err != nil { +// t.Errorf("test cert remove failed: %s", err) +// return err +// } +// err = os.Remove(testCertKeyPath) +// if err != nil { +// t.Errorf("test cert remove failed: %s", err) +// return err +// } +// return nil +// }, +// } +// }(), +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// if tt.afterFunc != nil { +// defer func() { +// if err := tt.afterFunc(); err != nil { +// t.Errorf("afterFunc error, error: %v", err) +// return +// } +// }() +// } +// if tt.beforeFunc != nil { +// if err := tt.beforeFunc(); err != nil { +// t.Errorf("beforeFunc error, error: %v", err) +// return +// } +// } - tcc := &TLSCertificateCache{ - serverCert: tt.fields.serverCert, - serverCertHash: tt.fields.serverCertHash, - serverCertKeyHash: tt.fields.serverCertKeyHash, - serverCertPath: tt.fields.serverCertPath, - serverCertKeyPath: tt.fields.serverCertKeyPath, - serverCertMutex: tt.fields.serverCertMutex, - certRefreshPeriod: tt.fields.certRefreshPeriod, - } - //errCh := make(chan error) - go func() error { - return tcc.RefreshCertificate(tt.args.ctx) - }() - if err := tt.checkFunc(tcc, got, tt.want); err != nil { - t.Errorf("TLSCertificateCache.RefreshCertificate() error = %v, want %v", err, tt.want) - } - }) - } -} +// tcc := &TLSCertificateCache{ +// serverCert: tt.fields.serverCert, +// serverCertHash: tt.fields.serverCertHash, +// serverCertKeyHash: tt.fields.serverCertKeyHash, +// serverCertPath: tt.fields.serverCertPath, +// serverCertKeyPath: tt.fields.serverCertKeyPath, +// serverCertMutex: tt.fields.serverCertMutex, +// certRefreshPeriod: tt.fields.certRefreshPeriod, +// } +// //errCh := make(chan error) +// go func() error { +// return tcc.RefreshCertificate(tt.args.ctx) +// }() +// if err := tt.checkFunc(tcc, got, tt.want); err != nil { +// t.Errorf("TLSCertificateCache.RefreshCertificate() error = %v, want %v", err, tt.want) +// } +// }) +// } +// } From 8f7e551b65b38775320b7ececafd55bba7703a03 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 12:25:18 +0900 Subject: [PATCH 34/58] Fix error message Signed-off-by: Kyo Fujisaki --- service/server_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/server_test.go b/service/server_test.go index 57e9b86..3fc2054 100644 --- a/service/server_test.go +++ b/service/server_test.go @@ -186,7 +186,7 @@ func TestNewServer(t *testing.T) { gotCert, _ := x509.ParseCertificate(got.(*server).srv.TLSConfig.Certificates[0].Certificate[0]) wantCert, _ := x509.ParseCertificate(want.(*server).srv.TLSConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber == nil || gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { - return fmt.Errorf("Certificate SerialNumber not equals\tgot: %s\twant: %s", got.(*server).srv.TLSConfig.Certificates[0].Leaf.Subject.CommonName, want.(*server).srv.TLSConfig.Certificates[0].Leaf.Subject.CommonName) + return fmt.Errorf("Certificate SerialNumber not equals\tgot: %s\twant: %s", gotCert.SerialNumber.String(), wantCert.SerialNumber.String()) } return nil }, From ab7a24e89789bf7157a894d288b15d3672bbeabf Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 15:15:19 +0900 Subject: [PATCH 35/58] Add TestTLSCertificateCache_RefreshCertificate Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 442 ++++++++++++++++----------- test/data/invalid_newDummyServer.crt | 19 ++ 2 files changed, 287 insertions(+), 174 deletions(-) create mode 100644 test/data/invalid_newDummyServer.crt diff --git a/service/tls_test.go b/service/tls_test.go index 1ae99d7..ae5e0c8 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -1,11 +1,14 @@ package service import ( + "context" "crypto/tls" "crypto/x509" "errors" "fmt" + "io" "io/ioutil" + "os" "reflect" "strings" "sync" @@ -906,183 +909,274 @@ func TestTLSCertificateCache_getCertificate(t *testing.T) { } } -// func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { -// type fields struct { -// serverCert atomic.Value -// serverCertHash []byte -// serverCertKeyHash []byte -// serverCertPath string -// serverCertKeyPath string -// serverCertMutex sync.Mutex -// certRefreshPeriod time.Duration -// } -// type args struct { -// ctx context.Context -// } -// type test struct { -// name string -// fields fields -// args args -// want error -// beforeFunc func() error -// checkFunc func(*TLSCertificateCache, error, error) error -// afterFunc func() error -// } -// copyCert := func(srcPath, dstPath string) error { -// src, err := os.Open(srcPath) -// if err != nil { -// t.Errorf("test cert copy failed: %s", err) -// return err -// } -// defer src.Close() +func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { + type fields struct { + serverCert atomic.Value + serverCertHash []byte + serverCertKeyHash []byte + serverCertPath string + serverCertKeyPath string + serverCertMutex sync.Mutex + certRefreshPeriod time.Duration + } + type args struct { + ctx context.Context + } + type test struct { + name string + fields fields + args args + want error + beforeFunc func() error + checkFunc func(*TLSCertificateCache, error) error + afterFunc func() error + } + copyCert := func(srcPath, dstPath string) error { + src, err := os.Open(srcPath) + if err != nil { + t.Errorf("test cert copy failed: %s", err) + return err + } + defer src.Close() -// dst, err := os.Create(dstPath) -// if err != nil { -// t.Errorf("test cert copy failed: %s", err) -// return err -// } -// defer dst.Close() + dst, err := os.Create(dstPath) + if err != nil { + t.Errorf("test cert copy failed: %s", err) + return err + } + defer dst.Close() -// _, err = io.Copy(dst, src) -// if err != nil { -// t.Errorf("test cert copy failed: %s", err) -// return err -// } -// return nil -// } -// testCertPath := "../test/data/test.crt" -// testCertKeyPath := "../test/data/test.key" -// oldCertPath := "../test/data/dummyServer.crt" -// oldCertKeyPath := "../test/data/dummyServer.key" -// newCertPath := "../test/data/newDummyServer.crt" + _, err = io.Copy(dst, src) + if err != nil { + t.Errorf("test cert copy failed: %s", err) + return err + } + return nil + } + testCertPath := "../test/data/test.crt" + testCertKeyPath := "../test/data/test.key" + oldCertPath := "../test/data/dummyServer.crt" + oldCertKeyPath := "../test/data/dummyServer.key" + newCertPath := "../test/data/newDummyServer.crt" + invalidNewCertPath := "../test/data/invalid_newDummyServer.crt" -// var defaultServerCert atomic.Value -// defaultServerCertData, err := tls.LoadX509KeyPair("../test/data/dummyServer.crt", "../test/data/dummyServer.key") -// if err != nil { -// t.Errorf("LoadX509KeyPair failed: %s", err) -// return -// } -// defaultServerCert.Store(&defaultServerCertData) -// defaultServerCerttHash, err := hash(oldCertPath) -// if err != nil { -// t.Errorf("hash failed: %s", err) -// return -// } -// defaultServerCerttKeyHash, _ := hash(oldCertKeyPath) -// if err != nil { -// t.Errorf("hash failed: %s", err) -// return -// } -// // newCert key == oldCert key -// newCert, err := tls.LoadX509KeyPair(newCertPath, oldCertKeyPath) -// if err != nil { -// t.Errorf("LoadX509KeyPair failed: %s", err) -// return -// } + var oldCert atomic.Value + oldCertData, err := tls.LoadX509KeyPair("../test/data/dummyServer.crt", "../test/data/dummyServer.key") + if err != nil { + t.Errorf("LoadX509KeyPair failed: %s", err) + return + } + oldCert.Store(&oldCertData) + oldCertHash, err := hash(oldCertPath) + if err != nil { + t.Errorf("hash failed: %s", err) + return + } + oldCertKeyHash, _ := hash(oldCertKeyPath) + if err != nil { + t.Errorf("hash failed: %s", err) + return + } + // newCert key == oldCert key + newCert, err := tls.LoadX509KeyPair(newCertPath, oldCertKeyPath) + if err != nil { + t.Errorf("LoadX509KeyPair failed: %s", err) + return + } -// tests := []test{ -// func() test { -// ctx, cancelFunc := context.WithCancel(context.Background()) + tests := []test{ + func() test { + ctx, cancelFunc := context.WithCancel(context.Background()) -// return test{ -// name: "Test refresh server cert", -// fields: fields{ -// serverCert: defaultServerCert, -// serverCertHash: defaultServerCerttHash, -// serverCertKeyHash: defaultServerCerttKeyHash, -// serverCertPath: testCertPath, -// serverCertKeyPath: testCertKeyPath, -// certRefreshPeriod: 1 * time.Second, -// serverCertMutex: sync.Mutex{}, -// }, -// args: args{ -// ctx: ctx, -// }, -// beforeFunc: func() error { -// err := copyCert(oldCertPath, testCertPath) -// if err != nil { -// return err -// } -// err = copyCert(oldCertKeyPath, testCertKeyPath) -// if err != nil { -// return err -// } -// return nil -// }, -// checkFunc: func(tcc *TLSCertificateCache, got error, want error) error { -// if got != nil { -// return got -// } -// cachedCert := tcc.serverCert.Load() -// cc, _ := x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) -// dc, _ := x509.ParseCertificate(defaultServerCertData.Certificate[0]) -// if cc.SerialNumber != dc.SerialNumber { -// return errors.New("Serial Number not Matched") -// } -// // refresh certificate -// err = copyCert(newCertPath, testCertPath) -// if err != nil { -// return err -// } -// time.Sleep(1 * time.Second) -// cachedCert = tcc.serverCert.Load() -// cc, _ = x509.ParseCertificate(cachedCert.(tls.Certificate).Certificate[0]) -// nc, _ := x509.ParseCertificate(newCert.Certificate[0]) -// if cc.SerialNumber != nc.SerialNumber { -// return errors.New("Serial Number not Matched") -// } -// return nil -// }, -// afterFunc: func() error { -// cancelFunc() -// err := os.Remove(testCertPath) -// if err != nil { -// t.Errorf("test cert remove failed: %s", err) -// return err -// } -// err = os.Remove(testCertKeyPath) -// if err != nil { -// t.Errorf("test cert remove failed: %s", err) -// return err -// } -// return nil -// }, -// } -// }(), -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// if tt.afterFunc != nil { -// defer func() { -// if err := tt.afterFunc(); err != nil { -// t.Errorf("afterFunc error, error: %v", err) -// return -// } -// }() -// } -// if tt.beforeFunc != nil { -// if err := tt.beforeFunc(); err != nil { -// t.Errorf("beforeFunc error, error: %v", err) -// return -// } -// } + return test{ + name: "Test refresh server cert and stop", + fields: fields{ + serverCert: oldCert, + serverCertHash: oldCertHash, + serverCertKeyHash: oldCertKeyHash, + serverCertPath: testCertPath, + serverCertKeyPath: testCertKeyPath, + certRefreshPeriod: 500 * time.Millisecond, + serverCertMutex: sync.Mutex{}, + }, + args: args{ + ctx: ctx, + }, + beforeFunc: func() error { + err := copyCert(oldCertPath, testCertPath) + if err != nil { + return err + } + err = copyCert(oldCertKeyPath, testCertKeyPath) + if err != nil { + return err + } + return nil + }, + checkFunc: func(tcc *TLSCertificateCache, want error) error { + cachedCert := tcc.serverCert.Load() + cc, _ := x509.ParseCertificate(cachedCert.(*tls.Certificate).Certificate[0]) + oc, _ := x509.ParseCertificate(oldCertData.Certificate[0]) + if cc.SerialNumber.String() != oc.SerialNumber.String() { + return errors.New("cached cert / old cert Serial Number not Matched") + } + // refresh certificate + err = copyCert(newCertPath, testCertPath) + if err != nil { + return err + } + // wait refresh period + time.Sleep(1 * time.Second) + cachedCert = tcc.serverCert.Load() + cc, _ = x509.ParseCertificate(cachedCert.(*tls.Certificate).Certificate[0]) + nc, _ := x509.ParseCertificate(newCert.Certificate[0]) + // check cert refreshed + if cc.SerialNumber.String() != nc.SerialNumber.String() { + return errors.New("cert not refreshed") + } + // refresh stop + cancelFunc() + err = copyCert(oldCertPath, testCertPath) + if err != nil { + return err + } + time.Sleep(1 * time.Second) + if cc.SerialNumber.String() == oc.SerialNumber.String() { + return errors.New("refresh not stopped") + } + return nil + }, + afterFunc: func() error { + cancelFunc() + err := os.Remove(testCertPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } + err = os.Remove(testCertKeyPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } + return nil + }, + } + }(), + func() test { + ctx, cancelFunc := context.WithCancel(context.Background()) + + return test{ + name: "Test invalid cert not refresh, next period refresh success", + fields: fields{ + serverCert: oldCert, + serverCertHash: oldCertHash, + serverCertKeyHash: oldCertKeyHash, + serverCertPath: testCertPath, + serverCertKeyPath: testCertKeyPath, + certRefreshPeriod: 500 * time.Millisecond, + serverCertMutex: sync.Mutex{}, + }, + args: args{ + ctx: ctx, + }, + beforeFunc: func() error { + err := copyCert(oldCertPath, testCertPath) + if err != nil { + return err + } + err = copyCert(oldCertKeyPath, testCertKeyPath) + if err != nil { + return err + } + return nil + }, + checkFunc: func(tcc *TLSCertificateCache, want error) error { + cachedCert := tcc.serverCert.Load() + cc, _ := x509.ParseCertificate(cachedCert.(*tls.Certificate).Certificate[0]) + oc, _ := x509.ParseCertificate(oldCertData.Certificate[0]) + if cc.SerialNumber.String() != oc.SerialNumber.String() { + return errors.New("cached cert / old cert Serial Number not Matched") + } + // refresh certificate but invalid + err = copyCert(invalidNewCertPath, testCertPath) + if err != nil { + return err + } + // wait refresh period + time.Sleep(1 * time.Second) + cachedCert = tcc.serverCert.Load() + cc, _ = x509.ParseCertificate(cachedCert.(*tls.Certificate).Certificate[0]) + // check cert not refreshed + if cc.SerialNumber.String() != oc.SerialNumber.String() { + return errors.New("cert refreshed") + } + // refresh certificate + err = copyCert(newCertPath, testCertPath) + if err != nil { + return err + } + // wait refresh period + time.Sleep(1 * time.Second) + cachedCert = tcc.serverCert.Load() + cc, _ = x509.ParseCertificate(cachedCert.(*tls.Certificate).Certificate[0]) + nc, _ := x509.ParseCertificate(newCert.Certificate[0]) + // check cert refreshed + if cc.SerialNumber.String() != nc.SerialNumber.String() { + return errors.New("cert not refreshed") + } + cancelFunc() + return nil + }, + afterFunc: func() error { + cancelFunc() + err := os.Remove(testCertPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } + err = os.Remove(testCertKeyPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } + return nil + }, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.afterFunc != nil { + defer func() { + if err := tt.afterFunc(); err != nil { + t.Errorf("afterFunc error, error: %v", err) + return + } + }() + } + if tt.beforeFunc != nil { + if err := tt.beforeFunc(); err != nil { + t.Errorf("beforeFunc error, error: %v", err) + return + } + } -// tcc := &TLSCertificateCache{ -// serverCert: tt.fields.serverCert, -// serverCertHash: tt.fields.serverCertHash, -// serverCertKeyHash: tt.fields.serverCertKeyHash, -// serverCertPath: tt.fields.serverCertPath, -// serverCertKeyPath: tt.fields.serverCertKeyPath, -// serverCertMutex: tt.fields.serverCertMutex, -// certRefreshPeriod: tt.fields.certRefreshPeriod, -// } -// //errCh := make(chan error) -// go func() error { -// return tcc.RefreshCertificate(tt.args.ctx) -// }() -// if err := tt.checkFunc(tcc, got, tt.want); err != nil { -// t.Errorf("TLSCertificateCache.RefreshCertificate() error = %v, want %v", err, tt.want) -// } -// }) -// } -// } + tcc := &TLSCertificateCache{ + serverCert: tt.fields.serverCert, + serverCertHash: tt.fields.serverCertHash, + serverCertKeyHash: tt.fields.serverCertKeyHash, + serverCertPath: tt.fields.serverCertPath, + serverCertKeyPath: tt.fields.serverCertKeyPath, + serverCertMutex: tt.fields.serverCertMutex, + certRefreshPeriod: tt.fields.certRefreshPeriod, + } + //errCh := make(chan error) + go func() error { + return tcc.RefreshCertificate(tt.args.ctx) + }() + if err := tt.checkFunc(tcc, tt.want); err != nil { + t.Errorf("TLSCertificateCache.RefreshCertificate() error = %v, want %v", err, tt.want) + } + }) + } +} diff --git a/test/data/invalid_newDummyServer.crt b/test/data/invalid_newDummyServer.crt new file mode 100644 index 0000000..f62c633 --- /dev/null +++ b/test/data/invalid_newDummyServer.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIC/jCCAeYCCQDYE18V6q0sozANBgkqhkiG9w0BAQUFADBBMQswCQYDVQQGEwJK +UDEOMAwGA1UECAwFVG9reW8xEjAQBgNVBAcMCWNoaXlvZGFrdTEOMAwGA1UECgwF +eWFob28wHhcNMjMwMjA3MDU0NzEwWhcNMzMwMjA0MDU0NzEwWjBBMQswCQYDVQQG +EwJKUDEOMAwGA1UECAwFVG9reW8xEjAQBgNVBAcMCWNoaXlvZGFrdTEOMAwGA1UE +CgwFeWFob28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUBKwhDHGl +yj5GFOhDlyH67BsW9l/A2rna21WVnq69ANN9WmaiFp32C+2BrqF8DYwtPyoaTgXd ++WQ8wmtjlYwp378P743vBgf6tazZYKbCa7PZsNMLCdI0KurXf6iwOkCY7zQQcOji +STUVtfr009LFgc8KBduxBiyEe1tXoN1qrjUZgty/05EtW09QFE2XqVFQS4jN6CC4 +whmN3J7yWHyZv3K1lO0aeC31kIOqw0/DxD/Dj/RMJBpz0q3jeYJwvM3rE4DNajTW +6BLiPWVkLW4H3paRE6jYfyCRazE8tLLWHzGGqPWUElV2rKWzdV6OnyWLFxWC2RK/ +71IuZS1UDqmHAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAAlLPl+OXQnoSkCGUvVa +Dz/9qCBbho38VBABRjN9ZwxRZXhF6La38gTpjsbu5iPidND+rEGqjQQIszahoNk2 +gcec23SQfpEswnoi3lHO/sEcxchKT6VD2RTgOCV9LQFebbu3EmH7JAsz8xmgWVMR +4T34E+KkMMq3ScUGXcbX812q3CoTglG2HVwJIzW+wYF6l7aVsrHSI51yKAs41fO5 +UY9zacsWYvEeOet6WXep0/G7gJuxJ2W7d6EP82AfduiOW7LLBUUny4eukbnKljaG +uR2iTOkr/PWcxMpO0+iNyvRpyUwc3A1KsL7h2z8/VH9QEbCMO3xl0GTTvDcFltAk +ug8=!invalid_cert! +-----END CERTIFICATE----- From 2ccaf6e81552a0a4c85aea34879362a4df9f43f8 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 17:05:44 +0900 Subject: [PATCH 36/58] Add CertRefreshPeriod 0 check Signed-off-by: Kyo Fujisaki --- usecase/authz_proxyd.go | 4 +- usecase/authz_proxyd_test.go | 247 +++++++++++++++++++++++++++++++++-- 2 files changed, 239 insertions(+), 12 deletions(-) diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index b852ebe..dd36321 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -77,7 +77,7 @@ func New(cfg config.Config) (AuthzProxyDaemon, error) { var tlsCertificateCache *service.TLSCertificateCache if cfg.Server.TLS.Enable { // Enable auto-reload if CertRefreshPeriod is set. - if cfg.Server.TLS.CertRefreshPeriod != "" { + if cfg.Server.TLS.CertRefreshPeriod != "" && cfg.Server.TLS.CertRefreshPeriod != "0" { configWithCache, err := service.NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS) if err != nil { return nil, errors.Wrap(err, "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS)") @@ -162,7 +162,7 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { // handle cert refresh goroutine error // prevent run RefreshCertificate if Enable is false and CertRefreshPeriod is set - if g.cfg.Server.TLS.Enable && g.cfg.Server.TLS.CertRefreshPeriod != "" { + if g.cfg.Server.TLS.Enable && g.cfg.Server.TLS.CertRefreshPeriod != "" && g.cfg.Server.TLS.CertRefreshPeriod != "0" { eg.Go(func() error { return g.tlsCertificateCache.RefreshCertificate(ctx) }) diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index 3f449cf..13e6fb0 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -55,7 +55,7 @@ func TestNew(t *testing.T) { Endpoint: "/dummy", }, TLS: config.TLS{ - Enable: true, + Enable: true, CertRefreshPeriod: "24h", }, }, @@ -81,6 +81,128 @@ func TestNew(t *testing.T) { if got.(*authzProxyDaemon).server == nil { return errors.New("got.server is nil") } + if got.(*authzProxyDaemon).tlsCertificateCache == nil { + return errors.New("got.tlsCertificateCache is nil") + } + return nil + }, + } + }(), + func() test { + cfg := config.Config{ + Athenz: config.Athenz{ + URL: "athenz.io", + }, + Authorization: config.Authorization{ + AthenzDomains: []string{"dummyDom1", "dummyDom2"}, + PublicKey: config.PublicKey{ + SysAuthDomain: "dummy.sys.auth", + RefreshPeriod: "10s", + ETagExpiry: "10s", + ETagPurgePeriod: "10s", + }, + Policy: config.Policy{ + ExpiryMargin: "10s", + RefreshPeriod: "10s", + PurgePeriod: "10s", + }, + AccessToken: config.AccessToken{ + Enable: true, + }, + }, + Server: config.Server{ + HealthCheck: config.HealthCheck{ + Endpoint: "/dummy", + }, + TLS: config.TLS{ + Enable: true, + CertRefreshPeriod: "0", + }, + }, + Proxy: config.Proxy{ + BufferSize: 512, + }, + } + return test{ + name: "new CertRefreshPeriod is 0, tlsCertificateCache is nil.", + args: args{ + cfg: cfg, + }, + checkFunc: func(got AuthzProxyDaemon) error { + if got == nil { + return errors.New("got is nil") + } + if !reflect.DeepEqual(got.(*authzProxyDaemon).cfg, cfg) { + return errors.New("got.cfg does not equal") + } + if got.(*authzProxyDaemon).athenz == nil { + return errors.New("got.athenz is nil") + } + if got.(*authzProxyDaemon).server == nil { + return errors.New("got.server is nil") + } + if got.(*authzProxyDaemon).tlsCertificateCache != nil { + return errors.New("got.tlsCertificateCache is not nil") + } + return nil + }, + } + }(), + func() test { + cfg := config.Config{ + Athenz: config.Athenz{ + URL: "athenz.io", + }, + Authorization: config.Authorization{ + AthenzDomains: []string{"dummyDom1", "dummyDom2"}, + PublicKey: config.PublicKey{ + SysAuthDomain: "dummy.sys.auth", + RefreshPeriod: "10s", + ETagExpiry: "10s", + ETagPurgePeriod: "10s", + }, + Policy: config.Policy{ + ExpiryMargin: "10s", + RefreshPeriod: "10s", + PurgePeriod: "10s", + }, + AccessToken: config.AccessToken{ + Enable: true, + }, + }, + Server: config.Server{ + HealthCheck: config.HealthCheck{ + Endpoint: "/dummy", + }, + TLS: config.TLS{ + Enable: true, + }, + }, + Proxy: config.Proxy{ + BufferSize: 512, + }, + } + return test{ + name: "new CertRefreshPeriod not set, tlsCertificateCache is nil.", + args: args{ + cfg: cfg, + }, + checkFunc: func(got AuthzProxyDaemon) error { + if got == nil { + return errors.New("got is nil") + } + if !reflect.DeepEqual(got.(*authzProxyDaemon).cfg, cfg) { + return errors.New("got.cfg does not equal") + } + if got.(*authzProxyDaemon).athenz == nil { + return errors.New("got.athenz is nil") + } + if got.(*authzProxyDaemon).server == nil { + return errors.New("got.server is nil") + } + if got.(*authzProxyDaemon).tlsCertificateCache != nil { + return errors.New("got.tlsCertificateCache is not nil") + } return nil }, } @@ -129,7 +251,7 @@ func TestNew(t *testing.T) { }, Server: config.Server{ TLS: config.TLS{ - Enable: true, + Enable: true, CertRefreshPeriod: "abcdefg", }, }, @@ -220,9 +342,10 @@ func Test_authzProxyDaemon_Init(t *testing.T) { func Test_authzProxyDaemon_Start(t *testing.T) { type fields struct { - cfg config.Config - athenz service.Authorizationd - server service.Server + cfg config.Config + athenz service.Authorizationd + server service.Server + tlsCertificateCache *service.TLSCertificateCache } type args struct { ctx context.Context @@ -239,12 +362,24 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return errs[i].Error() < errs[j].Error() } } + defaultTLSConfig := config.TLS{ + Enable: true, + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + CertRefreshPeriod: "5s", + } + defaultTLSConfigWithTLSCertificateCache, _ := service.NewTLSConfigWithTLSCertificateCache(defaultTLSConfig) tests := []test{ func() test { ctx, cancel := context.WithCancel(context.Background()) return test{ name: "Daemon start success", fields: fields{ + cfg: config.Config{ + Server: config.Server{ + TLS: defaultTLSConfig, + }, + }, athenz: &service.AuthorizerdMock{ StartFunc: func(ctx context.Context) <-chan error { ech := make(chan error) @@ -270,6 +405,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, + tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertficateCache, }, args: args{ ctx: ctx, @@ -455,6 +591,11 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return test{ name: "Daemon start end successfully, server shutdown without error", fields: fields{ + cfg: config.Config{ + Server: config.Server{ + TLS: defaultTLSConfig, + }, + }, athenz: &service.AuthorizerdMock{ StartFunc: func(ctx context.Context) <-chan error { ech := make(chan error) @@ -481,6 +622,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, + tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertficateCache, }, args: args{ ctx: ctx, @@ -524,6 +666,11 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return test{ name: "Daemon start end successfully, server shutdown >1 errors", fields: fields{ + cfg: config.Config{ + Server: config.Server{ + TLS: defaultTLSConfig, + }, + }, athenz: &service.AuthorizerdMock{ StartFunc: func(ctx context.Context) <-chan error { ech := make(chan error) @@ -550,6 +697,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, + tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertficateCache, }, args: args{ ctx: ctx, @@ -593,13 +741,15 @@ func Test_authzProxyDaemon_Start(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) dummyErr := errors.New("dummy") return test{ - name: "Daemon stops when TLS.Enable = false and CertRefreshPeriod is set", + name: "Cert refrsh daemon stops when TLS.Enable = false and CertRefreshPeriod is set", fields: fields{ cfg: config.Config{ Server: config.Server{ TLS: config.TLS{ - Enable: false, + Enable: false, CertRefreshPeriod: "3d", + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", }, }, }, @@ -638,6 +788,82 @@ func Test_authzProxyDaemon_Start(t *testing.T) { mux.Lock() go func() { defer mux.Unlock() + // this can be execute == through eg.Wait() == refresh daemon is not running + err, ok := <-got + if !ok { + return + } + gotErrs = append(gotErrs, err) + }() + time.Sleep(time.Second) + + mux.Lock() + defer mux.Unlock() + + // check only send errors once and the errors are expected ignoring order + sort.Slice(gotErrs[0], getLessErrorFunc(gotErrs[0])) + sort.Slice(wantErrs, getLessErrorFunc(wantErrs)) + gotErrsStr := fmt.Sprintf("%v", gotErrs[0]) + wantErrsStr := fmt.Sprintf("%v", wantErrs) + if len(gotErrs) != 1 || !reflect.DeepEqual(gotErrsStr, wantErrsStr) { + return errors.Errorf("Invalid err, got: %v, want: %v", gotErrsStr, wantErrsStr) + } + + cancel() + return nil + }, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + dummyErr := errors.New("dummy") + return test{ + name: "Cert refrsh daemon stops when TLS.Enable = true and CertRefreshPeriod is 0", + fields: fields{ + cfg: config.Config{ + Server: config.Server{ + TLS: config.TLS{ + Enable: true, + CertRefreshPeriod: "0", + }, + }, + }, + athenz: &service.AuthorizerdMock{ + StartFunc: func(ctx context.Context) <-chan error { + ech := make(chan error) + go func() { + defer close(ech) + <-ctx.Done() + ech <- ctx.Err() + }() + return ech + }, + }, + server: &service.ServerMock{ + ListenAndServeFunc: func(ctx context.Context) <-chan []error { + ech := make(chan []error) + go func() { + defer close(ech) + ech <- []error{errors.WithMessage(dummyErr, "server fails")} + }() + return ech + }, + }, + }, + args: args{ + ctx: ctx, + }, + wantErrs: []error{ + errors.WithMessage(dummyErr, "server fails"), + }, + checkFunc: func(got <-chan []error, wantErrs []error) error { + mux := &sync.Mutex{} + + gotErrs := make([][]error, 0) + mux.Lock() + go func() { + defer mux.Unlock() + // this can be execute == through eg.Wait() == refresh daemon is not running err, ok := <-got if !ok { return @@ -667,9 +893,10 @@ func Test_authzProxyDaemon_Start(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { g := &authzProxyDaemon{ - cfg: tt.fields.cfg, - athenz: tt.fields.athenz, - server: tt.fields.server, + cfg: tt.fields.cfg, + athenz: tt.fields.athenz, + server: tt.fields.server, + tlsCertificateCache: tt.fields.tlsCertificateCache, } got := g.Start(tt.args.ctx) if err := tt.checkFunc(got, tt.wantErrs); err != nil { From 0764566dfeb46072a12e8e437e1f036bf7dbea27 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 17:05:53 +0900 Subject: [PATCH 37/58] Add comment Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/service/tls_test.go b/service/tls_test.go index ae5e0c8..415aa66 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -379,6 +379,7 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } + // config.TLS.certRefreshPeriod is not set, TLSCertificateCache.certRefreshPeriod is 0 if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) } From 22ca473427f9ec0b3d52226d011f8ed49f34765d Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 17:12:40 +0900 Subject: [PATCH 38/58] Add certRefreshPeriod option Signed-off-by: Kyo Fujisaki --- test/data/example_config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/data/example_config.yaml b/test/data/example_config.yaml index a3a30e5..3daf991 100644 --- a/test/data/example_config.yaml +++ b/test/data/example_config.yaml @@ -10,6 +10,7 @@ server: certPath: "test/data/dummyServer.crt" keyPath: "test/data/dummyServer.key" caPath: "test/data/dummyCa.pem" + certRefreshPeriod: "24h" healthCheck: port: 6082 endpoint: /healthz From ed054111b9a09d37fa1586de4beb6593bbd03c83 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 17:15:56 +0900 Subject: [PATCH 39/58] Add comment Signed-off-by: Kyo Fujisaki --- config/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config.go b/config/config.go index 2843d3e..a50c63f 100755 --- a/config/config.go +++ b/config/config.go @@ -90,7 +90,7 @@ type TLS struct { // CAPath represents the CA certificate chain file path for verifying client certificates. CAPath string `yaml:"caPath"` - // CertRefreshPeriod represents + // CertRefreshPeriod represents the time to read the certificate again. CertRefreshPeriod string `yaml:"certRefreshPeriod"` } From bbd299a0280c27547bd51ef23357385c5f47c3b3 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 17:24:08 +0900 Subject: [PATCH 40/58] Fix test for config(add CertRefreshPeriod) Signed-off-by: Kyo Fujisaki --- config/config_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/config/config_test.go b/config/config_test.go index 8627890..fb3822b 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -84,10 +84,11 @@ func TestNew(t *testing.T) { ShutdownTimeout: "10s", ShutdownDelay: "9s", TLS: TLS{ - Enable: true, - CertPath: "test/data/dummyServer.crt", - KeyPath: "test/data/dummyServer.key", - CAPath: "test/data/dummyCa.pem", + Enable: true, + CertPath: "test/data/dummyServer.crt", + KeyPath: "test/data/dummyServer.key", + CAPath: "test/data/dummyCa.pem", + CertRefreshPeriod: "24h", }, HealthCheck: HealthCheck{ Port: 6082, From b24733150d7c31623b45c974bd757f5d04904221 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 18:45:14 +0900 Subject: [PATCH 41/58] Fix typo Signed-off-by: Kyo Fujisaki --- service/option.go | 2 +- service/option_test.go | 2 +- service/server.go | 10 +++++----- service/server_test.go | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/service/option.go b/service/option.go index ed0bf21..49b3b05 100644 --- a/service/option.go +++ b/service/option.go @@ -43,7 +43,7 @@ func WithGRPCCloser(c io.Closer) Option { // WithTLSConfig returns a TLS Config functional option func WithTLSConfig(t *tls.Config) Option { return func(s *server) { - s.tlsConifg = t + s.tlsConfig = t } } diff --git a/service/option_test.go b/service/option_test.go index 9ecf976..2d951c3 100644 --- a/service/option_test.go +++ b/service/option_test.go @@ -226,7 +226,7 @@ func TestWithTLSConfig(t *testing.T) { checkFunc: func(o Option) error { srv := &server{} o(srv) - if srv.tlsConifg.MinVersion != tls.VersionTLS12 { + if srv.tlsConfig.MinVersion != tls.VersionTLS12 { return errors.New("value cannot set") } return nil diff --git a/service/server.go b/service/server.go index b73dd9b..fc882ae 100644 --- a/service/server.go +++ b/service/server.go @@ -52,7 +52,7 @@ type server struct { grpcSrvRunning bool grpcCloser io.Closer - tlsConifg *tls.Config + tlsConfig *tls.Config // Health Check server hcsrv *http.Server @@ -103,8 +103,8 @@ func NewServer(opts ...Option) (Server, error) { o(s) } - if s.cfg.TLS.Enable && s.tlsConifg == nil { - return nil, errors.New("s.cfg.TLS.Enable is true, but s.tlsConifg is nil.") + if s.cfg.TLS.Enable && s.tlsConfig == nil { + return nil, errors.New("s.cfg.TLS.Enable is true, but s.tlsConfig is nil.") } if s.grpcSrvEnable() { @@ -114,7 +114,7 @@ func NewServer(opts ...Option) (Server, error) { } if s.cfg.TLS.Enable { - gopts = append(gopts, grpc.Creds(credentials.NewTLS(s.tlsConifg))) + gopts = append(gopts, grpc.Creds(credentials.NewTLS(s.tlsConfig))) } s.grpcSrv = grpc.NewServer(gopts...) @@ -125,7 +125,7 @@ func NewServer(opts ...Option) (Server, error) { } s.srv.SetKeepAlivesEnabled(true) if s.cfg.TLS.Enable { - s.srv.TLSConfig = s.tlsConifg + s.srv.TLSConfig = s.tlsConfig } } diff --git a/service/server_test.go b/service/server_test.go index 3fc2054..817c9d7 100644 --- a/service/server_test.go +++ b/service/server_test.go @@ -288,7 +288,7 @@ func TestNewServer(t *testing.T) { }, }, want: nil, - wantErr: errors.New("s.cfg.TLS.Enable is true, but s.tlsConifg is nil."), + wantErr: errors.New("s.cfg.TLS.Enable is true, but s.tlsConfig is nil."), checkFunc: func(got, want Server, gotErr, wantErr error) error { if gotErr.Error() != wantErr.Error() { return errors.Errorf("got error is not matched with want error, got: %s, want: %s", gotErr, wantErr) From b653f3cc8688bc6eb645498be7b12d1689997f9e Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 18:47:59 +0900 Subject: [PATCH 42/58] Fix typo Signed-off-by: Kyo Fujisaki --- service/tls.go | 8 +++--- service/tls_test.go | 48 ++++++++++++++++++------------------ usecase/authz_proxyd.go | 2 +- usecase/authz_proxyd_test.go | 6 ++--- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/service/tls.go b/service/tls.go index b96e00d..5cfb562 100644 --- a/service/tls.go +++ b/service/tls.go @@ -46,8 +46,8 @@ type TLSCertificateCache struct { } type TLSConfigWithTLSCertificateCache struct { - TLSConfig *tls.Config - TLSCertficateCache *TLSCertificateCache + TLSConfig *tls.Config + TLSCertificateCache *TLSCertificateCache } // NewTLSConfig returns a *tls.Config struct or error. @@ -157,8 +157,8 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti } return &TLSConfigWithTLSCertificateCache{ - TLSConfig: t, - TLSCertficateCache: tcc, + TLSConfig: t, + TLSCertificateCache: tcc, }, nil } diff --git a/service/tls_test.go b/service/tls_test.go index 415aa66..f5edb76 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -374,14 +374,14 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { if got.TLSConfig.MinVersion != want.TLSConfig.MinVersion { return fmt.Errorf("MinVersion not Matched :\tgot %d\twant %d", got.TLSConfig.MinVersion, want.TLSConfig.MinVersion) } - gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } // config.TLS.certRefreshPeriod is not set, TLSCertificateCache.certRefreshPeriod is 0 - if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) } return nil }, @@ -430,13 +430,13 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { return fmt.Errorf("CurvePreferences not Find :\twant %d", want.TLSConfig.MinVersion) } } - gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) } return nil }, @@ -473,13 +473,13 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { if got.TLSConfig.SessionTicketsDisabled != want.TLSConfig.SessionTicketsDisabled { return fmt.Errorf("SessionTicketsDisabled not matched :\tgot %v\twant %v", got.TLSConfig.SessionTicketsDisabled, want.TLSConfig.SessionTicketsDisabled) } - gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) } return nil }, @@ -525,13 +525,13 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { return fmt.Errorf("Certificates PrivateKey not Matched :\twant %s", wantVal.PrivateKey) } } - gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) } return nil }, @@ -568,13 +568,13 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { if got.TLSConfig.ClientAuth != want.TLSConfig.ClientAuth { return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", got.TLSConfig.ClientAuth, want.TLSConfig.ClientAuth) } - gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) } return nil }, @@ -618,13 +618,13 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { if got.TLSConfig.ClientAuth != want.TLSConfig.ClientAuth { return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", got.TLSConfig.ClientAuth, want.TLSConfig.ClientAuth) } - gotCert, _ := x509.ParseCertificate(got.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertficateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - if got.TLSCertficateCache.certRefreshPeriod != want.TLSCertficateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertficateCache.certRefreshPeriod, want.TLSCertficateCache.certRefreshPeriod) + if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) } return nil }, diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index dd36321..87d77e9 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -83,7 +83,7 @@ func New(cfg config.Config) (AuthzProxyDaemon, error) { return nil, errors.Wrap(err, "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS)") } tlsConfig = configWithCache.TLSConfig - tlsCertificateCache = configWithCache.TLSCertficateCache + tlsCertificateCache = configWithCache.TLSCertificateCache } else { tlsConfig, err = service.NewTLSConfig(cfg.Server.TLS) if err != nil { diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index 13e6fb0..e00dff6 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -405,7 +405,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, - tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertficateCache, + tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertificateCache, }, args: args{ ctx: ctx, @@ -622,7 +622,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, - tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertficateCache, + tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertificateCache, }, args: args{ ctx: ctx, @@ -697,7 +697,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, - tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertficateCache, + tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertificateCache, }, args: args{ ctx: ctx, From e214e9a844c3dcc9885891127a170375c66d8b83 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 18:49:40 +0900 Subject: [PATCH 43/58] Fix comment Signed-off-by: Kyo Fujisaki --- service/tls.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/service/tls.go b/service/tls.go index 5cfb562..259778e 100644 --- a/service/tls.go +++ b/service/tls.go @@ -93,8 +93,8 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { } // NewTLSConfigWithTLSCertificateCache returns a *TLSConfigWithTLSCertificateCache struct or error. -// It use to enable the certificate auto-reload feature. -// It reads TLS configuration and initializes *tls.Config / TLSCertificateCache struct. +// It uses to enable the certificate auto-reload feature. +// It reads TLS configuration and initializes *tls.Config / *TLSCertificateCache struct. // It initializes TLS configuration, for example the CA certificate and key to start TLS server. // Server and CA Certificate, and private key will read from files from file paths defined in environment variables. func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCertificateCache, error) { From 359673585b19228fc66dc8c1d5a60e7c4e3ea0fd Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 21:50:09 +0900 Subject: [PATCH 44/58] Fix CertRefreshPeriod 0 check logic Signed-off-by: Kyo Fujisaki --- usecase/authz_proxyd.go | 26 ++++++++++++-- usecase/authz_proxyd_test.go | 70 +++++++++++++++++++++++++++++++----- 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index 87d77e9..305c972 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -76,8 +76,11 @@ func New(cfg config.Config) (AuthzProxyDaemon, error) { var tlsConfig *tls.Config var tlsCertificateCache *service.TLSCertificateCache if cfg.Server.TLS.Enable { - // Enable auto-reload if CertRefreshPeriod is set. - if cfg.Server.TLS.CertRefreshPeriod != "" && cfg.Server.TLS.CertRefreshPeriod != "0" { + ivd, err := isValidDuration(cfg.Server.TLS.CertRefreshPeriod) + if err != nil { + return nil, errors.Wrap(err, "cannot isValidDuration(cfg.Server.TLS.CertRefreshPeriod)") + } + if ivd { configWithCache, err := service.NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS) if err != nil { return nil, errors.Wrap(err, "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS)") @@ -162,7 +165,8 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { // handle cert refresh goroutine error // prevent run RefreshCertificate if Enable is false and CertRefreshPeriod is set - if g.cfg.Server.TLS.Enable && g.cfg.Server.TLS.CertRefreshPeriod != "" && g.cfg.Server.TLS.CertRefreshPeriod != "0" { + ivd, _ := isValidDuration(g.cfg.Server.TLS.CertRefreshPeriod) + if g.cfg.Server.TLS.Enable && ivd { eg.Go(func() error { return g.tlsCertificateCache.RefreshCertificate(ctx) }) @@ -330,3 +334,19 @@ func newAuthzD(cfg config.Config) (service.Authorizationd, error) { } return authorizerd.New(authzOpts...) } + +// isValidDuration returns whether duration is valid. +// "" -> false, "abcdefg" -> false, "0s" -> false, "123s" -> true +func isValidDuration(durationString string) (bool, error) { + if durationString != "" { + crp, err := time.ParseDuration(durationString) + if err != nil { + return false, err + } + if crp == 0 { + return false, nil + } + return true, nil + } + return false, nil +} diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index e00dff6..5ff3b23 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -12,7 +12,6 @@ import ( authorizerd "github.com/AthenZ/athenz-authorizer/v5" "github.com/AthenZ/authorization-proxy/v4/config" "github.com/AthenZ/authorization-proxy/v4/service" - "github.com/pkg/errors" ) @@ -116,7 +115,7 @@ func TestNew(t *testing.T) { }, TLS: config.TLS{ Enable: true, - CertRefreshPeriod: "0", + CertRefreshPeriod: "0s", }, }, Proxy: config.Proxy{ @@ -124,7 +123,7 @@ func TestNew(t *testing.T) { }, } return test{ - name: "new CertRefreshPeriod is 0, tlsCertificateCache is nil.", + name: "CertRefreshPeriod is 0, tlsCertificateCache is nil.", args: args{ cfg: cfg, }, @@ -258,7 +257,7 @@ func TestNew(t *testing.T) { }, }, wantErr: true, - wantErrStr: "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS): ParseDuration(cfg.CertRefreshPeriod): time: invalid duration \"abcdefg\"", + wantErrStr: "cannot isValidDuration(cfg.Server.TLS.CertRefreshPeriod): time: invalid duration \"abcdefg\"", }, } for _, tt := range tests { @@ -741,13 +740,13 @@ func Test_authzProxyDaemon_Start(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) dummyErr := errors.New("dummy") return test{ - name: "Cert refrsh daemon stops when TLS.Enable = false and CertRefreshPeriod is set", + name: "Cert refrsh daemon stops when TLS.Enable = false and TLS.CertRefreshPeriod is set", fields: fields{ cfg: config.Config{ Server: config.Server{ TLS: config.TLS{ Enable: false, - CertRefreshPeriod: "3d", + CertRefreshPeriod: "3h", CertPath: "../test/data/dummyServer.crt", KeyPath: "../test/data/dummyServer.key", }, @@ -818,13 +817,13 @@ func Test_authzProxyDaemon_Start(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) dummyErr := errors.New("dummy") return test{ - name: "Cert refrsh daemon stops when TLS.Enable = true and CertRefreshPeriod is 0", + name: "Cert refrsh daemon stops when TLS.Enable = true and TLS.CertRefreshPeriod is 0", fields: fields{ cfg: config.Config{ Server: config.Server{ TLS: config.TLS{ Enable: true, - CertRefreshPeriod: "0", + CertRefreshPeriod: "0s", }, }, }, @@ -1156,3 +1155,58 @@ func Test_newAuthzD(t *testing.T) { }) } } + +func Test_isValidDuration(t *testing.T) { + type args struct { + durationString string + } + tests := []struct { + name string + args args + want bool + wantErr error + }{ + { + name: "test true, valid duration", + args: args{ + durationString: "123s", + }, + want: true, + }, + { + name: "test false, empty string", + args: args{ + durationString: "", + }, + want: false, + }, + { + name: "test false, zero", + args: args{ + durationString: "0h", + }, + want: false, + }, + { + name: "test false and error, abcdefg", + args: args{ + durationString: "abcdefg", + }, + want: false, + wantErr: errors.New("time: invalid duration \"abcdefg\""), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := isValidDuration(tt.args.durationString) + if tt.wantErr != nil { + if err.Error() != tt.wantErr.Error() { + t.Errorf("isValidDuration() error = %s, wantErr %s", err.Error(), tt.wantErr.Error()) + } + } + if got != tt.want { + t.Errorf("isValidDuration() = %v, want %v", got, tt.want) + } + }) + } +} From a94e93d5455eb45a8fead9c12edb6b785f61435e Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Wed, 8 Feb 2023 22:01:13 +0900 Subject: [PATCH 45/58] Remove dot Signed-off-by: Kyo Fujisaki --- service/tls.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/service/tls.go b/service/tls.go index 259778e..102c353 100644 --- a/service/tls.go +++ b/service/tls.go @@ -194,7 +194,7 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { case <-ctx.Done(): return nil case <-ticker.C: - glg.Info("Checking to refresh server certificate.") + glg.Info("Checking to refresh server certificate") serverCertHash, err := hash(tcc.serverCertPath) if err != nil { glg.Error("Failed to refresh server certificate: %s.", err.Error()) @@ -221,7 +221,7 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { tcc.serverCert.Store(&newCert) tcc.serverCertHash = serverCertHash tcc.serverCertKeyHash = serverCertKeyHash - glg.Info("Refreshed server certificate.") + glg.Info("Refreshed server certificate") } tcc.serverCertMutex.Unlock() } From c3100c8c93f79b8151e0f584e4fc341e188758df Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 14:18:13 +0900 Subject: [PATCH 46/58] Use NewTLSConfig in NewTLSConfigWithTLSCertificateCache Signed-off-by: Kyo Fujisaki --- service/tls.go | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/service/tls.go b/service/tls.go index 102c353..21fae13 100644 --- a/service/tls.go +++ b/service/tls.go @@ -98,25 +98,19 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { // It initializes TLS configuration, for example the CA certificate and key to start TLS server. // Server and CA Certificate, and private key will read from files from file paths defined in environment variables. func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCertificateCache, error) { + var err error + tcc := &TLSCertificateCache{} - t := &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - ClientAuth: tls.NoClientCert, - GetCertificate: tcc.getCertificate, + t, err := NewTLSConfig(cfg) + if err != nil { + return nil, err } - - var err error + // Set it to nil because the condition for GetCertificate to be executed is "It will only be called if the client supplies SNI information or if Certificates is empty". + t.Certificates = nil + t.GetCertificate = tcc.getCertificate cert := config.GetActualValue(cfg.CertPath) key := config.GetActualValue(cfg.KeyPath) - ca := config.GetActualValue(cfg.CAPath) if cert != "" && key != "" { crt, err := tls.LoadX509KeyPair(cert, key) @@ -147,15 +141,6 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti } } - if ca != "" { - pool, err := NewX509CertPool(ca) - if err != nil { - return nil, errors.Wrap(err, "NewX509CertPool(ca)") - } - t.ClientCAs = pool - t.ClientAuth = tls.RequireAndVerifyClientCert - } - return &TLSConfigWithTLSCertificateCache{ TLSConfig: t, TLSCertificateCache: tcc, From beed8a61c71f9099b13eeb150f155c992903c3ab Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 14:45:24 +0900 Subject: [PATCH 47/58] Revert "Use NewTLSConfig in NewTLSConfigWithTLSCertificateCache" This reverts commit 74006c8ee0789518b828ec1be34dd92eae89832c. Signed-off-by: Kyo Fujisaki --- service/tls.go | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/service/tls.go b/service/tls.go index 21fae13..102c353 100644 --- a/service/tls.go +++ b/service/tls.go @@ -98,19 +98,25 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { // It initializes TLS configuration, for example the CA certificate and key to start TLS server. // Server and CA Certificate, and private key will read from files from file paths defined in environment variables. func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCertificateCache, error) { - var err error - tcc := &TLSCertificateCache{} - t, err := NewTLSConfig(cfg) - if err != nil { - return nil, err + t := &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + ClientAuth: tls.NoClientCert, + GetCertificate: tcc.getCertificate, } - // Set it to nil because the condition for GetCertificate to be executed is "It will only be called if the client supplies SNI information or if Certificates is empty". - t.Certificates = nil - t.GetCertificate = tcc.getCertificate + + var err error cert := config.GetActualValue(cfg.CertPath) key := config.GetActualValue(cfg.KeyPath) + ca := config.GetActualValue(cfg.CAPath) if cert != "" && key != "" { crt, err := tls.LoadX509KeyPair(cert, key) @@ -141,6 +147,15 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti } } + if ca != "" { + pool, err := NewX509CertPool(ca) + if err != nil { + return nil, errors.Wrap(err, "NewX509CertPool(ca)") + } + t.ClientCAs = pool + t.ClientAuth = tls.RequireAndVerifyClientCert + } + return &TLSConfigWithTLSCertificateCache{ TLSConfig: t, TLSCertificateCache: tcc, From 2b5cccd44b1209e5914907bca00bcebe00210f87 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 15:05:26 +0900 Subject: [PATCH 48/58] Use NewTLSConfigWithTLSCertificateCache in NewTLSConfig Signed-off-by: Kyo Fujisaki --- service/tls.go | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/service/tls.go b/service/tls.go index 102c353..74a3bab 100644 --- a/service/tls.go +++ b/service/tls.go @@ -55,41 +55,25 @@ type TLSConfigWithTLSCertificateCache struct { // It initializes TLS configuration, for example the CA certificate and key to start TLS server. // Server and CA Certificate, and private key will read from files from file paths defined in environment variables. func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { - t := &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - ClientAuth: tls.NoClientCert, + t, err := NewTLSConfigWithTLSCertificateCache(cfg) + if err != nil { + return nil, err } + // GetCertificate can only be used with TLSCertificateCache. + t.TLSConfig.GetCertificate = nil cert := config.GetActualValue(cfg.CertPath) key := config.GetActualValue(cfg.KeyPath) - ca := config.GetActualValue(cfg.CAPath) - if cert != "" && key != "" { crt, err := tls.LoadX509KeyPair(cert, key) if err != nil { return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") } - t.Certificates = make([]tls.Certificate, 1) - t.Certificates[0] = crt - } - - if ca != "" { - pool, err := NewX509CertPool(ca) - if err != nil { - return nil, errors.Wrap(err, "NewX509CertPool(ca)") - } - t.ClientCAs = pool - t.ClientAuth = tls.RequireAndVerifyClientCert + t.TLSConfig.Certificates = make([]tls.Certificate, 1) + t.TLSConfig.Certificates[0] = crt } - return t, nil + return t.TLSConfig, nil } // NewTLSConfigWithTLSCertificateCache returns a *TLSConfigWithTLSCertificateCache struct or error. From d0cb64d1333ec537b4c4f165451166518b566f78 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 16:14:09 +0900 Subject: [PATCH 49/58] Move CertRefreshPeriod check to NewTLSConfigWithTLSCertificateCache Signed-off-by: Kyo Fujisaki --- service/tls.go | 89 +++++++++++++++++++++--------------- service/tls_test.go | 55 ++++++++++++++++++++++ usecase/authz_proxyd.go | 38 ++------------- usecase/authz_proxyd_test.go | 55 ---------------------- 4 files changed, 112 insertions(+), 125 deletions(-) diff --git a/service/tls.go b/service/tls.go index 74a3bab..d246969 100644 --- a/service/tls.go +++ b/service/tls.go @@ -59,20 +59,6 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { if err != nil { return nil, err } - // GetCertificate can only be used with TLSCertificateCache. - t.TLSConfig.GetCertificate = nil - - cert := config.GetActualValue(cfg.CertPath) - key := config.GetActualValue(cfg.KeyPath) - if cert != "" && key != "" { - crt, err := tls.LoadX509KeyPair(cert, key) - if err != nil { - return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") - } - t.TLSConfig.Certificates = make([]tls.Certificate, 1) - t.TLSConfig.Certificates[0] = crt - } - return t.TLSConfig, nil } @@ -82,7 +68,7 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { // It initializes TLS configuration, for example the CA certificate and key to start TLS server. // Server and CA Certificate, and private key will read from files from file paths defined in environment variables. func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCertificateCache, error) { - tcc := &TLSCertificateCache{} + var tcc *TLSCertificateCache t := &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{ @@ -93,7 +79,6 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti }, SessionTicketsDisabled: true, ClientAuth: tls.NoClientCert, - GetCertificate: tcc.getCertificate, } var err error @@ -102,32 +87,46 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti key := config.GetActualValue(cfg.KeyPath) ca := config.GetActualValue(cfg.CAPath) - if cert != "" && key != "" { - crt, err := tls.LoadX509KeyPair(cert, key) + isEnableCertRefresh, err := isValidDuration(cfg.CertRefreshPeriod) + if err != nil { + return nil, errors.Wrap(err, "cannot isValidDuration(cfg.CertRefreshPeriod)") + } + if isEnableCertRefresh { + t.GetCertificate = tcc.getCertificate + tcc = &TLSCertificateCache{} + tcc.certRefreshPeriod, err = time.ParseDuration(cfg.CertRefreshPeriod) if err != nil { - return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") + return nil, errors.Wrap(err, "ParseDuration(cfg.CertRefreshPeriod)") } + if cert != "" && key != "" { + crt, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") + } - crtHash, err := hash(cert) - if err != nil { - return nil, errors.Wrap(err, "hash(cert)") - } + crtHash, err := hash(cert) + if err != nil { + return nil, errors.Wrap(err, "hash(cert)") + } - crtKeyHash, err := hash(key) - if err != nil { - return nil, errors.Wrap(err, "hash(key)") + crtKeyHash, err := hash(key) + if err != nil { + return nil, errors.Wrap(err, "hash(key)") + } + tcc.serverCert.Store(&crt) + tcc.serverCertHash = crtHash + tcc.serverCertKeyHash = crtKeyHash + tcc.serverCertPath = cert + tcc.serverCertKeyPath = key } - tcc.serverCert.Store(&crt) - tcc.serverCertHash = crtHash - tcc.serverCertKeyHash = crtKeyHash - tcc.serverCertPath = cert - tcc.serverCertKeyPath = key - } - - if cfg.CertRefreshPeriod != "" { - tcc.certRefreshPeriod, err = time.ParseDuration(cfg.CertRefreshPeriod) - if err != nil { - return nil, errors.Wrap(err, "ParseDuration(cfg.CertRefreshPeriod)") + } else { + if cert != "" && key != "" { + crt, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") + } + t.Certificates = make([]tls.Certificate, 1) + t.Certificates[0] = crt } } @@ -226,3 +225,19 @@ func hash(file string) ([]byte, error) { return h.Sum(nil), nil } + +// isValidDuration returns whether duration is valid. +// "" -> false, "abcdefg" -> false, "0s" -> false, "123s" -> true +func isValidDuration(durationString string) (bool, error) { + if durationString != "" { + crp, err := time.ParseDuration(durationString) + if err != nil { + return false, err + } + if crp == 0 { + return false, nil + } + return true, nil + } + return false, nil +} diff --git a/service/tls_test.go b/service/tls_test.go index f5edb76..a69f529 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -1181,3 +1181,58 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { }) } } + +func Test_isValidDuration(t *testing.T) { + type args struct { + durationString string + } + tests := []struct { + name string + args args + want bool + wantErr error + }{ + { + name: "test true, valid duration", + args: args{ + durationString: "123s", + }, + want: true, + }, + { + name: "test false, empty string", + args: args{ + durationString: "", + }, + want: false, + }, + { + name: "test false, zero", + args: args{ + durationString: "0h", + }, + want: false, + }, + { + name: "test false and error, abcdefg", + args: args{ + durationString: "abcdefg", + }, + want: false, + wantErr: errors.New("time: invalid duration \"abcdefg\""), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := isValidDuration(tt.args.durationString) + if tt.wantErr != nil { + if err.Error() != tt.wantErr.Error() { + t.Errorf("isValidDuration() error = %s, wantErr %s", err.Error(), tt.wantErr.Error()) + } + } + if got != tt.want { + t.Errorf("isValidDuration() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index 305c972..a158217 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -76,23 +76,12 @@ func New(cfg config.Config) (AuthzProxyDaemon, error) { var tlsConfig *tls.Config var tlsCertificateCache *service.TLSCertificateCache if cfg.Server.TLS.Enable { - ivd, err := isValidDuration(cfg.Server.TLS.CertRefreshPeriod) + configWithCache, err := service.NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS) if err != nil { - return nil, errors.Wrap(err, "cannot isValidDuration(cfg.Server.TLS.CertRefreshPeriod)") - } - if ivd { - configWithCache, err := service.NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS) - if err != nil { - return nil, errors.Wrap(err, "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS)") - } - tlsConfig = configWithCache.TLSConfig - tlsCertificateCache = configWithCache.TLSCertificateCache - } else { - tlsConfig, err = service.NewTLSConfig(cfg.Server.TLS) - if err != nil { - return nil, errors.Wrap(err, "cannot NewTLSConfig(cfg.Server.TLS)") - } + return nil, errors.Wrap(err, "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS)") } + tlsConfig = configWithCache.TLSConfig + tlsCertificateCache = configWithCache.TLSCertificateCache serverOption = append(serverOption, service.WithTLSConfig(tlsConfig)) } @@ -165,8 +154,7 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { // handle cert refresh goroutine error // prevent run RefreshCertificate if Enable is false and CertRefreshPeriod is set - ivd, _ := isValidDuration(g.cfg.Server.TLS.CertRefreshPeriod) - if g.cfg.Server.TLS.Enable && ivd { + if g.cfg.Server.TLS.Enable && g.tlsCertificateCache != nil { eg.Go(func() error { return g.tlsCertificateCache.RefreshCertificate(ctx) }) @@ -334,19 +322,3 @@ func newAuthzD(cfg config.Config) (service.Authorizationd, error) { } return authorizerd.New(authzOpts...) } - -// isValidDuration returns whether duration is valid. -// "" -> false, "abcdefg" -> false, "0s" -> false, "123s" -> true -func isValidDuration(durationString string) (bool, error) { - if durationString != "" { - crp, err := time.ParseDuration(durationString) - if err != nil { - return false, err - } - if crp == 0 { - return false, nil - } - return true, nil - } - return false, nil -} diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index 5ff3b23..b5b638c 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -1155,58 +1155,3 @@ func Test_newAuthzD(t *testing.T) { }) } } - -func Test_isValidDuration(t *testing.T) { - type args struct { - durationString string - } - tests := []struct { - name string - args args - want bool - wantErr error - }{ - { - name: "test true, valid duration", - args: args{ - durationString: "123s", - }, - want: true, - }, - { - name: "test false, empty string", - args: args{ - durationString: "", - }, - want: false, - }, - { - name: "test false, zero", - args: args{ - durationString: "0h", - }, - want: false, - }, - { - name: "test false and error, abcdefg", - args: args{ - durationString: "abcdefg", - }, - want: false, - wantErr: errors.New("time: invalid duration \"abcdefg\""), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := isValidDuration(tt.args.durationString) - if tt.wantErr != nil { - if err.Error() != tt.wantErr.Error() { - t.Errorf("isValidDuration() error = %s, wantErr %s", err.Error(), tt.wantErr.Error()) - } - } - if got != tt.want { - t.Errorf("isValidDuration() = %v, want %v", got, tt.want) - } - }) - } -} From 19667a7dfd267b4d41f2d14a42d182bfbaf96fe9 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 16:43:50 +0900 Subject: [PATCH 50/58] Fix comment and log Signed-off-by: Kyo Fujisaki --- config/config.go | 2 +- service/tls.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config/config.go b/config/config.go index a50c63f..bae1040 100755 --- a/config/config.go +++ b/config/config.go @@ -90,7 +90,7 @@ type TLS struct { // CAPath represents the CA certificate chain file path for verifying client certificates. CAPath string `yaml:"caPath"` - // CertRefreshPeriod represents the time to read the certificate again. + // CertRefreshPeriod represents the duration to read the server certificate again. CertRefreshPeriod string `yaml:"certRefreshPeriod"` } diff --git a/service/tls.go b/service/tls.go index d246969..2ddee9f 100644 --- a/service/tls.go +++ b/service/tls.go @@ -34,7 +34,7 @@ import ( "github.com/pkg/errors" ) -// TLSCertificateCache represents refresh certificate +// TLSCertificateCache caches a certificate type TLSCertificateCache struct { serverCert atomic.Value serverCertHash []byte @@ -162,13 +162,13 @@ func NewX509CertPool(path string) (*x509.CertPool, error) { return pool, errors.Wrap(err, "x509.SystemCertPool()") } -// getCertificate return server TLS certificate. +// getCertificate returns the cached certificate. func (tcc *TLSCertificateCache) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { // serverCert is atomic.Value, so this can read it without lock. return tcc.serverCert.Load().(*tls.Certificate), nil } -// RefreshCertificate is refresh certificate for TLS. +// RefreshCertificate refreshes the cached certificate asynchronously. func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { ticker := time.NewTicker(tcc.certRefreshPeriod) defer ticker.Stop() @@ -177,7 +177,7 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { case <-ctx.Done(): return nil case <-ticker.C: - glg.Info("Checking to refresh server certificate") + glg.Info("Start refreshing server certificate") serverCertHash, err := hash(tcc.serverCertPath) if err != nil { glg.Error("Failed to refresh server certificate: %s.", err.Error()) @@ -188,7 +188,7 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { glg.Error("Failed to refresh server certificate: %s.", err.Error()) continue } - // A lock for when there are other features to update. + // lock the whole struct before write (prevent race from multiple calls). // serverCert is atomic.Value, so this can read it without lock. tcc.serverCertMutex.Lock() different := !bytes.Equal(tcc.serverCertHash, serverCertHash) || From 76edebdf368576908e31d5a9c0d55ae943982a93 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 20:17:27 +0900 Subject: [PATCH 51/58] Add process for compatibility Signed-off-by: Kyo Fujisaki --- service/tls.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/service/tls.go b/service/tls.go index 2ddee9f..5d382f7 100644 --- a/service/tls.go +++ b/service/tls.go @@ -59,10 +59,24 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { if err != nil { return nil, err } + // GetCertificate can only be used with TLSCertificateCache. + t.TLSConfig.GetCertificate = nil + cert := config.GetActualValue(cfg.CertPath) + key := config.GetActualValue(cfg.KeyPath) + if cert != "" && key != "" { + crt, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") + } + t.TLSConfig.Certificates = make([]tls.Certificate, 1) + t.TLSConfig.Certificates[0] = crt + } return t.TLSConfig, nil } // NewTLSConfigWithTLSCertificateCache returns a *TLSConfigWithTLSCertificateCache struct or error. +// cfg.CertRefreshPeriod is set(cert refresh enable), returns TLSCertificateCache: not nil / TLSConfig.GetCertificate: not nil / TLSConfig.Certificates: nil +// cfg.CertRefreshPeriod is not set(cert refresh disable), returns TLSCertificateCache: nil / TLSConfig.GetCertificate: nil / TLSConfig.Certificates: not nil // It uses to enable the certificate auto-reload feature. // It reads TLS configuration and initializes *tls.Config / *TLSCertificateCache struct. // It initializes TLS configuration, for example the CA certificate and key to start TLS server. @@ -92,6 +106,7 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti return nil, errors.Wrap(err, "cannot isValidDuration(cfg.CertRefreshPeriod)") } if isEnableCertRefresh { + // GetCertificate can only be used with TLSCertificateCache. t.GetCertificate = tcc.getCertificate tcc = &TLSCertificateCache{} tcc.certRefreshPeriod, err = time.ParseDuration(cfg.CertRefreshPeriod) From 04c9cdcbd2ceed42fd4895a46e2cb881eb04f1eb Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 20:17:56 +0900 Subject: [PATCH 52/58] Fix test for tls.go Signed-off-by: Kyo Fujisaki --- service/tls_test.go | 234 ++++++++++++++++++++++++++++---------------- 1 file changed, 148 insertions(+), 86 deletions(-) diff --git a/service/tls_test.go b/service/tls_test.go index a69f529..15b3271 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -165,6 +165,45 @@ func TestNewTLSConfig(t *testing.T) { return nil }, }, + { + name: "if certRefreshPeriod set, return TLSConfig.Certificates", + args: args{ + cfg: config.TLS{ + CertPath: "../test/data/dummyServer.crt", + KeyPath: "../test/data/dummyServer.key", + CAPath: "../test/data/dummyCa.pem", + CertRefreshPeriod: "12345s", + }, + }, + want: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, + }, + checkFunc: func(got, want *tls.Config) error { + // config.TLS.certRefreshPeriod is not set, GetCertificate is nil + if got.GetCertificate != nil { + return fmt.Errorf("GetCertificate is not nil") + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set + gotCert, _ := x509.ParseCertificate(got.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.Certificates[0].Certificate[0]) + if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { + return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) + } + return nil + }, + }, { name: "return value ClientAuth test.", args: defaultArgs, @@ -346,7 +385,7 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { name: "return value MinVersion test.", args: defaultArgs, want: &TLSConfigWithTLSCertificateCache{ - &tls.Config{ + TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{ tls.CurveP521, @@ -359,59 +398,52 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) return []tls.Certificate{cert} }(), - ClientAuth: tls.RequireAndVerifyClientCert, - }, - &TLSCertificateCache{ - serverCert: defaultServerCert, - serverCertHash: defaultServerCerttHash, - serverCertKeyHash: defaultServerCerttKeyHash, - serverCertPath: defaultArgs.cfg.CertPath, - serverCertKeyPath: defaultArgs.cfg.KeyPath, - certRefreshPeriod: 0, + ClientAuth: tls.RequireAndVerifyClientCert, + GetCertificate: nil, }, + TLSCertificateCache: nil, }, checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { if got.TLSConfig.MinVersion != want.TLSConfig.MinVersion { return fmt.Errorf("MinVersion not Matched :\tgot %d\twant %d", got.TLSConfig.MinVersion, want.TLSConfig.MinVersion) } - gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil + if got.TLSCertificateCache != want.TLSCertificateCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil + if got.TLSConfig.GetCertificate != nil { + return fmt.Errorf("TLSConfig.GetCertificate is not nil") + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set + gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - // config.TLS.certRefreshPeriod is not set, TLSCertificateCache.certRefreshPeriod is 0 - if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) - } return nil }, }, { name: "return value CurvePreferences test.", args: defaultArgs, - want: &TLSConfigWithTLSCertificateCache{&tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) - return []tls.Certificate{cert} - }(), - ClientAuth: tls.RequireAndVerifyClientCert, - }, - &TLSCertificateCache{ - serverCert: defaultServerCert, - serverCertHash: defaultServerCerttHash, - serverCertKeyHash: defaultServerCerttKeyHash, - serverCertPath: defaultArgs.cfg.CertPath, - serverCertKeyPath: defaultArgs.cfg.KeyPath, - certRefreshPeriod: 0, + want: &TLSConfigWithTLSCertificateCache{ + &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, + }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, }, + nil, }, checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { if len(got.TLSConfig.CurvePreferences) != len(want.TLSConfig.CurvePreferences) { @@ -430,14 +462,20 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { return fmt.Errorf("CurvePreferences not Find :\twant %d", want.TLSConfig.MinVersion) } } - gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil + if got.TLSCertificateCache != want.TLSCertificateCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil + if got.TLSConfig.GetCertificate != nil { + return fmt.Errorf("TLSConfig.GetCertificate is not nil") + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set + gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) - } return nil }, }, @@ -460,27 +498,26 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { }(), ClientAuth: tls.RequireAndVerifyClientCert, }, - &TLSCertificateCache{ - serverCert: defaultServerCert, - serverCertHash: defaultServerCerttHash, - serverCertKeyHash: defaultServerCerttKeyHash, - serverCertPath: defaultArgs.cfg.CertPath, - serverCertKeyPath: defaultArgs.cfg.KeyPath, - certRefreshPeriod: 0, - }, + nil, }, checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { if got.TLSConfig.SessionTicketsDisabled != want.TLSConfig.SessionTicketsDisabled { return fmt.Errorf("SessionTicketsDisabled not matched :\tgot %v\twant %v", got.TLSConfig.SessionTicketsDisabled, want.TLSConfig.SessionTicketsDisabled) } - gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil + if got.TLSCertificateCache != want.TLSCertificateCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil + if got.TLSConfig.GetCertificate != nil { + return fmt.Errorf("TLSConfig.GetCertificate is not nil") + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set + gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) - } return nil }, }, @@ -503,14 +540,7 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { }(), ClientAuth: tls.RequireAndVerifyClientCert, }, - &TLSCertificateCache{ - serverCert: defaultServerCert, - serverCertHash: defaultServerCerttHash, - serverCertKeyHash: defaultServerCerttKeyHash, - serverCertPath: defaultArgs.cfg.CertPath, - serverCertKeyPath: defaultArgs.cfg.KeyPath, - certRefreshPeriod: 0, - }, + nil, }, checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { for _, wantVal := range want.TLSConfig.Certificates { @@ -525,14 +555,20 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { return fmt.Errorf("Certificates PrivateKey not Matched :\twant %s", wantVal.PrivateKey) } } - gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil + if got.TLSCertificateCache != want.TLSCertificateCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil + if got.TLSConfig.GetCertificate != nil { + return fmt.Errorf("TLSConfig.GetCertificate is not nil") + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set + gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) - } return nil }, }, @@ -555,27 +591,26 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { }(), ClientAuth: tls.RequireAndVerifyClientCert, }, - &TLSCertificateCache{ - serverCert: defaultServerCert, - serverCertHash: defaultServerCerttHash, - serverCertKeyHash: defaultServerCerttKeyHash, - serverCertPath: defaultArgs.cfg.CertPath, - serverCertKeyPath: defaultArgs.cfg.KeyPath, - certRefreshPeriod: 0, - }, + nil, }, checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { if got.TLSConfig.ClientAuth != want.TLSConfig.ClientAuth { return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", got.TLSConfig.ClientAuth, want.TLSConfig.ClientAuth) } - gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil + if got.TLSCertificateCache != want.TLSCertificateCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil + if got.TLSConfig.GetCertificate != nil { + return fmt.Errorf("TLSConfig.GetCertificate is not nil") + } + // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set + gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) - } return nil }, }, @@ -618,11 +653,20 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { if got.TLSConfig.ClientAuth != want.TLSConfig.ClientAuth { return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", got.TLSConfig.ClientAuth, want.TLSConfig.ClientAuth) } + // config.TLS.certRefreshPeriod is set, TLSCertificateCache is set gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } + // config.TLS.certRefreshPeriod is set, TLSConfig.GetCertificate is set + if got.TLSConfig.GetCertificate == nil { + return fmt.Errorf("GetCertificate nil") + } + // config.TLS.certRefreshPeriod is set, TLSConfig.Certificates is nil + if got.TLSConfig.Certificates != nil { + return fmt.Errorf("Certificates not nil\tgot: %v", got.TLSConfig.Certificates) + } if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) } @@ -649,7 +693,7 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { Certificates: nil, ClientAuth: tls.RequireAndVerifyClientCert, }, - &TLSCertificateCache{}, + nil, }, checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { if got.TLSConfig.Certificates != nil { @@ -678,7 +722,7 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { Certificates: nil, ClientAuth: tls.RequireAndVerifyClientCert, }, - &TLSCertificateCache{}, + nil, }, checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { if got.TLSConfig.Certificates != nil { @@ -712,7 +756,7 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { }(), ClientAuth: tls.RequireAndVerifyClientCert, }, - &TLSCertificateCache{}, + nil, }, checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { if got.TLSConfig.ClientAuth != 0 { @@ -721,6 +765,24 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { return nil }, }, + { + name: "cert file invalid return error test.", + args: args{ + cfg: config.TLS{ + CertPath: "../test/data/invalid_dummyServer.crt", + KeyPath: "../test/data/invalid_dummyServer.key", + }, + }, + + want: nil, + wantErr: errors.New("tls.LoadX509KeyPair(cert, key): tls: failed to find any PEM data in certificate input"), + checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { + if got != nil { + return fmt.Errorf("got not nil :\tgot %d \twant %d", &got, &want) + } + return nil + }, + }, { name: "CertRefreshPeriod invalid return error test.", args: args{ @@ -732,7 +794,7 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { }, }, want: nil, - wantErr: errors.New("ParseDuration(cfg.CertRefreshPeriod): time: invalid duration \"invalid duration\""), + wantErr: errors.New("cannot isValidDuration(cfg.CertRefreshPeriod): time: invalid duration \"invalid duration\""), checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { if got != nil { return fmt.Errorf("got not nil :\tgot %d \twant %d", &got, &want) From 736f2665f7f56ca547b0a7fc50159d03da8b7b9e Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 21:04:57 +0900 Subject: [PATCH 53/58] Fix condition for running cert refresh daemon Signed-off-by: Kyo Fujisaki --- usecase/authz_proxyd.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index a158217..7004a77 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -154,7 +154,7 @@ func (g *authzProxyDaemon) Start(ctx context.Context) <-chan []error { // handle cert refresh goroutine error // prevent run RefreshCertificate if Enable is false and CertRefreshPeriod is set - if g.cfg.Server.TLS.Enable && g.tlsCertificateCache != nil { + if g.tlsCertificateCache != nil { eg.Go(func() error { return g.tlsCertificateCache.RefreshCertificate(ctx) }) From 45c5d58fb5f8a550ba691b6b75be528af2a0a8bd Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 21:05:35 +0900 Subject: [PATCH 54/58] Fix tests for authz_proxyd_test.go Signed-off-by: Kyo Fujisaki --- usecase/authz_proxyd_test.go | 84 +++--------------------------------- 1 file changed, 5 insertions(+), 79 deletions(-) diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index b5b638c..85559b7 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -182,7 +182,7 @@ func TestNew(t *testing.T) { }, } return test{ - name: "new CertRefreshPeriod not set, tlsCertificateCache is nil.", + name: "CertRefreshPeriod not set, tlsCertificateCache is nil.", args: args{ cfg: cfg, }, @@ -238,7 +238,7 @@ func TestNew(t *testing.T) { }, }, wantErr: true, - wantErrStr: "cannot NewTLSConfig(cfg.Server.TLS): tls.LoadX509KeyPair(cert, key): tls: failed to find any PEM data in certificate input", + wantErrStr: "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS): tls.LoadX509KeyPair(cert, key): tls: failed to find any PEM data in certificate input", }, { name: "return error when CertRefreshPeriod invalid (failed to parse)", args: args{ @@ -257,7 +257,7 @@ func TestNew(t *testing.T) { }, }, wantErr: true, - wantErrStr: "cannot isValidDuration(cfg.Server.TLS.CertRefreshPeriod): time: invalid duration \"abcdefg\"", + wantErrStr: "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS): cannot isValidDuration(cfg.CertRefreshPeriod): time: invalid duration \"abcdefg\"", }, } for _, tt := range tests { @@ -740,7 +740,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) dummyErr := errors.New("dummy") return test{ - name: "Cert refrsh daemon stops when TLS.Enable = false and TLS.CertRefreshPeriod is set", + name: "Cert refrsh daemon stops when tlsCertificateCache is nil", fields: fields{ cfg: config.Config{ Server: config.Server{ @@ -773,81 +773,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, - }, - args: args{ - ctx: ctx, - }, - wantErrs: []error{ - errors.WithMessage(dummyErr, "server fails"), - }, - checkFunc: func(got <-chan []error, wantErrs []error) error { - mux := &sync.Mutex{} - - gotErrs := make([][]error, 0) - mux.Lock() - go func() { - defer mux.Unlock() - // this can be execute == through eg.Wait() == refresh daemon is not running - err, ok := <-got - if !ok { - return - } - gotErrs = append(gotErrs, err) - }() - time.Sleep(time.Second) - - mux.Lock() - defer mux.Unlock() - - // check only send errors once and the errors are expected ignoring order - sort.Slice(gotErrs[0], getLessErrorFunc(gotErrs[0])) - sort.Slice(wantErrs, getLessErrorFunc(wantErrs)) - gotErrsStr := fmt.Sprintf("%v", gotErrs[0]) - wantErrsStr := fmt.Sprintf("%v", wantErrs) - if len(gotErrs) != 1 || !reflect.DeepEqual(gotErrsStr, wantErrsStr) { - return errors.Errorf("Invalid err, got: %v, want: %v", gotErrsStr, wantErrsStr) - } - - cancel() - return nil - }, - } - }(), - func() test { - ctx, cancel := context.WithCancel(context.Background()) - dummyErr := errors.New("dummy") - return test{ - name: "Cert refrsh daemon stops when TLS.Enable = true and TLS.CertRefreshPeriod is 0", - fields: fields{ - cfg: config.Config{ - Server: config.Server{ - TLS: config.TLS{ - Enable: true, - CertRefreshPeriod: "0s", - }, - }, - }, - athenz: &service.AuthorizerdMock{ - StartFunc: func(ctx context.Context) <-chan error { - ech := make(chan error) - go func() { - defer close(ech) - <-ctx.Done() - ech <- ctx.Err() - }() - return ech - }, - }, - server: &service.ServerMock{ - ListenAndServeFunc: func(ctx context.Context) <-chan []error { - ech := make(chan []error) - go func() { - defer close(ech) - ech <- []error{errors.WithMessage(dummyErr, "server fails")} - }() - return ech - }, - }, + tlsCertificateCache: nil, }, args: args{ ctx: ctx, From 29dec34f063ceecbf890d609ce0b0c8a40d7725a Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Thu, 9 Feb 2023 21:20:59 +0900 Subject: [PATCH 55/58] Fix initialization Signed-off-by: Kyo Fujisaki --- service/tls.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/service/tls.go b/service/tls.go index 5d382f7..a7d3982 100644 --- a/service/tls.go +++ b/service/tls.go @@ -107,8 +107,9 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti } if isEnableCertRefresh { // GetCertificate can only be used with TLSCertificateCache. - t.GetCertificate = tcc.getCertificate tcc = &TLSCertificateCache{} + t.GetCertificate = tcc.getCertificate + tcc.certRefreshPeriod, err = time.ParseDuration(cfg.CertRefreshPeriod) if err != nil { return nil, errors.Wrap(err, "ParseDuration(cfg.CertRefreshPeriod)") From 240c2edc3c1b0b93e86dbe5c190a2cc5d51d9027 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Mon, 13 Feb 2023 14:02:09 +0900 Subject: [PATCH 56/58] Fix use modified config Signed-off-by: Kyo Fujisaki --- service/tls.go | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/service/tls.go b/service/tls.go index a7d3982..597a679 100644 --- a/service/tls.go +++ b/service/tls.go @@ -55,22 +55,13 @@ type TLSConfigWithTLSCertificateCache struct { // It initializes TLS configuration, for example the CA certificate and key to start TLS server. // Server and CA Certificate, and private key will read from files from file paths defined in environment variables. func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { - t, err := NewTLSConfigWithTLSCertificateCache(cfg) + // This is config for not using TLSCertificateCache. + modifiedCfg := cfg + modifiedCfg.CertRefreshPeriod = "" + t, err := NewTLSConfigWithTLSCertificateCache(modifiedCfg) if err != nil { return nil, err } - // GetCertificate can only be used with TLSCertificateCache. - t.TLSConfig.GetCertificate = nil - cert := config.GetActualValue(cfg.CertPath) - key := config.GetActualValue(cfg.KeyPath) - if cert != "" && key != "" { - crt, err := tls.LoadX509KeyPair(cert, key) - if err != nil { - return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") - } - t.TLSConfig.Certificates = make([]tls.Certificate, 1) - t.TLSConfig.Certificates[0] = crt - } return t.TLSConfig, nil } From f8c6e4a4f88b2e5f2ec0018e4f32353e3dae60f4 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Mon, 13 Feb 2023 15:28:14 +0900 Subject: [PATCH 57/58] Remove TLSConfigWithTLSCertificateCache Signed-off-by: Kyo Fujisaki --- service/tls.go | 32 +-- service/tls_test.go | 453 +++++++++++++++++------------------ usecase/authz_proxyd.go | 4 +- usecase/authz_proxyd_test.go | 16 +- 4 files changed, 243 insertions(+), 262 deletions(-) diff --git a/service/tls.go b/service/tls.go index 597a679..f6db879 100644 --- a/service/tls.go +++ b/service/tls.go @@ -45,11 +45,6 @@ type TLSCertificateCache struct { certRefreshPeriod time.Duration } -type TLSConfigWithTLSCertificateCache struct { - TLSConfig *tls.Config - TLSCertificateCache *TLSCertificateCache -} - // NewTLSConfig returns a *tls.Config struct or error. // It reads TLS configuration and initializes *tls.Config struct. // It initializes TLS configuration, for example the CA certificate and key to start TLS server. @@ -58,21 +53,21 @@ func NewTLSConfig(cfg config.TLS) (*tls.Config, error) { // This is config for not using TLSCertificateCache. modifiedCfg := cfg modifiedCfg.CertRefreshPeriod = "" - t, err := NewTLSConfigWithTLSCertificateCache(modifiedCfg) + t, _, err := NewTLSConfigWithTLSCertificateCache(modifiedCfg) if err != nil { return nil, err } - return t.TLSConfig, nil + return t, nil } -// NewTLSConfigWithTLSCertificateCache returns a *TLSConfigWithTLSCertificateCache struct or error. +// NewTLSConfigWithTLSCertificateCache returns a *tls.Config/*TLSCertificateCache struct or error. // cfg.CertRefreshPeriod is set(cert refresh enable), returns TLSCertificateCache: not nil / TLSConfig.GetCertificate: not nil / TLSConfig.Certificates: nil // cfg.CertRefreshPeriod is not set(cert refresh disable), returns TLSCertificateCache: nil / TLSConfig.GetCertificate: nil / TLSConfig.Certificates: not nil // It uses to enable the certificate auto-reload feature. // It reads TLS configuration and initializes *tls.Config / *TLSCertificateCache struct. // It initializes TLS configuration, for example the CA certificate and key to start TLS server. // Server and CA Certificate, and private key will read from files from file paths defined in environment variables. -func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCertificateCache, error) { +func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*tls.Config, *TLSCertificateCache, error) { var tcc *TLSCertificateCache t := &tls.Config{ MinVersion: tls.VersionTLS12, @@ -94,7 +89,7 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti isEnableCertRefresh, err := isValidDuration(cfg.CertRefreshPeriod) if err != nil { - return nil, errors.Wrap(err, "cannot isValidDuration(cfg.CertRefreshPeriod)") + return nil, nil, errors.Wrap(err, "cannot isValidDuration(cfg.CertRefreshPeriod)") } if isEnableCertRefresh { // GetCertificate can only be used with TLSCertificateCache. @@ -103,22 +98,22 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti tcc.certRefreshPeriod, err = time.ParseDuration(cfg.CertRefreshPeriod) if err != nil { - return nil, errors.Wrap(err, "ParseDuration(cfg.CertRefreshPeriod)") + return nil, nil, errors.Wrap(err, "ParseDuration(cfg.CertRefreshPeriod)") } if cert != "" && key != "" { crt, err := tls.LoadX509KeyPair(cert, key) if err != nil { - return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") + return nil, nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") } crtHash, err := hash(cert) if err != nil { - return nil, errors.Wrap(err, "hash(cert)") + return nil, nil, errors.Wrap(err, "hash(cert)") } crtKeyHash, err := hash(key) if err != nil { - return nil, errors.Wrap(err, "hash(key)") + return nil, nil, errors.Wrap(err, "hash(key)") } tcc.serverCert.Store(&crt) tcc.serverCertHash = crtHash @@ -130,7 +125,7 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti if cert != "" && key != "" { crt, err := tls.LoadX509KeyPair(cert, key) if err != nil { - return nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") + return nil, nil, errors.Wrap(err, "tls.LoadX509KeyPair(cert, key)") } t.Certificates = make([]tls.Certificate, 1) t.Certificates[0] = crt @@ -140,16 +135,13 @@ func NewTLSConfigWithTLSCertificateCache(cfg config.TLS) (*TLSConfigWithTLSCerti if ca != "" { pool, err := NewX509CertPool(ca) if err != nil { - return nil, errors.Wrap(err, "NewX509CertPool(ca)") + return nil, nil, errors.Wrap(err, "NewX509CertPool(ca)") } t.ClientCAs = pool t.ClientAuth = tls.RequireAndVerifyClientCert } - return &TLSConfigWithTLSCertificateCache{ - TLSConfig: t, - TLSCertificateCache: tcc, - }, nil + return t, tcc, nil } // NewX509CertPool returns *x509.CertPool struct or error. diff --git a/service/tls_test.go b/service/tls_test.go index 15b3271..a9b7b0c 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -375,49 +375,48 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { tests := []struct { name string args args - want *TLSConfigWithTLSCertificateCache + wantConfig *tls.Config + wantCache *TLSCertificateCache beforeFunc func(args args) - checkFunc func(*TLSConfigWithTLSCertificateCache, *TLSConfigWithTLSCertificateCache) error + checkFunc func(*tls.Config, *TLSCertificateCache, *tls.Config, *TLSCertificateCache) error afterFunc func(args args) wantErr error }{ { name: "return value MinVersion test.", args: defaultArgs, - want: &TLSConfigWithTLSCertificateCache{ - TLSConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) - return []tls.Certificate{cert} - }(), - ClientAuth: tls.RequireAndVerifyClientCert, - GetCertificate: nil, + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, }, - TLSCertificateCache: nil, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, + GetCertificate: nil, }, - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if got.TLSConfig.MinVersion != want.TLSConfig.MinVersion { - return fmt.Errorf("MinVersion not Matched :\tgot %d\twant %d", got.TLSConfig.MinVersion, want.TLSConfig.MinVersion) + wantCache: nil, + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if gotConfig.MinVersion != wantConfig.MinVersion { + return fmt.Errorf("MinVersion not Matched :\tgot %d\twant %d", gotConfig.MinVersion, wantConfig.MinVersion) } // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil - if got.TLSCertificateCache != want.TLSCertificateCache { - return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + if gotCache != wantCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", gotCache) } - // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil - if got.TLSConfig.GetCertificate != nil { - return fmt.Errorf("TLSConfig.GetCertificate is not nil") + // config.TLS.certRefreshPeriod is not set, GetCertificate is nil + if gotConfig.GetCertificate != nil { + return fmt.Errorf("GetCertificate is not nil") } // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set - gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) + gotCert, _ := x509.ParseCertificate(gotConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(wantConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } @@ -427,31 +426,29 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { { name: "return value CurvePreferences test.", args: defaultArgs, - want: &TLSConfigWithTLSCertificateCache{ - &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) - return []tls.Certificate{cert} - }(), - ClientAuth: tls.RequireAndVerifyClientCert, + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, }, - nil, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, }, - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if len(got.TLSConfig.CurvePreferences) != len(want.TLSConfig.CurvePreferences) { - return fmt.Errorf("CurvePreferences not Matched length:\tgot %d\twant %d", len(got.TLSConfig.CurvePreferences), len(want.TLSConfig.CurvePreferences)) + wantCache: nil, + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if len(gotConfig.CurvePreferences) != len(wantConfig.CurvePreferences) { + return fmt.Errorf("CurvePreferences not Matched length:\tgot %d\twant %d", len(gotConfig.CurvePreferences), len(wantConfig.CurvePreferences)) } - for _, actualValue := range got.TLSConfig.CurvePreferences { + for _, actualValue := range gotConfig.CurvePreferences { var match bool - for _, expectedValue := range want.TLSConfig.CurvePreferences { + for _, expectedValue := range wantConfig.CurvePreferences { if actualValue == expectedValue { match = true break @@ -459,20 +456,20 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { } if !match { - return fmt.Errorf("CurvePreferences not Find :\twant %d", want.TLSConfig.MinVersion) + return fmt.Errorf("CurvePreferences not Find :\twant %d", wantConfig.MinVersion) } } // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil - if got.TLSCertificateCache != want.TLSCertificateCache { - return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + if gotCache != wantCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", gotCache) } - // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil - if got.TLSConfig.GetCertificate != nil { - return fmt.Errorf("TLSConfig.GetCertificate is not nil") + // config.TLS.certRefreshPeriod is not set, GetCertificate is nil + if gotConfig.GetCertificate != nil { + return fmt.Errorf("GetCertificate is not nil") } // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set - gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) + gotCert, _ := x509.ParseCertificate(gotConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(wantConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } @@ -482,39 +479,37 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { { name: "return value SessionTicketsDisabled test.", args: defaultArgs, - want: &TLSConfigWithTLSCertificateCache{ - &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) - return []tls.Certificate{cert} - }(), - ClientAuth: tls.RequireAndVerifyClientCert, + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, }, - nil, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, }, - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if got.TLSConfig.SessionTicketsDisabled != want.TLSConfig.SessionTicketsDisabled { - return fmt.Errorf("SessionTicketsDisabled not matched :\tgot %v\twant %v", got.TLSConfig.SessionTicketsDisabled, want.TLSConfig.SessionTicketsDisabled) + wantCache: nil, + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if gotConfig.SessionTicketsDisabled != wantConfig.SessionTicketsDisabled { + return fmt.Errorf("SessionTicketsDisabled not matched :\tgot %v\twant %v", gotConfig.SessionTicketsDisabled, wantConfig.SessionTicketsDisabled) } // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil - if got.TLSCertificateCache != want.TLSCertificateCache { - return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + if gotCache != wantCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", gotCache) } - // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil - if got.TLSConfig.GetCertificate != nil { - return fmt.Errorf("TLSConfig.GetCertificate is not nil") + // config.TLS.certRefreshPeriod is not set, GetCertificate is nil + if gotConfig.GetCertificate != nil { + return fmt.Errorf("GetCertificate is not nil") } // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set - gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) + gotCert, _ := x509.ParseCertificate(gotConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(wantConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } @@ -524,28 +519,26 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { { name: "return value Certificates test.", args: defaultArgs, - want: &TLSConfigWithTLSCertificateCache{ - &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) - return []tls.Certificate{cert} - }(), - ClientAuth: tls.RequireAndVerifyClientCert, + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, }, - nil, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, }, - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - for _, wantVal := range want.TLSConfig.Certificates { + wantCache: nil, + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + for _, wantVal := range wantConfig.Certificates { notExist := false - for _, gotVal := range got.TLSConfig.Certificates { + for _, gotVal := range gotConfig.Certificates { if gotVal.PrivateKey == wantVal.PrivateKey { notExist = true break @@ -556,16 +549,16 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { } } // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil - if got.TLSCertificateCache != want.TLSCertificateCache { - return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + if gotCache != wantCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", gotCache) } - // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil - if got.TLSConfig.GetCertificate != nil { - return fmt.Errorf("TLSConfig.GetCertificate is not nil") + // config.TLS.certRefreshPeriod is not set, GetCertificate is nil + if gotConfig.GetCertificate != nil { + return fmt.Errorf("GetCertificate is not nil") } // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set - gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) + gotCert, _ := x509.ParseCertificate(gotConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(wantConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } @@ -575,39 +568,37 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { { name: "return value ClientAuth test.", args: defaultArgs, - want: &TLSConfigWithTLSCertificateCache{ - &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) - return []tls.Certificate{cert} - }(), - ClientAuth: tls.RequireAndVerifyClientCert, + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, }, - nil, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, }, - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if got.TLSConfig.ClientAuth != want.TLSConfig.ClientAuth { - return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", got.TLSConfig.ClientAuth, want.TLSConfig.ClientAuth) + wantCache: nil, + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if gotConfig.ClientAuth != wantConfig.ClientAuth { + return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", gotConfig.ClientAuth, wantConfig.ClientAuth) } // config.TLS.certRefreshPeriod is not set, TLSCertificateCache is nil - if got.TLSCertificateCache != want.TLSCertificateCache { - return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", got.TLSCertificateCache) + if gotCache != wantCache { + return fmt.Errorf("TLSCertificateCache is not nil\tgot: %v", gotCache) } - // config.TLS.certRefreshPeriod is not set, TLSConfig.GetCertificate is nil - if got.TLSConfig.GetCertificate != nil { - return fmt.Errorf("TLSConfig.GetCertificate is not nil") + // config.TLS.certRefreshPeriod is not set, GetCertificate is nil + if gotConfig.GetCertificate != nil { + return fmt.Errorf("GetCertificate is not nil") } // config.TLS.certRefreshPeriod is not set, TLSConfig.Certificates is set - gotCert, _ := x509.ParseCertificate(got.TLSConfig.Certificates[0].Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSConfig.Certificates[0].Certificate[0]) + gotCert, _ := x509.ParseCertificate(gotConfig.Certificates[0].Certificate[0]) + wantCert, _ := x509.ParseCertificate(wantConfig.Certificates[0].Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } @@ -624,51 +615,49 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { CertRefreshPeriod: "12345s", }, }, - want: &TLSConfigWithTLSCertificateCache{ - &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) - return []tls.Certificate{cert} - }(), - ClientAuth: tls.RequireAndVerifyClientCert, - }, - &TLSCertificateCache{ - serverCert: defaultServerCert, - serverCertHash: defaultServerCerttHash, - serverCertKeyHash: defaultServerCerttKeyHash, - serverCertPath: defaultArgs.cfg.CertPath, - serverCertKeyPath: defaultArgs.cfg.KeyPath, - certRefreshPeriod: 12345 * time.Second, + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, }, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, }, - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if got.TLSConfig.ClientAuth != want.TLSConfig.ClientAuth { - return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", got.TLSConfig.ClientAuth, want.TLSConfig.ClientAuth) + wantCache: &TLSCertificateCache{ + serverCert: defaultServerCert, + serverCertHash: defaultServerCerttHash, + serverCertKeyHash: defaultServerCerttKeyHash, + serverCertPath: defaultArgs.cfg.CertPath, + serverCertKeyPath: defaultArgs.cfg.KeyPath, + certRefreshPeriod: 12345 * time.Second, + }, + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if gotConfig.ClientAuth != wantConfig.ClientAuth { + return fmt.Errorf("ClientAuth not Matched :\tgot %d \twant %d", gotConfig.ClientAuth, wantConfig.ClientAuth) } // config.TLS.certRefreshPeriod is set, TLSCertificateCache is set - gotCert, _ := x509.ParseCertificate(got.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) - wantCert, _ := x509.ParseCertificate(want.TLSCertificateCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + gotCert, _ := x509.ParseCertificate(gotCache.serverCert.Load().(*tls.Certificate).Certificate[0]) + wantCert, _ := x509.ParseCertificate(wantCache.serverCert.Load().(*tls.Certificate).Certificate[0]) if gotCert.SerialNumber.String() != wantCert.SerialNumber.String() { return fmt.Errorf("Certificate SerialNumber not Matched\tgot: %s\twant: %s", gotCert.SerialNumber, wantCert.SerialNumber) } - // config.TLS.certRefreshPeriod is set, TLSConfig.GetCertificate is set - if got.TLSConfig.GetCertificate == nil { + // config.TLS.certRefreshPeriod is set, GetCertificate is set + if gotConfig.GetCertificate == nil { return fmt.Errorf("GetCertificate nil") } // config.TLS.certRefreshPeriod is set, TLSConfig.Certificates is nil - if got.TLSConfig.Certificates != nil { - return fmt.Errorf("Certificates not nil\tgot: %v", got.TLSConfig.Certificates) + if gotConfig.Certificates != nil { + return fmt.Errorf("Certificates not nil\tgot: %v", gotConfig.Certificates) } - if got.TLSCertificateCache.certRefreshPeriod != want.TLSCertificateCache.certRefreshPeriod { - return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", got.TLSCertificateCache.certRefreshPeriod, want.TLSCertificateCache.certRefreshPeriod) + if gotCache.certRefreshPeriod != wantCache.certRefreshPeriod { + return fmt.Errorf("certRefreshPeriod not Matched\tgot: %s\twant: %s", gotCache.certRefreshPeriod, wantCache.certRefreshPeriod) } return nil }, @@ -680,23 +669,21 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { CertPath: "", }, }, - want: &TLSConfigWithTLSCertificateCache{ - &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: nil, - ClientAuth: tls.RequireAndVerifyClientCert, + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, }, - nil, + SessionTicketsDisabled: true, + Certificates: nil, + ClientAuth: tls.RequireAndVerifyClientCert, }, - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if got.TLSConfig.Certificates != nil { + wantCache: nil, + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if gotConfig.Certificates != nil { return fmt.Errorf("Certificates not nil") } return nil @@ -709,23 +696,21 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { CertPath: "", }, }, - want: &TLSConfigWithTLSCertificateCache{ - &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: nil, - ClientAuth: tls.RequireAndVerifyClientCert, + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, }, - nil, + SessionTicketsDisabled: true, + Certificates: nil, + ClientAuth: tls.RequireAndVerifyClientCert, }, - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if got.TLSConfig.Certificates != nil { + wantCache: nil, + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if gotConfig.Certificates != nil { return fmt.Errorf("Certificates not nil") } return nil @@ -740,27 +725,25 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { }, }, - want: &TLSConfigWithTLSCertificateCache{ - &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - tls.X25519, - }, - SessionTicketsDisabled: true, - Certificates: func() []tls.Certificate { - cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) - return []tls.Certificate{cert} - }(), - ClientAuth: tls.RequireAndVerifyClientCert, + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + tls.X25519, }, - nil, + SessionTicketsDisabled: true, + Certificates: func() []tls.Certificate { + cert, _ := tls.LoadX509KeyPair(defaultArgs.cfg.CertPath, defaultArgs.cfg.KeyPath) + return []tls.Certificate{cert} + }(), + ClientAuth: tls.RequireAndVerifyClientCert, }, - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if got.TLSConfig.ClientAuth != 0 { - return fmt.Errorf("ClientAuth is :\t%d", got.TLSConfig.ClientAuth) + wantCache: nil, + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if gotConfig.ClientAuth != 0 { + return fmt.Errorf("ClientAuth is :\t%d", gotConfig.ClientAuth) } return nil }, @@ -774,11 +757,15 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { }, }, - want: nil, - wantErr: errors.New("tls.LoadX509KeyPair(cert, key): tls: failed to find any PEM data in certificate input"), - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if got != nil { - return fmt.Errorf("got not nil :\tgot %d \twant %d", &got, &want) + wantConfig: nil, + wantCache: nil, + wantErr: errors.New("tls.LoadX509KeyPair(cert, key): tls: failed to find any PEM data in certificate input"), + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if gotConfig != nil { + return fmt.Errorf("gotConfig not nil :\tgot %d \twant %d", &gotConfig, &wantConfig) + } + if gotCache != nil { + return fmt.Errorf("gotConfig not nil :\tgot %d \twant %d", &gotCache, &wantCache) } return nil }, @@ -793,11 +780,15 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { CertRefreshPeriod: "invalid duration", }, }, - want: nil, - wantErr: errors.New("cannot isValidDuration(cfg.CertRefreshPeriod): time: invalid duration \"invalid duration\""), - checkFunc: func(got, want *TLSConfigWithTLSCertificateCache) error { - if got != nil { - return fmt.Errorf("got not nil :\tgot %d \twant %d", &got, &want) + wantConfig: nil, + wantCache: nil, + wantErr: errors.New("cannot isValidDuration(cfg.CertRefreshPeriod): time: invalid duration \"invalid duration\""), + checkFunc: func(gotConfig *tls.Config, gotCache *TLSCertificateCache, wantConfig *tls.Config, wantCache *TLSCertificateCache) error { + if gotConfig != nil { + return fmt.Errorf("gotConfig not nil :\tgot %d \twant %d", &gotConfig, &wantConfig) + } + if gotCache != nil { + return fmt.Errorf("gotConfig not nil :\tgot %d \twant %d", &gotCache, &wantCache) } return nil }, @@ -809,7 +800,7 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { tt.beforeFunc(tt.args) } - got, err := NewTLSConfigWithTLSCertificateCache(tt.args.cfg) + gotConfig, gotCache, err := NewTLSConfigWithTLSCertificateCache(tt.args.cfg) if tt.wantErr == nil && err != nil { t.Errorf("NewTLSConfigWithTLSCertificateCache() error = %v, wantErr %v", err, tt.wantErr) return @@ -823,7 +814,7 @@ func TestNewTLSConfigWithTLSCertificateCache(t *testing.T) { } if tt.checkFunc != nil { - err = tt.checkFunc(got, tt.want) + err = tt.checkFunc(gotConfig, gotCache, tt.wantConfig, tt.wantCache) if err != nil { t.Errorf("NewTLSConfigWithTLSCertificateCache() error = %v", err) return diff --git a/usecase/authz_proxyd.go b/usecase/authz_proxyd.go index 7004a77..040f32a 100644 --- a/usecase/authz_proxyd.go +++ b/usecase/authz_proxyd.go @@ -76,12 +76,10 @@ func New(cfg config.Config) (AuthzProxyDaemon, error) { var tlsConfig *tls.Config var tlsCertificateCache *service.TLSCertificateCache if cfg.Server.TLS.Enable { - configWithCache, err := service.NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS) + tlsConfig, tlsCertificateCache, err = service.NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS) if err != nil { return nil, errors.Wrap(err, "cannot NewTLSConfigWithTLSCertificateCache(cfg.Server.TLS)") } - tlsConfig = configWithCache.TLSConfig - tlsCertificateCache = configWithCache.TLSCertificateCache serverOption = append(serverOption, service.WithTLSConfig(tlsConfig)) } diff --git a/usecase/authz_proxyd_test.go b/usecase/authz_proxyd_test.go index 85559b7..456f6ad 100644 --- a/usecase/authz_proxyd_test.go +++ b/usecase/authz_proxyd_test.go @@ -361,13 +361,13 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return errs[i].Error() < errs[j].Error() } } - defaultTLSConfig := config.TLS{ + defaultConfig := config.TLS{ Enable: true, CertPath: "../test/data/dummyServer.crt", KeyPath: "../test/data/dummyServer.key", CertRefreshPeriod: "5s", } - defaultTLSConfigWithTLSCertificateCache, _ := service.NewTLSConfigWithTLSCertificateCache(defaultTLSConfig) + _, defaultTLSCache, _ := service.NewTLSConfigWithTLSCertificateCache(defaultConfig) tests := []test{ func() test { ctx, cancel := context.WithCancel(context.Background()) @@ -376,7 +376,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { fields: fields{ cfg: config.Config{ Server: config.Server{ - TLS: defaultTLSConfig, + TLS: defaultConfig, }, }, athenz: &service.AuthorizerdMock{ @@ -404,7 +404,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, - tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertificateCache, + tlsCertificateCache: defaultTLSCache, }, args: args{ ctx: ctx, @@ -592,7 +592,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { fields: fields{ cfg: config.Config{ Server: config.Server{ - TLS: defaultTLSConfig, + TLS: defaultConfig, }, }, athenz: &service.AuthorizerdMock{ @@ -621,7 +621,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, - tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertificateCache, + tlsCertificateCache: defaultTLSCache, }, args: args{ ctx: ctx, @@ -667,7 +667,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { fields: fields{ cfg: config.Config{ Server: config.Server{ - TLS: defaultTLSConfig, + TLS: defaultConfig, }, }, athenz: &service.AuthorizerdMock{ @@ -696,7 +696,7 @@ func Test_authzProxyDaemon_Start(t *testing.T) { return ech }, }, - tlsCertificateCache: defaultTLSConfigWithTLSCertificateCache.TLSCertificateCache, + tlsCertificateCache: defaultTLSCache, }, args: args{ ctx: ctx, From 66aab41a90a776917d48fb592feef6ea0045b811 Mon Sep 17 00:00:00 2001 From: Kyo Fujisaki Date: Mon, 13 Feb 2023 16:22:16 +0900 Subject: [PATCH 58/58] Add not refreshed log Signed-off-by: Kyo Fujisaki --- service/tls.go | 2 ++ service/tls_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/service/tls.go b/service/tls.go index f6db879..c6ef4ce 100644 --- a/service/tls.go +++ b/service/tls.go @@ -204,6 +204,8 @@ func (tcc *TLSCertificateCache) RefreshCertificate(ctx context.Context) error { tcc.serverCertHash = serverCertHash tcc.serverCertKeyHash = serverCertKeyHash glg.Info("Refreshed server certificate") + } else { + glg.Info("Server certificate is same as the file") } tcc.serverCertMutex.Unlock() } diff --git a/service/tls_test.go b/service/tls_test.go index a9b7b0c..4113474 100644 --- a/service/tls_test.go +++ b/service/tls_test.go @@ -1119,6 +1119,79 @@ func TestTLSCertificateCache_RefreshCertificate(t *testing.T) { func() test { ctx, cancelFunc := context.WithCancel(context.Background()) + return test{ + name: "Test not refresh and stop", + fields: fields{ + serverCert: oldCert, + serverCertHash: oldCertHash, + serverCertKeyHash: oldCertKeyHash, + serverCertPath: testCertPath, + serverCertKeyPath: testCertKeyPath, + certRefreshPeriod: 500 * time.Millisecond, + serverCertMutex: sync.Mutex{}, + }, + args: args{ + ctx: ctx, + }, + beforeFunc: func() error { + err := copyCert(oldCertPath, testCertPath) + if err != nil { + return err + } + err = copyCert(oldCertKeyPath, testCertKeyPath) + if err != nil { + return err + } + return nil + }, + checkFunc: func(tcc *TLSCertificateCache, want error) error { + cachedCert := tcc.serverCert.Load() + cc, _ := x509.ParseCertificate(cachedCert.(*tls.Certificate).Certificate[0]) + oc, _ := x509.ParseCertificate(oldCertData.Certificate[0]) + if cc.SerialNumber.String() != oc.SerialNumber.String() { + return errors.New("cached cert / old cert Serial Number not Matched") + } + + // wait refresh period + time.Sleep(1 * time.Second) + cachedCert = tcc.serverCert.Load() + cc, _ = x509.ParseCertificate(cachedCert.(*tls.Certificate).Certificate[0]) + // check cert not refreshed + if cc.SerialNumber.String() != oc.SerialNumber.String() { + return errors.New("cached cert / old cert Serial Number not Matched") + } + // refresh stop + cancelFunc() + err = copyCert(newCertPath, testCertPath) + if err != nil { + return err + } + time.Sleep(1 * time.Second) + nc, _ := x509.ParseCertificate(newCert.Certificate[0]) + if cc.SerialNumber.String() == nc.SerialNumber.String() { + return errors.New("refresh not stopped") + } + return nil + }, + afterFunc: func() error { + cancelFunc() + err := os.Remove(testCertPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } + err = os.Remove(testCertKeyPath) + if err != nil { + t.Errorf("test cert remove failed: %s", err) + return err + } + return nil + }, + } + }(), + func() test { + ctx, cancelFunc := context.WithCancel(context.Background()) + return test{ name: "Test invalid cert not refresh, next period refresh success", fields: fields{