diff --git a/item/item.go b/item/item.go index 03fe9ab..b1cb312 100644 --- a/item/item.go +++ b/item/item.go @@ -110,9 +110,10 @@ var ErrExceededRejectRetryLimit = errors.New("item: exceeded lease rejected retr var ErrInvalidLeaseGetStatus = errors.New("item: invalid lease get response status") type multiGetState[T any, K comparable] struct { - keys []K - result map[K]T - err error + completed bool + keys []K + result map[K]T + err error } type multiGetFillerConfig struct { @@ -160,7 +161,8 @@ func NewMultiGetFiller[T any, K comparable]( s.keys = append(s.keys, key) return func() (T, error) { - if state != nil { + if !s.completed { + s.completed = true state = nil values, err := multiGetFunc(ctx, s.keys) diff --git a/item/item_test.go b/item/item_test.go index 6341768..e7e73fd 100644 --- a/item/item_test.go +++ b/item/item_test.go @@ -1237,6 +1237,62 @@ func TestMultiGetFiller(t *testing.T) { {user1.GetKey(), user2.GetKey()}, }, calledKeys) }) + + t.Run("interleaving", func(t *testing.T) { + user1 := userValue{ + Tenant: "TENANT01", + Name: "user01", + Age: 31, + } + user2 := userValue{ + Tenant: "TENANT01", + Name: "user02", + Age: 32, + } + user3 := userValue{ + Tenant: "TENANT02", + Name: "user03", + Age: 33, + } + + var calledKeys [][]userKey + values := [][]userValue{ + {user1, user2}, + {user3}, + } + + filler := NewMultiGetFiller[userValue, userKey]( + func(ctx context.Context, keys []userKey) ([]userValue, error) { + index := len(calledKeys) + calledKeys = append(calledKeys, keys) + return values[index], nil + }, + userValue.GetKey, + ) + + fn1 := filler(newContext(), user1.GetKey()) + fn2 := filler(newContext(), user2.GetKey()) + + resp1, err := fn1() + assert.Equal(t, nil, err) + assert.Equal(t, user1, resp1) + + fn3 := filler(newContext(), user3.GetKey()) + + resp2, err := fn2() + assert.Equal(t, nil, err) + assert.Equal(t, user2, resp2) + + // Get Stage 2 + resp3, err := fn3() + assert.Equal(t, nil, err) + assert.Equal(t, user3, resp3) + + assert.Equal(t, [][]userKey{ + {user1.GetKey(), user2.GetKey()}, + {user3.GetKey()}, + }, calledKeys) + }) } func TestItem_WithFakePipeline(t *testing.T) {