From 999a6ab83f6829d4e2f43e4f126c342de7175d49 Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Wed, 20 Aug 2025 00:29:20 +0100 Subject: [PATCH 1/4] :sparkles: `[parallelisation]` New groups and Store options --- changes/20250820002654.feature | 1 + utils/parallelisation/cancel_functions.go | 244 +------------- utils/parallelisation/contextual.go | 39 +++ utils/parallelisation/contextual_test.go | 48 +++ utils/parallelisation/group.go | 368 ++++++++++++++++++++++ utils/parallelisation/group_test.go | 148 +++++++++ utils/parallelisation/onclose.go | 55 +++- utils/parallelisation/onclose_test.go | 2 +- utils/parallelisation/parallelisation.go | 7 +- 9 files changed, 655 insertions(+), 257 deletions(-) create mode 100644 changes/20250820002654.feature create mode 100644 utils/parallelisation/contextual.go create mode 100644 utils/parallelisation/contextual_test.go create mode 100644 utils/parallelisation/group.go create mode 100644 utils/parallelisation/group_test.go diff --git a/changes/20250820002654.feature b/changes/20250820002654.feature new file mode 100644 index 0000000000..cdae0a6d02 --- /dev/null +++ b/changes/20250820002654.feature @@ -0,0 +1 @@ +:sparkles: `[parallelisation]` Added new groups (ContextualFunctionGroup) and new Store options to configure the execution (number of workers, single execution, etc.) diff --git a/utils/parallelisation/cancel_functions.go b/utils/parallelisation/cancel_functions.go index daa6ab8885..fa2c526560 100644 --- a/utils/parallelisation/cancel_functions.go +++ b/utils/parallelisation/cancel_functions.go @@ -5,245 +5,14 @@ package parallelisation -import ( - "context" - - "github.com/sasha-s/go-deadlock" - "golang.org/x/sync/errgroup" - - "github.com/ARM-software/golang-utils/utils/commonerrors" - "github.com/ARM-software/golang-utils/utils/reflection" -) - -type StoreOptions struct { - clearOnExecution bool - stopOnFirstError bool - sequential bool - reverse bool - joinErrors bool -} -type StoreOption func(*StoreOptions) *StoreOptions - -// StopOnFirstError stops store execution on first error. -var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.stopOnFirstError = true - o.joinErrors = false - return o -} - -// JoinErrors will collate any errors which happened when executing functions in store. -// This option should not be used in combination to StopOnFirstError. -var JoinErrors StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.stopOnFirstError = false - o.joinErrors = true - return o -} - -// ExecuteAll executes all functions in the store even if an error is raised. the first error raised is then returned. -var ExecuteAll StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.stopOnFirstError = false - return o -} - -// ClearAfterExecution clears the store after execution. -var ClearAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.clearOnExecution = true - return o -} - -// RetainAfterExecution keep the store intact after execution (no reset). -var RetainAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.clearOnExecution = false - return o -} - -// Parallel ensures every function registered in the store is executed concurrently in the order they were registered. -var Parallel StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.sequential = false - return o -} - -// Sequential ensures every function registered in the store is executed sequentially in the order they were registered. -var Sequential StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.sequential = true - return o -} - -// SequentialInReverse ensures every function registered in the store is executed sequentially but in the reverse order they were registered. -var SequentialInReverse StoreOption = func(o *StoreOptions) *StoreOptions { - if o == nil { - return o - } - o.sequential = true - o.reverse = true - return o -} - -func newFunctionStore[T any](executeFunc func(context.Context, T) error, options ...StoreOption) *store[T] { - - opts := &StoreOptions{} - - for i := range options { - opts = options[i](opts) - } - return &store[T]{ - mu: deadlock.RWMutex{}, - functions: make([]T, 0), - executeFunc: executeFunc, - options: *opts, - } -} - -type store[T any] struct { - mu deadlock.RWMutex - functions []T - executeFunc func(ctx context.Context, element T) error - options StoreOptions -} - -func (s *store[T]) RegisterFunction(function ...T) { - defer s.mu.Unlock() - s.mu.Lock() - s.functions = append(s.functions, function...) -} - -func (s *store[T]) Len() int { - defer s.mu.RUnlock() - s.mu.RLock() - return len(s.functions) -} - -func (s *store[T]) Execute(ctx context.Context) (err error) { - defer s.mu.Unlock() - s.mu.Lock() - if reflection.IsEmpty(s.executeFunc) { - return commonerrors.New(commonerrors.ErrUndefined, "the store was not initialised correctly") - } - - if s.options.sequential { - err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse, s.options.joinErrors) - } else { - err = s.executeConcurrently(ctx, s.options.stopOnFirstError, s.options.joinErrors) - } - - if err == nil && s.options.clearOnExecution { - s.functions = make([]T, 0, len(s.functions)) - } - return -} - -func (s *store[T]) executeConcurrently(ctx context.Context, stopOnFirstError bool, collateErrors bool) error { - g, gCtx := errgroup.WithContext(ctx) - if !stopOnFirstError { - gCtx = ctx - } - funcNum := len(s.functions) - errCh := make(chan error, funcNum) - g.SetLimit(funcNum) - for i := range s.functions { - g.Go(func() error { - _, subErr := s.executeFunction(gCtx, s.functions[i]) - errCh <- subErr - return subErr - }) - } - err := g.Wait() - close(errCh) - if collateErrors { - collateErr := make([]error, funcNum) - i := 0 - for subErr := range errCh { - collateErr[i] = subErr - i++ - } - err = commonerrors.Join(collateErr...) - } - - return err -} - -func (s *store[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse, collateErrors bool) (err error) { - err = DetermineContextError(ctx) - if err != nil { - return - } - funcNum := len(s.functions) - collateErr := make([]error, funcNum) - if reverse { - for i := funcNum - 1; i >= 0; i-- { - shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) - collateErr[funcNum-i-1] = subErr - if shouldBreak { - err = subErr - return - } - if subErr != nil && err == nil { - err = subErr - if stopOnFirstError { - return - } - } - } - } else { - for i := range s.functions { - shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) - collateErr[i] = subErr - if shouldBreak { - err = subErr - return - } - if subErr != nil && err == nil { - err = subErr - if stopOnFirstError { - return - } - } - } - } - - if collateErrors { - err = commonerrors.Join(collateErr...) - } - return -} - -func (s *store[T]) executeFunction(ctx context.Context, element T) (mustBreak bool, err error) { - err = DetermineContextError(ctx) - if err != nil { - mustBreak = true - return - } - err = s.executeFunc(ctx, element) - return -} +import "context" type CancelFunctionStore struct { - store[context.CancelFunc] + ExecutionGroup[context.CancelFunc] } func (s *CancelFunctionStore) RegisterCancelFunction(cancel ...context.CancelFunc) { - s.store.RegisterFunction(cancel...) + s.ExecutionGroup.RegisterFunction(cancel...) } // Cancel will execute the cancel functions in the store. Any errors will be ignored and Execute() is recommended if you need to know if a cancellation failed @@ -252,15 +21,14 @@ func (s *CancelFunctionStore) Cancel() { } func (s *CancelFunctionStore) Len() int { - return s.store.Len() + return s.ExecutionGroup.Len() } // NewCancelFunctionsStore creates a store for cancel functions. Whatever the options passed, all cancel functions will be executed and cleared. In other words, options `RetainAfterExecution` and `StopOnFirstError` would be discarded if selected to create the Cancel store func NewCancelFunctionsStore(options ...StoreOption) *CancelFunctionStore { return &CancelFunctionStore{ - store: *newFunctionStore[context.CancelFunc](func(_ context.Context, cancelFunc context.CancelFunc) error { - cancelFunc() - return nil + ExecutionGroup: *NewExecutionGroup[context.CancelFunc](func(ctx context.Context, cancelFunc context.CancelFunc) error { + return WrapCancelToContextualFunc(cancelFunc)(ctx) }, append(options, ClearAfterExecution, ExecuteAll)...), } } diff --git a/utils/parallelisation/contextual.go b/utils/parallelisation/contextual.go new file mode 100644 index 0000000000..68267490a0 --- /dev/null +++ b/utils/parallelisation/contextual.go @@ -0,0 +1,39 @@ +package parallelisation + +import ( + "context" + + "github.com/ARM-software/golang-utils/utils/commonerrors" +) + +// DetermineContextError determines what the context error is if any. +func DetermineContextError(ctx context.Context) error { + return commonerrors.ConvertContextError(ctx.Err()) +} + +type ContextualFunctionGroup struct { + ExecutionGroup[ContextualFunc] +} + +// NewContextualGroup returns a group executing contextual functions. +func NewContextualGroup(options ...StoreOption) *ContextualFunctionGroup { + return &ContextualFunctionGroup{ + ExecutionGroup: *NewExecutionGroup[ContextualFunc](func(ctx context.Context, contextualF ContextualFunc) error { + return contextualF(ctx) + }, options...), + } +} + +// ForEach executes all the contextual functions according to the store options and returns an error if one occurred. +func ForEach(ctx context.Context, executionOptions *StoreOptions, contextualFunc ...ContextualFunc) error { + group := NewContextualGroup(ExecuteAll(executionOptions).Options()...) + group.RegisterFunction(contextualFunc...) + return group.Execute(ctx) +} + +// BreakOnError executes each functions in the group until an error is found or the context gets cancelled. +func BreakOnError(ctx context.Context, executionOptions *StoreOptions, contextualFunc ...ContextualFunc) error { + group := NewContextualGroup(StopOnFirstError(executionOptions).Options()...) + group.RegisterFunction(contextualFunc...) + return group.Execute(ctx) +} diff --git a/utils/parallelisation/contextual_test.go b/utils/parallelisation/contextual_test.go new file mode 100644 index 0000000000..5391831d9c --- /dev/null +++ b/utils/parallelisation/contextual_test.go @@ -0,0 +1,48 @@ +package parallelisation + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" +) + +func TestForEach(t *testing.T) { + cancelFunc := func() {} + t.Run("close with 1 error", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + + errortest.AssertError(t, ForEach(context.Background(), WithOptions(Parallel), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + }) + + t.Run("close with 1 error but error collection", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + errortest.AssertError(t, ForEach(context.Background(), WithOptions(Parallel, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + }) + + t.Run("close with 1 error but error collection", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + errortest.AssertError(t, ForEach(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + }) + + t.Run("close with 1 error but sequential", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + errortest.AssertError(t, ForEach(context.Background(), WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + errortest.AssertError(t, BreakOnError(context.Background(), WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + }) + + t.Run("close with cancellation", func(t *testing.T) { + closeError := commonerrors.ErrUnexpected + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + errortest.AssertError(t, ForEach(cancelCtx, WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), commonerrors.ErrCancelled) + errortest.AssertError(t, BreakOnError(cancelCtx, WithOptions(SequentialInReverse, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc)), commonerrors.ErrCancelled) + }) + + t.Run("break on error with no error", func(t *testing.T) { + require.NoError(t, BreakOnError(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc))) + }) +} diff --git a/utils/parallelisation/group.go b/utils/parallelisation/group.go new file mode 100644 index 0000000000..7e50391852 --- /dev/null +++ b/utils/parallelisation/group.go @@ -0,0 +1,368 @@ +package parallelisation + +import ( + "context" + "math" + + "github.com/sasha-s/go-deadlock" + "go.uber.org/atomic" + "golang.org/x/sync/errgroup" + + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/reflection" + "github.com/ARM-software/golang-utils/utils/safecast" +) + +type StoreOptions struct { + clearOnExecution bool + stopOnFirstError bool + sequential bool + reverse bool + joinErrors bool + onlyOnce bool + workers int +} + +func (o *StoreOptions) Merge(opts *StoreOptions) *StoreOptions { + if opts == nil { + return o + } + return &StoreOptions{ + clearOnExecution: opts.clearOnExecution || o.clearOnExecution, + stopOnFirstError: opts.stopOnFirstError || o.stopOnFirstError, + sequential: opts.sequential || o.sequential, + reverse: opts.reverse || o.reverse, + joinErrors: opts.joinErrors || o.joinErrors, + onlyOnce: opts.onlyOnce || o.onlyOnce, + workers: safecast.ToInt(math.Max(float64(opts.workers), float64(o.workers))), + } +} + +func (o *StoreOptions) Options() []StoreOption { + return []StoreOption{ + func(opts *StoreOptions) *StoreOptions { + op := o + if op == nil { + op = &StoreOptions{} + } + return op.Merge(opts) + }, + } +} + +type StoreOption func(*StoreOptions) *StoreOptions + +// StopOnFirstError stops ExecutionGroup execution on first error. +var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.stopOnFirstError = true + o.joinErrors = false + return o +} + +// JoinErrors will collate any errors which happened when executing functions in ExecutionGroup. +// This option should not be used in combination to StopOnFirstError. +var JoinErrors StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.stopOnFirstError = false + o.joinErrors = true + return o +} + +// OnlyOnce will ensure the function are executed only once if they do. +var OnlyOnce StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.onlyOnce = true + return o +} + +// AnyTimes will allow the functions to be executed as often that they might be. +var AnyTimes StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.onlyOnce = false + return o +} + +// ExecuteAll executes all functions in the ExecutionGroup even if an error is raised. the first error raised is then returned. +var ExecuteAll StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.stopOnFirstError = false + return o +} + +// ClearAfterExecution clears the ExecutionGroup after execution. +var ClearAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.clearOnExecution = true + return o +} + +// RetainAfterExecution keep the ExecutionGroup intact after execution (no reset). +var RetainAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.clearOnExecution = false + return o +} + +// Parallel ensures every function registered in the ExecutionGroup is executed concurrently in the order they were registered. +var Parallel StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.sequential = false + return o +} + +// Workers defines a limit number of workers for executing the function registered in the ExecutionGroup. +func Workers(workers int) StoreOption { + return func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.workers = workers + o.sequential = false + return o + } +} + +// Sequential ensures every function registered in the ExecutionGroup is executed sequentially in the order they were registered. +var Sequential StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.sequential = true + return o +} + +// SequentialInReverse ensures every function registered in the ExecutionGroup is executed sequentially but in the reverse order they were registered. +var SequentialInReverse StoreOption = func(o *StoreOptions) *StoreOptions { + if o == nil { + o = &StoreOptions{} + } + o.sequential = true + o.reverse = true + return o +} + +// WithOptions defines a store configuration. +func WithOptions(option ...StoreOption) (opts *StoreOptions) { + for i := range option { + opts = option[i](opts) + } + if opts == nil { + opts = &StoreOptions{} + } + return +} + +// NewExecutionGroup returns an execution group which executes functions according to store options. +func NewExecutionGroup[T any](executeFunc ExecuteFunc[T], options ...StoreOption) *ExecutionGroup[T] { + + opts := WithOptions(options...) + return &ExecutionGroup[T]{ + mu: deadlock.RWMutex{}, + functions: make([]wrappedElement[T], 0), + executeFunc: executeFunc, + options: *opts, + } +} + +type ExecuteFunc[T any] func(ctx context.Context, element T) error + +type ExecutionGroup[T any] struct { + mu deadlock.RWMutex + functions []wrappedElement[T] + executeFunc ExecuteFunc[T] + options StoreOptions +} + +// RegisterFunction registers functions to the group. +func (s *ExecutionGroup[T]) RegisterFunction(function ...T) { + defer s.mu.Unlock() + s.mu.Lock() + wrapped := make([]wrappedElement[T], len(function)) + for i := range function { + wrapped[i] = newWrapped(function[i], s.options.onlyOnce) + } + s.functions = append(s.functions, wrapped...) +} + +func (s *ExecutionGroup[T]) Len() int { + defer s.mu.RUnlock() + s.mu.RLock() + return len(s.functions) +} + +// Execute executes all the function in the group according to store options. +func (s *ExecutionGroup[T]) Execute(ctx context.Context) (err error) { + defer s.mu.Unlock() + s.mu.Lock() + if reflection.IsEmpty(s.executeFunc) { + return commonerrors.New(commonerrors.ErrUndefined, "the group was not initialised correctly") + } + + if s.options.sequential { + err = s.executeSequentially(ctx, s.options.stopOnFirstError, s.options.reverse, s.options.joinErrors) + } else { + err = s.executeConcurrently(ctx, s.options.stopOnFirstError, s.options.joinErrors) + } + + if err == nil && s.options.clearOnExecution { + s.functions = make([]wrappedElement[T], 0, len(s.functions)) + } + return +} + +func (s *ExecutionGroup[T]) executeConcurrently(ctx context.Context, stopOnFirstError bool, collateErrors bool) error { + g, gCtx := errgroup.WithContext(ctx) + if !stopOnFirstError { + gCtx = ctx + } + funcNum := len(s.functions) + workers := s.options.workers + if workers <= 0 { + workers = funcNum + } + errCh := make(chan error, funcNum) + + g.SetLimit(workers) + for i := range s.functions { + g.Go(func() error { + _, subErr := s.executeFunction(gCtx, s.functions[i]) + errCh <- subErr + return subErr + }) + } + err := g.Wait() + close(errCh) + if collateErrors { + collateErr := make([]error, funcNum) + i := 0 + for subErr := range errCh { + collateErr[i] = subErr + i++ + } + err = commonerrors.Join(collateErr...) + } + + return err +} + +func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirstError, reverse, collateErrors bool) (err error) { + err = DetermineContextError(ctx) + if err != nil { + return + } + funcNum := len(s.functions) + collateErr := make([]error, funcNum) + if reverse { + for i := funcNum - 1; i >= 0; i-- { + shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + collateErr[funcNum-i-1] = subErr + if shouldBreak { + err = subErr + return + } + if subErr != nil && err == nil { + err = subErr + if stopOnFirstError { + return + } + } + } + } else { + for i := range s.functions { + shouldBreak, subErr := s.executeFunction(ctx, s.functions[i]) + collateErr[i] = subErr + if shouldBreak { + err = subErr + return + } + if subErr != nil && err == nil { + err = subErr + if stopOnFirstError { + return + } + } + } + } + + if collateErrors { + err = commonerrors.Join(collateErr...) + } + return +} + +func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, w wrappedElement[T]) (mustBreak bool, err error) { + err = DetermineContextError(ctx) + if err != nil { + mustBreak = true + return + } + if w == nil { + err = commonerrors.UndefinedVariable("function element") + mustBreak = true + return + } + err = w.Execute(ctx, s.executeFunc) + + return +} + +type wrappedElement[T any] interface { + Execute(ctx context.Context, f ExecuteFunc[T]) error +} +type basicWrap[T any] struct { + value T +} + +func (w *basicWrap[T]) Execute(ctx context.Context, f ExecuteFunc[T]) error { + return f(ctx, w.value) +} + +func newBasicWrap[T any](e T) wrappedElement[T] { + return &basicWrap[T]{ + value: e, + } +} + +func newOnce[T any](e T) wrappedElement[T] { + return &once[T]{ + wrappedElement: newBasicWrap[T](e), + once: atomic.NewBool(false), + } +} + +type once[T any] struct { + wrappedElement[T] + once *atomic.Bool +} + +func (w *once[T]) Execute(ctx context.Context, f ExecuteFunc[T]) error { + if !w.once.Swap(true) { + return w.wrappedElement.Execute(ctx, f) + } + return nil +} + +func newWrapped[T any](e T, once bool) wrappedElement[T] { + if once { + return newOnce[T](e) + } else { + return newBasicWrap[T](e) + } +} diff --git a/utils/parallelisation/group_test.go b/utils/parallelisation/group_test.go new file mode 100644 index 0000000000..b46563e13a --- /dev/null +++ b/utils/parallelisation/group_test.go @@ -0,0 +1,148 @@ +package parallelisation + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/ARM-software/golang-utils/utils/parallelisation/mocks" +) + +func TestExecutionTimes(t *testing.T) { + + t.Run("close only Once Parallel with retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + + group := NewCloserStoreWithOptions(ExecuteAll, Parallel, OnlyOnce, RetainAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close only Once Sequential with retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + + group := NewCloserStoreWithOptions(ExecuteAll, OnlyOnce, Sequential, RetainAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close only Once Parallel without retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + + group := NewCloserStoreWithOptions(ExecuteAll, Parallel, OnlyOnce, ClearAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close only Once Sequential without retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + + group := NewCloserStoreWithOptions(ExecuteAll, OnlyOnce, Sequential, ClearAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + + t.Run("close Multiple times Parallel", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(21) + group := NewCloserStoreWithOptions(ExecuteAll, AnyTimes, Parallel, RetainAfterExecution, Workers(3)) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close Multiple times Sequential", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(21) + group := NewCloserStoreWithOptions(ExecuteAll, AnyTimes, Sequential, RetainAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + + t.Run("close Multiple times Parallel without retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + group := NewCloserStoreWithOptions(ExecuteAll, AnyTimes, Parallel, ClearAfterExecution, Workers(3)) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) + t.Run("close Multiple times Sequential without retention", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(3) + group := NewCloserStoreWithOptions(ExecuteAll, AnyTimes, Sequential, ClearAfterExecution) + group.RegisterFunction(closerMock, closerMock, closerMock) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + require.NoError(t, group.Close()) + }) +} diff --git a/utils/parallelisation/onclose.go b/utils/parallelisation/onclose.go index 78142d8302..c03a29b7b2 100644 --- a/utils/parallelisation/onclose.go +++ b/utils/parallelisation/onclose.go @@ -8,11 +8,11 @@ import ( ) type CloserStore struct { - store[io.Closer] + ExecutionGroup[io.Closer] } func (s *CloserStore) RegisterCloser(closerObj ...io.Closer) { - s.store.RegisterFunction(closerObj...) + s.ExecutionGroup.RegisterFunction(closerObj...) } func (s *CloserStore) Close() error { @@ -20,7 +20,7 @@ func (s *CloserStore) Close() error { } func (s *CloserStore) Len() int { - return s.store.Len() + return s.ExecutionGroup.Len() } // NewCloserStore returns a store of io.Closer object which will all be closed concurrently on Close(). The first error received will be returned @@ -35,7 +35,7 @@ func NewCloserStore(stopOnFirstError bool) *CloserStore { // NewCloserStoreWithOptions returns a store of io.Closer object which will all be closed on Close(). The first error received if any will be returned func NewCloserStoreWithOptions(opts ...StoreOption) *CloserStore { return &CloserStore{ - store: *newFunctionStore[io.Closer](func(_ context.Context, closerObj io.Closer) error { + ExecutionGroup: *NewExecutionGroup[io.Closer](func(_ context.Context, closerObj io.Closer) error { if closerObj == nil { return commonerrors.UndefinedVariable("closer object") } @@ -86,24 +86,55 @@ func CloseAllFuncAndCollateErrors(cs ...CloseFunc) error { return group.Close() } +type ContextualFunc func(ctx context.Context) error type CloseFunc func() error +func WrapCancelToCloseFunc(f context.CancelFunc) CloseFunc { + return func() error { + f() + return nil + } +} + +func WrapCancelToContextualFunc(f context.CancelFunc) ContextualFunc { + return WrapCloseToContextualFunc(WrapCancelToCloseFunc(f)) +} + +func WrapCloseToContextualFunc(f CloseFunc) ContextualFunc { + return func(_ context.Context) error { + return f() + } +} + +func WrapCloseToCancelFunc(f CloseFunc) context.CancelFunc { + return func() { + _ = f() + } +} + +func WrapContextualToCloseFunc(f ContextualFunc) CloseFunc { + return func() error { + return f(context.Background()) + } +} + +func WrapContextualToCancelFunc(f ContextualFunc) context.CancelFunc { + return WrapCloseToCancelFunc(WrapContextualToCloseFunc(f)) +} + type CloseFunctionStore struct { - store[CloseFunc] + ExecutionGroup[CloseFunc] } func (s *CloseFunctionStore) RegisterCloseFunction(closerObj ...CloseFunc) { - s.store.RegisterFunction(closerObj...) + s.ExecutionGroup.RegisterFunction(closerObj...) } func (s *CloseFunctionStore) RegisterCancelStore(cancelStore *CancelFunctionStore) { if cancelStore == nil { return } - s.store.RegisterFunction(func() error { - cancelStore.Cancel() - return nil - }) + s.ExecutionGroup.RegisterFunction(WrapCancelToCloseFunc(cancelStore.Cancel)) } func (s *CloseFunctionStore) RegisterCancelFunction(cancelFunc ...context.CancelFunc) { @@ -117,13 +148,13 @@ func (s *CloseFunctionStore) Close() error { } func (s *CloseFunctionStore) Len() int { - return s.store.Len() + return s.ExecutionGroup.Len() } // NewCloseFunctionStore returns a store closing functions which will all be called on Close(). The first error received if any will be returned. func NewCloseFunctionStore(options ...StoreOption) *CloseFunctionStore { return &CloseFunctionStore{ - store: *newFunctionStore[CloseFunc](func(_ context.Context, closerObj CloseFunc) error { + ExecutionGroup: *NewExecutionGroup[CloseFunc](func(_ context.Context, closerObj CloseFunc) error { return closerObj() }, options...), } diff --git a/utils/parallelisation/onclose_test.go b/utils/parallelisation/onclose_test.go index 2287796c11..e92574e9b9 100644 --- a/utils/parallelisation/onclose_test.go +++ b/utils/parallelisation/onclose_test.go @@ -14,7 +14,7 @@ import ( "github.com/ARM-software/golang-utils/utils/parallelisation/mocks" ) -//go:generate go tool mockgen -destination=./mocks/mock_$GOPACKAGE.go -package=mocks io Closer +//go:generate go tool mockgen -destination=./mocks/mock_$GOPACKAGE.go -package=mocks io.Closer func TestCloseAll(t *testing.T) { t.Run("close", func(t *testing.T) { ctlr := gomock.NewController(t) diff --git a/utils/parallelisation/parallelisation.go b/utils/parallelisation/parallelisation.go index 93fc596619..30e909afdd 100644 --- a/utils/parallelisation/parallelisation.go +++ b/utils/parallelisation/parallelisation.go @@ -17,11 +17,6 @@ import ( "github.com/ARM-software/golang-utils/utils/commonerrors" ) -// DetermineContextError determines what the context error is if any. -func DetermineContextError(ctx context.Context) error { - return commonerrors.ConvertContextError(ctx.Err()) -} - type result struct { Item any err error @@ -178,7 +173,7 @@ func RunActionWithTimeoutAndContext(ctx context.Context, timeout time.Duration, } // RunActionWithTimeoutAndCancelStore runs an action with timeout -// The cancel store is used just to register the cancel function so that it can be called on Cancel. +// The cancel ExecutionGroup is used just to register the cancel function so that it can be called on Cancel. func RunActionWithTimeoutAndCancelStore(ctx context.Context, timeout time.Duration, store *CancelFunctionStore, blockingAction func(context.Context) error) error { err := DetermineContextError(ctx) if err != nil { From 81abfd7acf686195df5450090ffc5b1d7c18ac2e Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Wed, 20 Aug 2025 14:20:03 +0100 Subject: [PATCH 2/4] :sparkles: Address review comments --- changes/20250820140853.feature | 1 + utils/parallelisation/contextual.go | 2 + utils/parallelisation/contextual_test.go | 7 +- utils/parallelisation/group.go | 111 ++++++++++++++++++----- utils/parallelisation/group_test.go | 44 +++++++++ utils/parallelisation/onclose.go | 30 ++++-- utils/parallelisation/onclose_test.go | 17 +++- 7 files changed, 179 insertions(+), 33 deletions(-) create mode 100644 changes/20250820140853.feature diff --git a/changes/20250820140853.feature b/changes/20250820140853.feature new file mode 100644 index 0000000000..d79a2a989e --- /dev/null +++ b/changes/20250820140853.feature @@ -0,0 +1 @@ +:sparkles: `[parallelisation]` Added new compound execution group to support nested execution groups diff --git a/utils/parallelisation/contextual.go b/utils/parallelisation/contextual.go index 68267490a0..dbd3d99c0a 100644 --- a/utils/parallelisation/contextual.go +++ b/utils/parallelisation/contextual.go @@ -11,6 +11,8 @@ func DetermineContextError(ctx context.Context) error { return commonerrors.ConvertContextError(ctx.Err()) } +type ContextualFunc func(ctx context.Context) error + type ContextualFunctionGroup struct { ExecutionGroup[ContextualFunc] } diff --git a/utils/parallelisation/contextual_test.go b/utils/parallelisation/contextual_test.go index 5391831d9c..dcbe4b8e00 100644 --- a/utils/parallelisation/contextual_test.go +++ b/utils/parallelisation/contextual_test.go @@ -20,10 +20,10 @@ func TestForEach(t *testing.T) { t.Run("close with 1 error but error collection", func(t *testing.T) { closeError := commonerrors.ErrUnexpected - errortest.AssertError(t, ForEach(context.Background(), WithOptions(Parallel, JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) + errortest.AssertError(t, ForEach(context.Background(), WithOptions(Parallel), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) }) - t.Run("close with 1 error but error collection", func(t *testing.T) { + t.Run("close with 1 error and limited number of parallel workers", func(t *testing.T) { closeError := commonerrors.ErrUnexpected errortest.AssertError(t, ForEach(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCloseToContextualFunc(func() error { return closeError }), WrapCancelToContextualFunc(cancelFunc)), closeError) }) @@ -45,4 +45,7 @@ func TestForEach(t *testing.T) { t.Run("break on error with no error", func(t *testing.T) { require.NoError(t, BreakOnError(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc))) }) + t.Run("for each with no error", func(t *testing.T) { + require.NoError(t, ForEach(context.Background(), WithOptions(Workers(5), JoinErrors), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc), WrapCancelToContextualFunc(cancelFunc))) + }) } diff --git a/utils/parallelisation/group.go b/utils/parallelisation/group.go index 7e50391852..a7310641d6 100644 --- a/utils/parallelisation/group.go +++ b/utils/parallelisation/group.go @@ -23,19 +23,41 @@ type StoreOptions struct { workers int } +func (o *StoreOptions) Default() *StoreOptions { + o.clearOnExecution = false + o.stopOnFirstError = false + o.sequential = false + o.reverse = false + o.joinErrors = false + o.onlyOnce = false + o.workers = 0 + return o +} + func (o *StoreOptions) Merge(opts *StoreOptions) *StoreOptions { if opts == nil { return o } - return &StoreOptions{ - clearOnExecution: opts.clearOnExecution || o.clearOnExecution, - stopOnFirstError: opts.stopOnFirstError || o.stopOnFirstError, - sequential: opts.sequential || o.sequential, - reverse: opts.reverse || o.reverse, - joinErrors: opts.joinErrors || o.joinErrors, - onlyOnce: opts.onlyOnce || o.onlyOnce, - workers: safecast.ToInt(math.Max(float64(opts.workers), float64(o.workers))), - } + o.clearOnExecution = opts.clearOnExecution || o.clearOnExecution + o.stopOnFirstError = opts.stopOnFirstError || o.stopOnFirstError + o.sequential = opts.sequential || o.sequential + o.reverse = opts.reverse || o.reverse + o.joinErrors = opts.joinErrors || o.joinErrors + o.onlyOnce = opts.onlyOnce || o.onlyOnce + o.workers = safecast.ToInt(math.Max(float64(opts.workers), float64(o.workers))) + return o +} + +func (o *StoreOptions) MergeWithOptions(opt ...StoreOption) *StoreOptions { + return o.Merge(WithOptions(opt...)) +} + +func (o *StoreOptions) Overwrite(opts *StoreOptions) *StoreOptions { + return o.Default().Merge(opts) +} + +func (o *StoreOptions) WithOptions(opts ...StoreOption) *StoreOptions { + return o.Overwrite(WithOptions(opts...)) } func (o *StoreOptions) Options() []StoreOption { @@ -43,7 +65,7 @@ func (o *StoreOptions) Options() []StoreOption { func(opts *StoreOptions) *StoreOptions { op := o if op == nil { - op = &StoreOptions{} + op = DefaultOptions() } return op.Merge(opts) }, @@ -55,7 +77,7 @@ type StoreOption func(*StoreOptions) *StoreOptions // StopOnFirstError stops ExecutionGroup execution on first error. var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.stopOnFirstError = true o.joinErrors = false @@ -66,7 +88,7 @@ var StopOnFirstError StoreOption = func(o *StoreOptions) *StoreOptions { // This option should not be used in combination to StopOnFirstError. var JoinErrors StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.stopOnFirstError = false o.joinErrors = true @@ -76,7 +98,7 @@ var JoinErrors StoreOption = func(o *StoreOptions) *StoreOptions { // OnlyOnce will ensure the function are executed only once if they do. var OnlyOnce StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.onlyOnce = true return o @@ -85,7 +107,7 @@ var OnlyOnce StoreOption = func(o *StoreOptions) *StoreOptions { // AnyTimes will allow the functions to be executed as often that they might be. var AnyTimes StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.onlyOnce = false return o @@ -94,7 +116,7 @@ var AnyTimes StoreOption = func(o *StoreOptions) *StoreOptions { // ExecuteAll executes all functions in the ExecutionGroup even if an error is raised. the first error raised is then returned. var ExecuteAll StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.stopOnFirstError = false return o @@ -103,7 +125,7 @@ var ExecuteAll StoreOption = func(o *StoreOptions) *StoreOptions { // ClearAfterExecution clears the ExecutionGroup after execution. var ClearAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.clearOnExecution = true return o @@ -112,7 +134,7 @@ var ClearAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { // RetainAfterExecution keep the ExecutionGroup intact after execution (no reset). var RetainAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.clearOnExecution = false return o @@ -121,7 +143,7 @@ var RetainAfterExecution StoreOption = func(o *StoreOptions) *StoreOptions { // Parallel ensures every function registered in the ExecutionGroup is executed concurrently in the order they were registered. var Parallel StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.sequential = false return o @@ -131,7 +153,7 @@ var Parallel StoreOption = func(o *StoreOptions) *StoreOptions { func Workers(workers int) StoreOption { return func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.workers = workers o.sequential = false @@ -142,7 +164,7 @@ func Workers(workers int) StoreOption { // Sequential ensures every function registered in the ExecutionGroup is executed sequentially in the order they were registered. var Sequential StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.sequential = true return o @@ -151,7 +173,7 @@ var Sequential StoreOption = func(o *StoreOptions) *StoreOptions { // SequentialInReverse ensures every function registered in the ExecutionGroup is executed sequentially but in the reverse order they were registered. var SequentialInReverse StoreOption = func(o *StoreOptions) *StoreOptions { if o == nil { - o = &StoreOptions{} + o = DefaultOptions() } o.sequential = true o.reverse = true @@ -164,11 +186,34 @@ func WithOptions(option ...StoreOption) (opts *StoreOptions) { opts = option[i](opts) } if opts == nil { - opts = &StoreOptions{} + opts = DefaultOptions() } return } +// DefaultOptions returns the default store configuration +func DefaultOptions() *StoreOptions { + opts := &StoreOptions{} + return opts.Default() +} + +type IExecutor interface { + // Execute executes all the functions in the group. + Execute(ctx context.Context) error +} + +type IExecutionGroup[T any] interface { + IExecutor + RegisterFunction(function ...T) + Len() int +} + +type ICompoundExecutionGroup[T any] interface { + IExecutionGroup[T] + // RegisterExecutor registers executors of any kind to the group: they could be functions or sub-groups. + RegisterExecutor(executor ...IExecutor) +} + // NewExecutionGroup returns an execution group which executes functions according to store options. func NewExecutionGroup[T any](executeFunc ExecuteFunc[T], options ...StoreOption) *ExecutionGroup[T] { @@ -366,3 +411,25 @@ func newWrapped[T any](e T, once bool) wrappedElement[T] { return newBasicWrap[T](e) } } + +var _ ICompoundExecutionGroup[ContextualFunc] = &CompoundExecutionGroup{} + +// NewCompoundExecutionGroup returns an execution group made of executors +func NewCompoundExecutionGroup(options ...StoreOption) *CompoundExecutionGroup { + return &CompoundExecutionGroup{ + ContextualFunctionGroup: *NewContextualGroup(options...), + } +} + +type CompoundExecutionGroup struct { + ContextualFunctionGroup +} + +// RegisterExecutor registers executors +func (g *CompoundExecutionGroup) RegisterExecutor(group ...IExecutor) { + for i := range group { + g.RegisterFunction(func(ctx context.Context) error { + return group[i].Execute(ctx) + }) + } +} diff --git a/utils/parallelisation/group_test.go b/utils/parallelisation/group_test.go index b46563e13a..240fffaa5e 100644 --- a/utils/parallelisation/group_test.go +++ b/utils/parallelisation/group_test.go @@ -1,11 +1,15 @@ package parallelisation import ( + "context" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" "github.com/ARM-software/golang-utils/utils/parallelisation/mocks" ) @@ -146,3 +150,43 @@ func TestExecutionTimes(t *testing.T) { require.NoError(t, group.Close()) }) } + +func TestCompoundGroup(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(nil).Times(17) + group := NewCloserStoreWithOptions(ExecuteAll, OnlyOnce, Sequential) + group.RegisterFunction(closerMock, closerMock, closerMock) + + compoundGroup := NewCompoundExecutionGroup(Parallel, RetainAfterExecution) + compoundGroup.RegisterFunction(WrapCloseToContextualFunc(WrapCloserIntoCloseFunc(closerMock))) + compoundGroup.RegisterExecutor(group) + compoundGroup.RegisterFunction(WrapCancelToContextualFunc(WrapContextualToCancelFunc(WrapCloseToContextualFunc(WrapCloserIntoCloseFunc(closerMock))))) + + assert.Equal(t, 3, compoundGroup.Len()) + + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + require.NoError(t, compoundGroup.Execute(context.Background())) + + t.Run("With cancelled context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + errortest.AssertError(t, compoundGroup.Execute(ctx), commonerrors.ErrCancelled) + }) + +} + +func TestStoreOptions_MergeWithOptions(t *testing.T) { + opts := WithOptions(Parallel).MergeWithOptions(OnlyOnce, ExecuteAll, Workers(5), Sequential) + assert.True(t, opts.onlyOnce) + assert.False(t, opts.stopOnFirstError) + assert.True(t, opts.sequential) + assert.Equal(t, 5, opts.workers) +} diff --git a/utils/parallelisation/onclose.go b/utils/parallelisation/onclose.go index c03a29b7b2..39079730c2 100644 --- a/utils/parallelisation/onclose.go +++ b/utils/parallelisation/onclose.go @@ -29,17 +29,14 @@ func NewCloserStore(stopOnFirstError bool) *CloserStore { if stopOnFirstError { option = StopOnFirstError } - return NewCloserStoreWithOptions(option, Parallel, RetainAfterExecution) + return NewCloserStoreWithOptions(option, Parallel, OnlyOnce, RetainAfterExecution) } // NewCloserStoreWithOptions returns a store of io.Closer object which will all be closed on Close(). The first error received if any will be returned func NewCloserStoreWithOptions(opts ...StoreOption) *CloserStore { return &CloserStore{ - ExecutionGroup: *NewExecutionGroup[io.Closer](func(_ context.Context, closerObj io.Closer) error { - if closerObj == nil { - return commonerrors.UndefinedVariable("closer object") - } - return closerObj.Close() + ExecutionGroup: *NewExecutionGroup[io.Closer](func(ctx context.Context, closerObj io.Closer) error { + return WrapCloseToContextualFunc(WrapCloserIntoCloseFunc(closerObj))(ctx) }, opts...), } } @@ -86,9 +83,21 @@ func CloseAllFuncAndCollateErrors(cs ...CloseFunc) error { return group.Close() } -type ContextualFunc func(ctx context.Context) error type CloseFunc func() error +func (c CloseFunc) Close() error { + return c() +} + +func WrapCloserIntoCloseFunc(closer io.Closer) CloseFunc { + return func() error { + if closer == nil { + return commonerrors.UndefinedVariable("closer object") + } + return closer.Close() + } +} + func WrapCancelToCloseFunc(f context.CancelFunc) CloseFunc { return func() error { f() @@ -172,5 +181,10 @@ func NewConcurrentCloseFunctionStore(stopOnFirstError bool) *CloseFunctionStore if stopOnFirstError { option = StopOnFirstError } - return NewCloseFunctionStore(option, Parallel, RetainAfterExecution) + return NewCloseFunctionStore(option, Parallel, RetainAfterExecution, OnlyOnce) +} + +// NewCloseOnceGroup is the same as NewCloseFunctionStore but ensures any closing functions are only executed once. +func NewCloseOnceGroup(options ...StoreOption) *CloseFunctionStore { + return NewCloseFunctionStore(OnlyOnce(WithOptions(options...)).Options()...) } diff --git a/utils/parallelisation/onclose_test.go b/utils/parallelisation/onclose_test.go index e92574e9b9..9a81d61150 100644 --- a/utils/parallelisation/onclose_test.go +++ b/utils/parallelisation/onclose_test.go @@ -72,6 +72,21 @@ func TestCloseAll(t *testing.T) { } +func TestCloseOnce(t *testing.T) { + t.Run("close every function once", func(t *testing.T) { + ctlr := gomock.NewController(t) + defer ctlr.Finish() + closeError := commonerrors.ErrUnexpected + + closerMock := mocks.NewMockCloser(ctlr) + closerMock.EXPECT().Close().Return(closeError).Times(3) + + group := NewCloseOnceGroup(Parallel, RetainAfterExecution) + group.RegisterCloseFunction(WrapCloserIntoCloseFunc(closerMock), WrapCloserIntoCloseFunc(closerMock), WrapCloserIntoCloseFunc(closerMock)) + errortest.AssertError(t, group.Close(), closeError) + }) +} + func TestCancelOnClose(t *testing.T) { t.Run("parallel", func(t *testing.T) { closeStore := NewCloseFunctionStoreStore(true) @@ -136,7 +151,7 @@ func TestSequentialExecution(t *testing.T) { for i := range tests { test := tests[i] t.Run(fmt.Sprintf("%v-%#v", i, test.option), func(t *testing.T) { - opt := test.option(&StoreOptions{}) + opt := test.option(DefaultOptions()) t.Run("sequentially", func(t *testing.T) { closeStore := NewCloseFunctionStore(test.option, Sequential) ctx1, cancel1 := context.WithCancel(context.Background()) From d5d1642e6ef902699396411e1e1bbf6d9e600025 Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Wed, 20 Aug 2025 14:30:58 +0100 Subject: [PATCH 3/4] :sparkles: further features --- utils/parallelisation/cancel_functions.go | 9 ++++++ .../parallelisation/cancel_functions_test.go | 30 +++++++++++++------ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/utils/parallelisation/cancel_functions.go b/utils/parallelisation/cancel_functions.go index fa2c526560..3464084e04 100644 --- a/utils/parallelisation/cancel_functions.go +++ b/utils/parallelisation/cancel_functions.go @@ -15,6 +15,15 @@ func (s *CancelFunctionStore) RegisterCancelFunction(cancel ...context.CancelFun s.ExecutionGroup.RegisterFunction(cancel...) } +func (s *CancelFunctionStore) RegisterCancelStore(store *CancelFunctionStore) { + if store == nil { + return + } + s.RegisterCancelFunction(func() { + store.Cancel() + }) +} + // Cancel will execute the cancel functions in the store. Any errors will be ignored and Execute() is recommended if you need to know if a cancellation failed func (s *CancelFunctionStore) Cancel() { _ = s.Execute(context.Background()) diff --git a/utils/parallelisation/cancel_functions_test.go b/utils/parallelisation/cancel_functions_test.go index 75163f5081..73130d2b67 100644 --- a/utils/parallelisation/cancel_functions_test.go +++ b/utils/parallelisation/cancel_functions_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/atomic" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" @@ -19,24 +20,35 @@ func testCancelStore(t *testing.T, store *CancelFunctionStore) { t.Helper() require.NotNil(t, store) // Set up some fake CancelFuncs to make sure they are called - called1 := false - called2 := false + called1 := atomic.NewBool(false) + called2 := atomic.NewBool(false) + called3 := atomic.NewBool(false) + cancelFunc1 := func() { - called1 = true + called1.Store(true) } cancelFunc2 := func() { - called2 = true + called2.Store(true) + } + cancelFunc3 := func() { + called3.Store(true) } + subStore := NewCancelFunctionsStore() + subStore.RegisterCancelFunction(cancelFunc3) store.RegisterCancelFunction(cancelFunc1, cancelFunc2) + store.RegisterCancelStore(subStore) + store.RegisterCancelStore(nil) - assert.Equal(t, 2, store.Len()) - assert.False(t, called1) - assert.False(t, called2) + assert.Equal(t, 3, store.Len()) + assert.False(t, called1.Load()) + assert.False(t, called2.Load()) + assert.False(t, called3.Load()) store.Cancel() - assert.True(t, called1) - assert.True(t, called2) + assert.True(t, called1.Load()) + assert.True(t, called2.Load()) + assert.True(t, called3.Load()) } // Given a CancelFunctionsStore From 36b9c088a238c1a934c3cf00d4298b2ebb2562c6 Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Wed, 20 Aug 2025 16:05:44 +0100 Subject: [PATCH 4/4] complete test --- utils/parallelisation/onclose_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/parallelisation/onclose_test.go b/utils/parallelisation/onclose_test.go index 9a81d61150..4fa2ddfc39 100644 --- a/utils/parallelisation/onclose_test.go +++ b/utils/parallelisation/onclose_test.go @@ -84,6 +84,7 @@ func TestCloseOnce(t *testing.T) { group := NewCloseOnceGroup(Parallel, RetainAfterExecution) group.RegisterCloseFunction(WrapCloserIntoCloseFunc(closerMock), WrapCloserIntoCloseFunc(closerMock), WrapCloserIntoCloseFunc(closerMock)) errortest.AssertError(t, group.Close(), closeError) + require.NoError(t, group.Close()) }) }