diff --git a/health/check.go b/health/check.go index 12795bfd..e14f1255 100644 --- a/health/check.go +++ b/health/check.go @@ -33,7 +33,7 @@ type Check struct { state *State } -// NewCheck creates a new Check +// NewCheck creates a new Check. Every check should have a unique name. // // You are able to return custom statuses by returning a StatusError from the check function. This way you can perform // checks that return a status other than up or down. For example, you can return a status of "degraded" if the check diff --git a/health/checker.go b/health/checker.go index 348fdec4..03476ab8 100644 --- a/health/checker.go +++ b/health/checker.go @@ -4,53 +4,38 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "sync" ) -// Checker is an interface that defines a health checker. -type Checker interface { - // Handler returns the handler for the check. - Handler() http.HandlerFunc - - // Check returns the result of the check. - Check(ctx context.Context) *Result -} - -// checker is a struct that implements the Checker interface. +// Checker is a struct that handles the checking of multiple health checks. // // This is a group of checks that can be run in parallel. -type checker struct { - baseCtx context.Context - cancel context.CancelFunc - mtx *sync.Mutex - checks []*Check +type Checker struct { + checks sync.Map httpStatusCodeUp int httpStatusCodeDown int } // NewChecker creates a new Checker. -func NewChecker(opts ...CheckerOption) Checker { - c := &checker{ - mtx: new(sync.Mutex), - checks: make([]*Check, 0), +func NewChecker(opts ...CheckerOption) (*Checker, error) { + c := &Checker{ httpStatusCodeUp: http.StatusOK, httpStatusCodeDown: http.StatusServiceUnavailable, } for _, opt := range opts { - opt(c) - } - - if c.baseCtx == nil { - c.baseCtx, c.cancel = context.WithCancel(context.Background()) + if err := opt(c); err != nil { + return nil, fmt.Errorf("failed to apply checker option: %w", err) + } } - return c + return c, nil } -func (c *checker) httpCodeFromStatus(status Status) int { +func (c *Checker) httpCodeFromStatus(status Status) int { switch status { case StatusUp: return c.httpStatusCodeUp @@ -62,7 +47,7 @@ func (c *checker) httpCodeFromStatus(status Status) int { } // Handler returns the handler for the check. -func (c *checker) Handler() http.HandlerFunc { +func (c *Checker) Handler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { result := c.Check(r.Context()) httpStatus := c.httpCodeFromStatus(result.Status) @@ -73,10 +58,7 @@ func (c *checker) Handler() http.HandlerFunc { } // Check returns the result of the check. -func (c *checker) Check(ctx context.Context) *Result { - c.mtx.Lock() - defer c.mtx.Unlock() - +func (c *Checker) Check(ctx context.Context) *Result { if ctx == nil { ctx = context.Background() } @@ -84,10 +66,17 @@ func (c *checker) Check(ctx context.Context) *Result { result := NewResult() wg := new(sync.WaitGroup) - for _, check := range c.checks { + c.checks.Range(func(key, value any) bool { + check, ok := value.(*Check) + if !ok { + // This "should" never happen, but just in case + return true + } + wg.Add(1) - go func(check *Check, result *Result) { + go func(check *Check) { defer wg.Done() + checkResult := NewResult() checkStatus := StatusUp @@ -104,9 +93,30 @@ func (c *checker) Check(ctx context.Context) *Result { result.SetStatus(checkResult.Status) result.addDetail(check.String(), checkResult) - }(check, result) - } + }(check) + return true + }) wg.Wait() + return result } + +// AddCheck adds a check to the checker. +func (c *Checker) AddCheck(check *Check) error { + if check == nil { + return errors.New("check is nil") + } + + if check.name == "" { + return errors.New("check name is empty") + } + + if _, ok := c.checks.Load(check.String()); ok { + return fmt.Errorf("check already exists with the same key: %s", check.String()) + } + + c.checks.Store(check.String(), check) + + return nil +} diff --git a/health/checker_options.go b/health/checker_options.go index ca078bd4..49ff5669 100644 --- a/health/checker_options.go +++ b/health/checker_options.go @@ -1,52 +1,45 @@ package health -import ( - "context" -) - -type CheckerOption func(*checker) - -// WithCheckerBaseContext sets the base context for the checker. -func WithCheckerBaseContext(baseCtx context.Context) CheckerOption { - return func(c *checker) { - ctx, cancel := context.WithCancel(baseCtx) - c.baseCtx = ctx - c.cancel = cancel - } -} +import "fmt" + +type CheckerOption func(*Checker) error // WithCheckerCheck adds a single check to the checker. func WithCheckerCheck(check *Check) CheckerOption { - return func(c *checker) { - if check == nil { - return + return func(c *Checker) error { + if err := c.AddCheck(check); err != nil { + return fmt.Errorf("failed to add check: %w", err) } - if c.checks == nil { - c.checks = make([]*Check, 0) - } - c.checks = append(c.checks, check) + + return nil } } // WithCheckerChecks adds multiple checks to the checker. func WithCheckerChecks(checks ...*Check) CheckerOption { - return func(c *checker) { + return func(c *Checker) error { for _, check := range checks { - WithCheckerCheck(check)(c) + if err := WithCheckerCheck(check)(c); err != nil { + return fmt.Errorf("failed to add check %s: %w", check.String(), err) + } } + + return nil } } // WithCheckerHTTPCodeUp sets the HTTP status code when the system is up. func WithCheckerHTTPCodeUp(code int) CheckerOption { - return func(c *checker) { + return func(c *Checker) error { c.httpStatusCodeUp = code + return nil } } // WithCheckerHTTPCodeDown sets the HTTP status code when the system is down. func WithCheckerHTTPCodeDown(code int) CheckerOption { - return func(c *checker) { + return func(c *Checker) error { c.httpStatusCodeDown = code + return nil } } diff --git a/health/checker_test.go b/health/checker_test.go index 5f45f355..b7e3cfc2 100644 --- a/health/checker_test.go +++ b/health/checker_test.go @@ -18,14 +18,12 @@ func TestNewChecker(t *testing.T) { return nil }) - got := NewChecker(WithCheckerCheck(gotCheck)) + c, err := NewChecker(WithCheckerCheck(gotCheck)) + require.NoError(t, err) - c, ok := got.(*checker) - require.True(t, ok, "NewChecker() should return a *checker") require.NotNil(t, c) require.Equal(t, http.StatusOK, c.httpStatusCodeUp) require.Equal(t, http.StatusServiceUnavailable, c.httpStatusCodeDown) - require.NotNil(t, c.baseCtx, "NewChecker() should set the base context") } func TestNewCheckerHandler_Single(t *testing.T) { @@ -36,7 +34,8 @@ func TestNewCheckerHandler_Single(t *testing.T) { return nil }) - got := NewChecker(WithCheckerCheck(gotCheck)) + got, err := NewChecker(WithCheckerCheck(gotCheck)) + require.NoError(t, err) handler := got.Handler() require.NotNil(t, handler) @@ -62,7 +61,8 @@ func TestNewCheckerHandler_Single_StatusError(t *testing.T) { return NewStatusError(errors.New("test error"), StatusDegraded) }) - got := NewChecker(WithCheckerCheck(gotCheck)) + got, err := NewChecker(WithCheckerCheck(gotCheck)) + require.NoError(t, err) handler := got.Handler() require.NotNil(t, handler) @@ -88,7 +88,8 @@ func TestNewCheckerHandler_Single_StatusError_InvalidStatus(t *testing.T) { return NewStatusError(errors.New("test error"), 123) }) - got := NewChecker(WithCheckerCheck(gotCheck)) + got, err := NewChecker(WithCheckerCheck(gotCheck)) + require.NoError(t, err) handler := got.Handler() require.NotNil(t, handler) @@ -118,7 +119,8 @@ func TestNewCheckerHandler_Multiple(t *testing.T) { return nil }) - got := NewChecker(WithCheckerChecks([]*Check{gotCheck, secondCheck}...)) + got, err := NewChecker(WithCheckerChecks([]*Check{gotCheck, secondCheck}...)) + require.NoError(t, err) handler := got.Handler() require.NotNil(t, handler) @@ -144,7 +146,8 @@ func TestNewCheckerHandler_Single_Error(t *testing.T) { return errors.New("test error") }) - got := NewChecker(WithCheckerCheck(gotCheck)) + got, err := NewChecker(WithCheckerCheck(gotCheck)) + require.NoError(t, err) handler := got.Handler() require.NotNil(t, handler) @@ -170,7 +173,8 @@ func TestNewCheckerHandler_NoParentContext(t *testing.T) { return nil }) - got := NewChecker(WithCheckerCheck(gotCheck)) + got, err := NewChecker(WithCheckerCheck(gotCheck)) + require.NoError(t, err) handler := got.Handler() require.NotNil(t, handler) @@ -224,12 +228,78 @@ func TestChecker_HttpCodeFromStatus(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := NewChecker() - c, ok := got.(*checker) - require.True(t, ok, "NewChecker() should return a *checker") + c, err := NewChecker() + require.NoError(t, err) require.NotNil(t, c) - require.Equal(t, tt.expectedStatus, c.httpCodeFromStatus(tt.status)) }) } } + +func TestChecker_AddCheck(t *testing.T) { + t.Parallel() + + c, err := NewChecker() + require.NoError(t, err) + require.NotNil(t, c) + + check := NewCheck("test_check", func(_ context.Context) error { + return nil + }) + + err = c.AddCheck(check) + require.NoError(t, err) + + // Check if the check was added + c.checks.Range(func(key, value any) bool { + require.Equal(t, "test_check", key) + return false + }) +} + +func TestChecker_AddTest_Invalid_Nil(t *testing.T) { + t.Parallel() + + c, err := NewChecker() + require.NoError(t, err) + require.NotNil(t, c) + + err = c.AddCheck(nil) + require.Error(t, err) + require.Equal(t, "check is nil", err.Error()) +} + +func TestChecker_AddTest_Invalid_NoName(t *testing.T) { + t.Parallel() + + c, err := NewChecker() + require.NoError(t, err) + require.NotNil(t, c) + + check := NewCheck("", func(_ context.Context) error { + return nil + }) + + err = c.AddCheck(check) + require.Error(t, err) + require.Equal(t, "check name is empty", err.Error()) +} + +func TestChecker_AddTest_Invalid_AlreadyExists(t *testing.T) { + t.Parallel() + + c, err := NewChecker() + require.NoError(t, err) + require.NotNil(t, c) + + check := NewCheck("test_check", func(_ context.Context) error { + return nil + }) + + err = c.AddCheck(check) + require.NoError(t, err) + + err = c.AddCheck(check) + require.Error(t, err) + require.Equal(t, "check already exists with the same key: test_check", err.Error()) +} diff --git a/health/mock_Checker.go b/health/mock_Checker.go deleted file mode 100644 index c9becff6..00000000 --- a/health/mock_Checker.go +++ /dev/null @@ -1,69 +0,0 @@ -// Code generated by mockery. DO NOT EDIT. - -package health - -import ( - context "context" - http "net/http" - - mock "github.com/stretchr/testify/mock" -) - -// MockChecker is an autogenerated mock type for the Checker type -type MockChecker struct { - mock.Mock -} - -// Check provides a mock function with given fields: ctx -func (_m *MockChecker) Check(ctx context.Context) *Result { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Check") - } - - var r0 *Result - if rf, ok := ret.Get(0).(func(context.Context) *Result); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*Result) - } - } - - return r0 -} - -// Handler provides a mock function with no fields -func (_m *MockChecker) Handler() http.HandlerFunc { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Handler") - } - - var r0 http.HandlerFunc - if rf, ok := ret.Get(0).(func() http.HandlerFunc); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(http.HandlerFunc) - } - } - - return r0 -} - -// NewMockChecker creates a new instance of MockChecker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockChecker(t interface { - mock.TestingT - Cleanup(func()) -}) *MockChecker { - mock := &MockChecker{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/health/mock_CheckerOption.go b/health/mock_CheckerOption.go index b745cfe1..69da0e0c 100644 --- a/health/mock_CheckerOption.go +++ b/health/mock_CheckerOption.go @@ -10,7 +10,7 @@ type MockCheckerOption struct { } // Execute provides a mock function with given fields: _a0 -func (_m *MockCheckerOption) Execute(_a0 *checker) { +func (_m *MockCheckerOption) Execute(_a0 *Checker) { _m.Called(_a0) } diff --git a/options.go b/options.go index 216c8a89..ff20a0e4 100644 --- a/options.go +++ b/options.go @@ -178,12 +178,16 @@ func WithLeaderElection(lockName string) StartOption { // WithHealthCheck is a StartOption that sets up the health2 check. func WithHealthCheck(checks ...*health.Check) StartOption { return func(a *App) error { - checkerOpts := make([]health.CheckerOption, 0) - for _, check := range checks { - checkerOpts = append(checkerOpts, health.WithCheckerCheck(check)) + checker, err := health.NewChecker() + if err != nil { + return fmt.Errorf("error creating health checker: %w", err) } - checker := health.NewChecker(checkerOpts...) + for _, check := range checks { + if err := checker.AddCheck(check); err != nil { + return fmt.Errorf("error adding health check %s: %w", check.String(), err) + } + } a.servers.Store("health", &http.Server{ Addr: fmt.Sprintf(":%d", HealthPort),