diff --git a/changes/20250820002654.feature b/changes/20250820002654.feature new file mode 100644 index 0000000000..cdae0a6d02 --- /dev/null +++ b/changes/20250820002654.feature @@ -0,0 +1 @@ +:sparkles: `[parallelisation]` Added new groups (ContextualFunctionGroup) and new Store options to configure the execution (number of workers, single execution, etc.) diff --git a/changes/20250820140853.feature b/changes/20250820140853.feature new file mode 100644 index 0000000000..d79a2a989e --- /dev/null +++ b/changes/20250820140853.feature @@ -0,0 +1 @@ +:sparkles: `[parallelisation]` Added new compound execution group to support nested execution groups diff --git a/utils/parallelisation/cancel_functions.go b/utils/parallelisation/cancel_functions.go index daa6ab8885..3464084e04 100644 --- a/utils/parallelisation/cancel_functions.go +++ b/utils/parallelisation/cancel_functions.go @@ -5,245 +5,23 @@ package parallelisation -import ( - "context" +import "context" - "github.com/sasha-s/go-deadlock" - "golang.org/x/sync/errgroup" - - "github.com/ARM-software/golang-utils/utils/commonerrors" - "github.com/ARM-software/golang-utils/utils/reflection" -) - -type StoreOptions struct { - clearOnExecution bool - stopOnFirstError bool - sequential bool - reverse bool - joinErrors bool -} -type StoreOption func(*StoreOptions) *StoreOptions - -// StopOnFirstError stops store execution on first error. -var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.stopOnFirstError = true - o.joinErrors = false - return o -} - -// JoinErrors will collate any errors which happened when executing functions in store. -// This option should not be used in combination to StopOnFirstError. -var JoinErrors StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.stopOnFirstError = false - o.joinErrors = true - return o -} - -// ExecuteAll executes all functions in the store even if an error is raised. the first error raised is then returned. -var ExecuteAll StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.stopOnFirstError = false - return o -} - -// ClearAfterExecution clears the store after execution. -var ClearAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.clearOnExecution = true - return o -} - -// RetainAfterExecution keep the store intact after execution (no reset). -var RetainAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.clearOnExecution = false - return o -} - -// Parallel ensures every function registered in the store is executed concurrently in the order they were registered. -var Parallel StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.sequential = false - return o -} - -// Sequential ensures every function registered in the store is executed sequentially in the order they were registered. -var Sequential StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.sequential = true - return o -} - -// SequentialInReverse ensures every function registered in the store is executed sequentially but in the reverse order they were registered. -var SequentialInReverse StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.sequential = true - o.reverse = true - return o -} - -func newFunctionStore[T any](executeFunc func(context.Context, T) error, options ...StoreOption) *store[T] { - - opts := &StoreOptions{} - - for i := range options { - opts = options[i](opts) - } - return &store[T]{ - mu: deadlock.RWMutex{}, - functions: make([]T, 0), - executeFunc: executeFunc, - options: *opts, - } -} - -type store[T any] struct { - mu deadlock.RWMutex - functions []T - executeFunc func(ctx context.Context, element T) error - options StoreOptions -} - -func (s *store[T]) RegisterFunction(function ...T) { - defer s.mu.Unlock() - s.mu.Lock() - s.functions = append(s.functions, function...) -} - -func (s *store[T]) Len() int { - defer s.mu.RUnlock() - s.mu.RLock() - return len(s.functions) -} - -func (s *store[T]) Execute(ctx context.Context) (err error) { - defer s.mu.Unlock() - s.mu.Lock() - if reflection.IsEmpty(s.executeFunc) { - return commonerrors.New(commonerrors.ErrUndefined, "the store was not initialised correctly") - } - - if s.options.sequential { - err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse, s.options.joinErrors) - } else { - err = s.executeConcurrently(ctx, s.options.stopOnFirstError, s.options.joinErrors) - } - - if err == nil && s.options.clearOnExecution { - s.functions = make([]T, 0, len(s.functions)) - } - return -} - -func (s *store[T]) executeConcurrently(ctx context.Context, stopOnFirstError bool, collateErrors bool) error { - g, gCtx := errgroup.WithContext(ctx) - if !stopOnFirstError { - gCtx = ctx - } - funcNum := len(s.functions) - errCh := make(chan error, funcNum) - g.SetLimit(funcNum) - for i := range s.functions { - g.Go(func() error { - _, subErr := s.executeFunction(gCtx, s.functions[i]) - errCh <- subErr - return subErr - }) - } - err := g.Wait() - close(errCh) - if collateErrors { - collateErr := make([]error, funcNum) - i := 0 - for subErr := range errCh { - collateErr[i] = subErr - i++ - } - err = commonerrors.Join(collateErr...) - } - - return err +type CancelFunctionStore struct { + ExecutionGroup[context.CancelFunc] } -func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse, collateErrors bool) (err error) { - err = DetermineContextError(ctx) - if err != nil { - return - } - funcNum := len(s.functions) - collateErr := make([]error, funcNum) - if reverse { - for i := funcNum - 1; i >= 0; i-- { - shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) - collateErr[funcNum-i-1] = subErr - if shouldBreak { - err = subErr - return - } - if subErr != nil && err == nil { - err = subErr - if stopOnFirstError { - return - } - } - } - } else { - for i := range s.functions { - shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) - collateErr[i] = subErr - if shouldBreak { - err = subErr - return - } - if subErr != nil && err == nil { - err = subErr - if stopOnFirstError { - return - } - } - } - } - - if collateErrors { - err = commonerrors.Join(collateErr...) - } - return +func (s *CancelFunctionStore) RegisterCancelFunction(cancel ...context.CancelFunc) { + s.ExecutionGroup.RegisterFunction(cancel...) } -func (s *store[T]) executeFunction(ctx context.Context, element T) (mustBreak bool, err error) { - err = DetermineContextError(ctx) - if err != nil { - mustBreak = true +func (s *CancelFunctionStore) RegisterCancelStore(store *CancelFunctionStore) { + if store == nil { return } - err = s.executeFunc(ctx, element) - return -} - -type CancelFunctionStore struct { - store[context.CancelFunc] -} - -func (s *CancelFunctionStore) RegisterCancelFunction(cancel ...context.CancelFunc) { - s.store.RegisterFunction(cancel...) + s.RegisterCancelFunction(func() { + store.Cancel() + }) } // Cancel will execute the cancel functions in the store. Any errors will be ignored and Execute() is recommended if you need to know if a cancellation failed @@ -252,15 +30,14 @@ func (s *CancelFunctionStore) Cancel() { } func (s *CancelFunctionStore) Len() int { - return s.store.Len() + return s.ExecutionGroup.Len() } // NewCancelFunctionsStore creates a store for cancel functions. Whatever the options passed, all cancel functions will be executed and cleared. In other words, options `RetainAfterExecution` and `StopOnFirstError` would be discarded if selected to create the Cancel store func NewCancelFunctionsStore(options ...StoreOption) *CancelFunctionStore { return &CancelFunctionStore{ - store: *newFunctionStore[context.CancelFunc](func(_ context.Context, cancelFunc context.CancelFunc) error { - cancelFunc() - return nil + ExecutionGroup: *NewExecutionGroup[context.CancelFunc](func(ctx context.Context, cancelFunc context.CancelFunc) error { + return WrapCancelToContextualFunc(cancelFunc)(ctx) }, append(options, ClearAfterExecution, ExecuteAll)...), } } diff --git a/utils/parallelisation/cancel_functions_test.go b/utils/parallelisation/cancel_functions_test.go index 75163f5081..73130d2b67 100644 --- a/utils/parallelisation/cancel_functions_test.go +++ b/utils/parallelisation/cancel_functions_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/atomic" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" @@ -19,24 +20,35 @@ func testCancelStore(t *testing.T, store *CancelFunctionStore) { t.Helper() require.NotNil(t, store) // Set up some fake CancelFuncs to make sure they are called - called1 := false - called2 := false + called1 := atomic.NewBool(false) + called2 := atomic.NewBool(false) + called3 := atomic.NewBool(false) + cancelFunc1 := func() { - called1 = true + called1.Store(true) } cancelFunc2 := func() { - called2 = true + called2.Store(true) + } + cancelFunc3 := func() { + called3.Store(true) } + subStore := NewCancelFunctionsStore() + subStore.RegisterCancelFunction(cancelFunc3) store.RegisterCancelFunction(cancelFunc1, cancelFunc2) + store.RegisterCancelStore(subStore) + store.RegisterCancelStore(nil) - assert.Equal(t, 2, store.Len()) - assert.False(t, called1) - assert.False(t, called2) + assert.Equal(t, 3, store.Len()) + assert.False(t, called1.Load()) + assert.False(t, called2.Load()) + assert.False(t, called3.Load()) store.Cancel() - assert.True(t, called1) - assert.True(t, called2) + assert.True(t, called1.Load()) + assert.True(t, called2.Load()) + assert.True(t, called3.Load()) } // Given a CancelFunctionsStore diff --git a/utils/parallelisation/contextual.go b/utils/parallelisation/contextual.go new file mode 100644 index 0000000000..dbd3d99c0a --- /dev/null +++ b/utils/parallelisation/contextual.go @@ -0,0 +1,41 @@ +package parallelisation + +import ( + "context" + + "github.com/ARM-software/golang-utils/utils/commonerrors" +) + +// DetermineContextError determines what the context error is if any. +func DetermineContextError(ctx context.Context) error { + return commonerrors.ConvertContextError(ctx.Err()) +} + +type ContextualFunc func(ctx context.Context) error + +type ContextualFunctionGroup struct { + ExecutionGroup[ContextualFunc] +} + +// NewContextualGroup returns a group executing contextual functions. +func NewContextualGroup(options ...StoreOption) *ContextualFunctionGroup { + return &ContextualFunctionGroup{ + ExecutionGroup: *NewExecutionGroup[ContextualFunc](func(ctx context.Context, contextualF ContextualFunc) error { + return contextualF(ctx) + }, options...), + } +} + +// ForEach executes all the contextual functions according to the store options and returns an error if one occurred. +func ForEach(ctx context.Context, executionOptions *StoreOptions, contextualFunc ...ContextualFunc) error { + group := NewContextualGroup(ExecuteAll(executionOptions).Options()...) + group.RegisterFunction(contextualFunc...) + return group.Execute(ctx) +} + +// BreakOnError executes each functions in the group until an error is found or the context gets cancelled. +func BreakOnError(ctx context.Context, executionOptions *StoreOptions, contextualFunc ...ContextualFunc) error { + group := NewContextualGroup(StopOnFirstError(executionOptions).Options()...) + group.RegisterFunction(contextualFunc...) + return group.Execute(ctx) +} diff --git a/utils/parallelisation/contextual_test.go b/utils/parallelisation/contextual_test.go new file mode 100644 index 0000000000..dcbe4b8e00 --- /dev/null +++ b/utils/parallelisation/contextual_test.go @@ -0,0 +1,51 @@ +package parallelisation + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" +) + +func TestForEach(t *testing.T) { + cancelFunc := func() {} + t.Run("close with 1 error", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + + errortest.AssertError(t, ForEach(context.Background(), WithOptions(Parallel), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + }) + + t.Run("close with 1 error but error collection", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + errortest.AssertError(t, ForEach(context.Background(), WithOptions(Parallel), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + }) + + t.Run("close with 1 error and limited number of parallel workers", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + errortest.AssertError(t, ForEach(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + }) + + t.Run("close with 1 error but sequential", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + errortest.AssertError(t, ForEach(context.Background(), WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + errortest.AssertError(t, BreakOnError(context.Background(), WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + }) + + t.Run("close with cancellation", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + errortest.AssertError(t, ForEach(cancelCtx, WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), commonerrors.ErrCancelled) + errortest.AssertError(t, BreakOnError(cancelCtx, WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc)), commonerrors.ErrCancelled) + }) + + t.Run("break on error with no error", func(t *testing.T) { + require.NoError(t, BreakOnError(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc))) + }) + t.Run("for each with no error", func(t *testing.T) { + require.NoError(t, ForEach(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc))) + }) +} diff --git a/utils/parallelisation/group.go b/utils/parallelisation/group.go new file mode 100644 index 0000000000..a7310641d6 --- /dev/null +++ b/utils/parallelisation/group.go @@ -0,0 +1,435 @@ +package parallelisation + +import ( + "context" + "math" + + "github.com/sasha-s/go-deadlock" + "go.uber.org/atomic" + "golang.org/x/sync/errgroup" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/reflection" + "github.com/ARM-software/golang-utils/utils/safecast" +) + +type StoreOptions struct { + clearOnExecution bool + stopOnFirstError bool + sequential bool + reverse bool + joinErrors bool + onlyOnce bool + workers int +} + +func (o *StoreOptions) Default() *StoreOptions { + o.clearOnExecution = false + o.stopOnFirstError = false + o.sequential = false + o.reverse = false + o.joinErrors = false + o.onlyOnce = false + o.workers = 0 + return o +} + +func (o *StoreOptions) Merge(opts *StoreOptions) *StoreOptions { + if opts == nil { + return o + } + o.clearOnExecution = opts.clearOnExecution || o.clearOnExecution + o.stopOnFirstError = opts.stopOnFirstError || o.stopOnFirstError + o.sequential = opts.sequential || o.sequential + o.reverse = opts.reverse || o.reverse + o.joinErrors = opts.joinErrors || o.joinErrors + o.onlyOnce = opts.onlyOnce || o.onlyOnce + o.workers = safecast.ToInt(math.Max(float64(opts.workers), float64(o.workers))) + return o +} + +func (o *StoreOptions) MergeWithOptions(opt ...StoreOption) *StoreOptions { + return o.Merge(WithOptions(opt...)) +} + +func (o *StoreOptions) Overwrite(opts *StoreOptions) *StoreOptions { + return o.Default().Merge(opts) +} + +func (o *StoreOptions) WithOptions(opts ...StoreOption) *StoreOptions { + return o.Overwrite(WithOptions(opts...)) +} + +func (o *StoreOptions) Options() []StoreOption { + return []StoreOption{ + func(opts *StoreOptions) *StoreOptions { + op := o + if op == nil { + op = DefaultOptions() + } + return op.Merge(opts) + }, + } +} + +type StoreOption func(*StoreOptions) *StoreOptions + +// StopOnFirstError stops ExecutionGroup execution on first error. +var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.stopOnFirstError = true + o.joinErrors = false + return o +} + +// JoinErrors will collate any errors which happened when executing functions in ExecutionGroup. +// This option should not be used in combination to StopOnFirstError. +var JoinErrors StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.stopOnFirstError = false + o.joinErrors = true + return o +} + +// OnlyOnce will ensure the function are executed only once if they do. +var OnlyOnce StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.onlyOnce = true + return o +} + +// AnyTimes will allow the functions to be executed as often that they might be. +var AnyTimes StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.onlyOnce = false + return o +} + +// ExecuteAll executes all functions in the ExecutionGroup even if an error is raised. the first error raised is then returned. +var ExecuteAll StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.stopOnFirstError = false + return o +} + +// ClearAfterExecution clears the ExecutionGroup after execution. +var ClearAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.clearOnExecution = true + return o +} + +// RetainAfterExecution keep the ExecutionGroup intact after execution (no reset). +var RetainAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.clearOnExecution = false + return o +} + +// Parallel ensures every function registered in the ExecutionGroup is executed concurrently in the order they were registered. +var Parallel StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.sequential = false + return o +} + +// Workers defines a limit number of workers for executing the function registered in the ExecutionGroup. +func Workers(workers int) StoreOption { + return func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.workers = workers + o.sequential = false + return o + } +} + +// Sequential ensures every function registered in the ExecutionGroup is executed sequentially in the order they were registered. +var Sequential StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.sequential = true + return o +} + +// SequentialInReverse ensures every function registered in the ExecutionGroup is executed sequentially but in the reverse order they were registered. +var SequentialInReverse StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = DefaultOptions() + } + o.sequential = true + o.reverse = true + return o +} + +// WithOptions defines a store configuration. +func WithOptions(option ...StoreOption) (opts *StoreOptions) { + for i := range option { + opts = option[i](opts) + } + if opts == nil { + opts = DefaultOptions() + } + return +} + +// DefaultOptions returns the default store configuration +func DefaultOptions() *StoreOptions { + opts := &StoreOptions{} + return opts.Default() +} + +type IExecutor interface { + // Execute executes all the functions in the group. + Execute(ctx context.Context) error +} + +type IExecutionGroup[T any] interface { + IExecutor + RegisterFunction(function ...T) + Len() int +} + +type ICompoundExecutionGroup[T any] interface { + IExecutionGroup[T] + // RegisterExecutor registers executors of any kind to the group: they could be functions or sub-groups. + RegisterExecutor(executor ...IExecutor) +} + +// NewExecutionGroup returns an execution group which executes functions according to store options. +func NewExecutionGroup[T any](executeFunc ExecuteFunc[T], options ...StoreOption) *ExecutionGroup[T] { + + opts := WithOptions(options...) + return &ExecutionGroup[T]{ + mu: deadlock.RWMutex{}, + functions: make([]wrappedElement[T], 0), + executeFunc: executeFunc, + options: *opts, + } +} + +type ExecuteFunc[T any] func(ctx context.Context, element T) error + +type ExecutionGroup[T any] struct { + mu deadlock.RWMutex + functions []wrappedElement[T] + executeFunc ExecuteFunc[T] + options StoreOptions +} + +// RegisterFunction registers functions to the group. +func (s *ExecutionGroup[T]) RegisterFunction(function ...T) { + defer s.mu.Unlock() + s.mu.Lock() + wrapped := make([]wrappedElement[T], len(function)) + for i := range function { + wrapped[i] = newWrapped(function[i], s.options.onlyOnce) + } + s.functions = append(s.functions, wrapped...) +} + +func (s *ExecutionGroup[T]) Len() int { + defer s.mu.RUnlock() + s.mu.RLock() + return len(s.functions) +} + +// Execute executes all the function in the group according to store options. +func (s *ExecutionGroup[T]) Execute(ctx context.Context) (err error) { + defer s.mu.Unlock() + s.mu.Lock() + if reflection.IsEmpty(s.executeFunc) { + return commonerrors.New(commonerrors.ErrUndefined, "the group was not initialised correctly") + } + + if s.options.sequential { + err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse, s.options.joinErrors) + } else { + err = s.executeConcurrently(ctx, s.options.stopOnFirstError, s.options.joinErrors) + } + + if err == nil && s.options.clearOnExecution { + s.functions = make([]wrappedElement[T], 0, len(s.functions)) + } + return +} + +func (s *ExecutionGroup[T]) executeConcurrently(ctx context.Context, stopOnFirstError bool, collateErrors bool) error { + g, gCtx := errgroup.WithContext(ctx) + if !stopOnFirstError { + gCtx = ctx + } + funcNum := len(s.functions) + workers := s.options.workers + if workers <= 0 { + workers = funcNum + } + errCh := make(chan error, funcNum) + + g.SetLimit(workers) + for i := range s.functions { + g.Go(func() error { + _, subErr := s.executeFunction(gCtx, s.functions[i]) + errCh <- subErr + return subErr + }) + } + err := g.Wait() + close(errCh) + if collateErrors { + collateErr := make([]error, funcNum) + i := 0 + for subErr := range errCh { + collateErr[i] = subErr + i++ + } + err = commonerrors.Join(collateErr...) + } + + return err +} + +func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse, collateErrors bool) (err error) { + err = DetermineContextError(ctx) + if err != nil { + return + } + funcNum := len(s.functions) + collateErr := make([]error, funcNum) + if reverse { + for i := funcNum - 1; i >= 0; i-- { + shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + collateErr[funcNum-i-1] = subErr + if shouldBreak { + err = subErr + return + } + if subErr != nil && err == nil { + err = subErr + if stopOnFirstError { + return + } + } + } + } else { + for i := range s.functions { + shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + collateErr[i] = subErr + if shouldBreak { + err = subErr + return + } + if subErr != nil && err == nil { + err = subErr + if stopOnFirstError { + return + } + } + } + } + + if collateErrors { + err = commonerrors.Join(collateErr...) + } + return +} + +func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, w wrappedElement[T]) (mustBreak bool, err error) { + err = DetermineContextError(ctx) + if err != nil { + mustBreak = true + return + } + if w == nil { + err = commonerrors.UndefinedVariable("function element") + mustBreak = true + return + } + err = w.Execute(ctx, s.executeFunc) + + return +} + +type wrappedElement[T any] interface { + Execute(ctx context.Context, f ExecuteFunc[T]) error +} +type basicWrap[T any] struct { + value T +} + +func (w *basicWrap[T]) Execute(ctx context.Context, f ExecuteFunc[T]) error { + return f(ctx, w.value) +} + +func newBasicWrap[T any](e T) wrappedElement[T] { + return &basicWrap[T]{ + value: e, + } +} + +func newOnce[T any](e T) wrappedElement[T] { + return &once[T]{ + wrappedElement: newBasicWrap[T](e), + once: atomic.NewBool(false), + } +} + +type once[T any] struct { + wrappedElement[T] + once *atomic.Bool +} + +func (w *once[T]) Execute(ctx context.Context, f ExecuteFunc[T]) error { + if !w.once.Swap(true) { + return w.wrappedElement.Execute(ctx, f) + } + return nil +} + +func newWrapped[T any](e T, once bool) wrappedElement[T] { + if once { + return newOnce[T](e) + } else { + return newBasicWrap[T](e) + } +} + +var _ ICompoundExecutionGroup[ContextualFunc] = &CompoundExecutionGroup{} + +// NewCompoundExecutionGroup returns an execution group made of executors +func NewCompoundExecutionGroup(options ...StoreOption) *CompoundExecutionGroup { + return &CompoundExecutionGroup{ + ContextualFunctionGroup: *NewContextualGroup(options...), + } +} + +type CompoundExecutionGroup struct { + ContextualFunctionGroup +} + +// RegisterExecutor registers executors +func (g *CompoundExecutionGroup) RegisterExecutor(group ...IExecutor) { + for i := range group { + g.RegisterFunction(func(ctx context.Context) error { + return group[i].Execute(ctx) + }) + } +} diff --git a/utils/parallelisation/group_test.go b/utils/parallelisation/group_test.go new file mode 100644 index 0000000000..240fffaa5e --- /dev/null +++ b/utils/parallelisation/group_test.go @@ -0,0 +1,192 @@ +package parallelisation + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" + "github.com/ARM-software/golang-utils/utils/parallelisation/mocks" +) + +func TestExecutionTimes(t *testing.T) { + + t.Run("close only Once Parallel with retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + + group := NewCloserStoreWithOptions(ExecuteAll, Parallel, OnlyOnce, RetainAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close only Once Sequential with retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + + group := NewCloserStoreWithOptions(ExecuteAll, OnlyOnce, Sequential, RetainAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close only Once Parallel without retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + + group := NewCloserStoreWithOptions(ExecuteAll, Parallel, OnlyOnce, ClearAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close only Once Sequential without retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + + group := NewCloserStoreWithOptions(ExecuteAll, OnlyOnce, Sequential, ClearAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + + t.Run("close Multiple times Parallel", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(21) + group := NewCloserStoreWithOptions(ExecuteAll, AnyTimes, Parallel, RetainAfterExecution, Workers(3)) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close Multiple times Sequential", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(21) + group := NewCloserStoreWithOptions(ExecuteAll, AnyTimes, Sequential, RetainAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + + t.Run("close Multiple times Parallel without retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + group := NewCloserStoreWithOptions(ExecuteAll, AnyTimes, Parallel, ClearAfterExecution, Workers(3)) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close Multiple times Sequential without retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + group := NewCloserStoreWithOptions(ExecuteAll, AnyTimes, Sequential, ClearAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) +} + +func TestCompoundGroup(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(17) + group := NewCloserStoreWithOptions(ExecuteAll, OnlyOnce, Sequential) + group.RegisterFunction(closerMock, closerMock, closerMock) + + compoundGroup := NewCompoundExecutionGroup(Parallel, RetainAfterExecution) + compoundGroup.RegisterFunction(WrapCloseToContextualFunc(WrapCloserIntoCloseFunc(closerMock))) + compoundGroup.RegisterExecutor(group) + compoundGroup.RegisterFunction(WrapCancelToContextualFunc(WrapContextualToCancelFunc(WrapCloseToContextualFunc(WrapCloserIntoCloseFunc(closerMock))))) + + assert.Equal(t, 3, compoundGroup.Len()) + + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + + t.Run("With cancelled context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + errortest.AssertError(t, compoundGroup.Execute(ctx), commonerrors.ErrCancelled) + }) + +} + +func TestStoreOptions_MergeWithOptions(t *testing.T) { + opts := WithOptions(Parallel).MergeWithOptions(OnlyOnce, ExecuteAll, Workers(5), Sequential) + assert.True(t, opts.onlyOnce) + assert.False(t, opts.stopOnFirstError) + assert.True(t, opts.sequential) + assert.Equal(t, 5, opts.workers) +} diff --git a/utils/parallelisation/onclose.go b/utils/parallelisation/onclose.go index 78142d8302..39079730c2 100644 --- a/utils/parallelisation/onclose.go +++ b/utils/parallelisation/onclose.go @@ -8,11 +8,11 @@ import ( ) type CloserStore struct { - store[io.Closer] + ExecutionGroup[io.Closer] } func (s *CloserStore) RegisterCloser(closerObj ...io.Closer) { - s.store.RegisterFunction(closerObj...) + s.ExecutionGroup.RegisterFunction(closerObj...) } func (s *CloserStore) Close() error { @@ -20,7 +20,7 @@ func (s *CloserStore) Close() error { } func (s *CloserStore) Len() int { - return s.store.Len() + return s.ExecutionGroup.Len() } // NewCloserStore returns a store of io.Closer object which will all be closed concurrently on Close(). The first error received will be returned @@ -29,17 +29,14 @@ func NewCloserStore(stopOnFirstError bool) *CloserStore { if stopOnFirstError { option = StopOnFirstError } - return NewCloserStoreWithOptions(option, Parallel, RetainAfterExecution) + return NewCloserStoreWithOptions(option, Parallel, OnlyOnce, RetainAfterExecution) } // NewCloserStoreWithOptions returns a store of io.Closer object which will all be closed on Close(). The first error received if any will be returned func NewCloserStoreWithOptions(opts ...StoreOption) *CloserStore { return &CloserStore{ - store: *newFunctionStore[io.Closer](func(_ context.Context, closerObj io.Closer) error { - if closerObj == nil { - return commonerrors.UndefinedVariable("closer object") - } - return closerObj.Close() + ExecutionGroup: *NewExecutionGroup[io.Closer](func(ctx context.Context, closerObj io.Closer) error { + return WrapCloseToContextualFunc(WrapCloserIntoCloseFunc(closerObj))(ctx) }, opts...), } } @@ -88,22 +85,65 @@ func CloseAllFuncAndCollateErrors(cs ...CloseFunc) error { type CloseFunc func() error +func (c CloseFunc) Close() error { + return c() +} + +func WrapCloserIntoCloseFunc(closer io.Closer) CloseFunc { + return func() error { + if closer == nil { + return commonerrors.UndefinedVariable("closer object") + } + return closer.Close() + } +} + +func WrapCancelToCloseFunc(f context.CancelFunc) CloseFunc { + return func() error { + f() + return nil + } +} + +func WrapCancelToContextualFunc(f context.CancelFunc) ContextualFunc { + return WrapCloseToContextualFunc(WrapCancelToCloseFunc(f)) +} + +func WrapCloseToContextualFunc(f CloseFunc) ContextualFunc { + return func(_ context.Context) error { + return f() + } +} + +func WrapCloseToCancelFunc(f CloseFunc) context.CancelFunc { + return func() { + _ = f() + } +} + +func WrapContextualToCloseFunc(f ContextualFunc) CloseFunc { + return func() error { + return f(context.Background()) + } +} + +func WrapContextualToCancelFunc(f ContextualFunc) context.CancelFunc { + return WrapCloseToCancelFunc(WrapContextualToCloseFunc(f)) +} + type CloseFunctionStore struct { - store[CloseFunc] + ExecutionGroup[CloseFunc] } func (s *CloseFunctionStore) RegisterCloseFunction(closerObj ...CloseFunc) { - s.store.RegisterFunction(closerObj...) + s.ExecutionGroup.RegisterFunction(closerObj...) } func (s *CloseFunctionStore) RegisterCancelStore(cancelStore *CancelFunctionStore) { if cancelStore == nil { return } - s.store.RegisterFunction(func() error { - cancelStore.Cancel() - return nil - }) + s.ExecutionGroup.RegisterFunction(WrapCancelToCloseFunc(cancelStore.Cancel)) } func (s *CloseFunctionStore) RegisterCancelFunction(cancelFunc ...context.CancelFunc) { @@ -117,13 +157,13 @@ func (s *CloseFunctionStore) Close() error { } func (s *CloseFunctionStore) Len() int { - return s.store.Len() + return s.ExecutionGroup.Len() } // NewCloseFunctionStore returns a store closing functions which will all be called on Close(). The first error received if any will be returned. func NewCloseFunctionStore(options ...StoreOption) *CloseFunctionStore { return &CloseFunctionStore{ - store: *newFunctionStore[CloseFunc](func(_ context.Context, closerObj CloseFunc) error { + ExecutionGroup: *NewExecutionGroup[CloseFunc](func(_ context.Context, closerObj CloseFunc) error { return closerObj() }, options...), } @@ -141,5 +181,10 @@ func NewConcurrentCloseFunctionStore(stopOnFirstError bool) *CloseFunctionStore if stopOnFirstError { option = StopOnFirstError } - return NewCloseFunctionStore(option, Parallel, RetainAfterExecution) + return NewCloseFunctionStore(option, Parallel, RetainAfterExecution, OnlyOnce) +} + +// NewCloseOnceGroup is the same as NewCloseFunctionStore but ensures any closing functions are only executed once. +func NewCloseOnceGroup(options ...StoreOption) *CloseFunctionStore { + return NewCloseFunctionStore(OnlyOnce(WithOptions(options...)).Options()...) } diff --git a/utils/parallelisation/onclose_test.go b/utils/parallelisation/onclose_test.go index 2287796c11..4fa2ddfc39 100644 --- a/utils/parallelisation/onclose_test.go +++ b/utils/parallelisation/onclose_test.go @@ -14,7 +14,7 @@ import ( "github.com/ARM-software/golang-utils/utils/parallelisation/mocks" ) -//go:generate go tool mockgen -destination=./mocks/mock_$GOPACKAGE.go -package=mocks io Closer +//go:generate go tool mockgen -destination=./mocks/mock_$GOPACKAGE.go -package=mocks io.Closer func TestCloseAll(t *testing.T) { t.Run("close", func(t *testing.T) { ctlr := gomock.NewController(t) @@ -72,6 +72,22 @@ func TestCloseAll(t *testing.T) { } +func TestCloseOnce(t *testing.T) { + t.Run("close every function once", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + closeError := commonerrors.ErrUnexpected + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(closeError).Times(3) + + group := NewCloseOnceGroup(Parallel, RetainAfterExecution) + group.RegisterCloseFunction(WrapCloserIntoCloseFunc(closerMock), WrapCloserIntoCloseFunc(closerMock), WrapCloserIntoCloseFunc(closerMock)) + errortest.AssertError(t, group.Close(), closeError) + require.NoError(t, group.Close()) + }) +} + func TestCancelOnClose(t *testing.T) { t.Run("parallel", func(t *testing.T) { closeStore := NewCloseFunctionStoreStore(true) @@ -136,7 +152,7 @@ func TestSequentialExecution(t *testing.T) { for i := range tests { test := tests[i] t.Run(fmt.Sprintf("%v-%#v", i, test.option), func(t *testing.T) { - opt := test.option(&StoreOptions{}) + opt := test.option(DefaultOptions()) t.Run("sequentially", func(t *testing.T) { closeStore := NewCloseFunctionStore(test.option, Sequential) ctx1, cancel1 := context.WithCancel(context.Background()) diff --git a/utils/parallelisation/parallelisation.go b/utils/parallelisation/parallelisation.go index 93fc596619..30e909afdd 100644 --- a/utils/parallelisation/parallelisation.go +++ b/utils/parallelisation/parallelisation.go @@ -17,11 +17,6 @@ import ( "github.com/ARM-software/golang-utils/utils/commonerrors" ) -// DetermineContextError determines what the context error is if any. -func DetermineContextError(ctx context.Context) error { - return commonerrors.ConvertContextError(ctx.Err()) -} - type result struct { Item any err error @@ -178,7 +173,7 @@ func RunActionWithTimeoutAndContext(ctx context.Context, timeout time.Duration, } // RunActionWithTimeoutAndCancelStore runs an action with timeout -// The cancel store is used just to register the cancel function so that it can be called on Cancel. +// The cancel ExecutionGroup is used just to register the cancel function so that it can be called on Cancel. func RunActionWithTimeoutAndCancelStore(ctx context.Context, timeout time.Duration, store *CancelFunctionStore, blockingAction func(context.Context) error) error { err := DetermineContextError(ctx) if err != nil {