Skip to content

Commit

Permalink
Add batch.BySize
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed Jun 17, 2024
1 parent 9f1c5e2 commit cbd1eb0
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ static:

.PHONY: test
test:
go test ./...
go test -cover ./...

.PHONY: race
race:
Expand Down
48 changes: 48 additions & 0 deletions lib/batch/batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package batch

import "fmt"

// BatchBySize takes a slice of elements, encodes them, groups batches of bytes that sum to at most [maxSize], and then p
// passes each of those batches to the [yield] function.
func BySize[T any](in []T, maxSize int, encode func(T) ([]byte, error), yield func([][]byte) error) error {
var buffer [][]byte
var currentSize int

for i, item := range in {
bytes, err := encode(item)
if err != nil {
return fmt.Errorf("failed to encode item %d: %w", i, err)
}

if len(bytes) > maxSize {
return fmt.Errorf("item %d is larger (%d bytes) than max size (%d bytes)", i, len(bytes), maxSize)
}

currentSize += len(bytes)

if currentSize < maxSize {
buffer = append(buffer, bytes)
} else if currentSize == maxSize {
buffer = append(buffer, bytes)
if err := yield(buffer); err != nil {
return err
}
buffer = [][]byte{}
currentSize = 0
} else {
if err := yield(buffer); err != nil {
return err
}
buffer = [][]byte{bytes}
currentSize = len(bytes)
}
}

if len(buffer) > 0 {
if err := yield(buffer); err != nil {
return err
}
}

return nil
}
90 changes: 90 additions & 0 deletions lib/batch/batch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package batch

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestBySize(t *testing.T) {
goodEncoder := func(value string) ([]byte, error) {
return []byte(value), nil
}

panicEncoder := func(value string) ([]byte, error) {
panic("should not be called")
}

badEncoder := func(value string) ([]byte, error) {
return nil, fmt.Errorf("failed to encode %q", value)
}

testBySize := func(in []string, maxSize int, encoder func(value string) ([]byte, error)) ([][][]byte, error) {
out := [][][]byte{}
err := BySize(in, maxSize, encoder, func(batch [][]byte) error { out = append(out, batch); return nil })
return out, err
}

{
// Empty slice:
batches, err := testBySize([]string{}, 0, panicEncoder)
assert.NoError(t, err)
assert.Empty(t, batches)
}
{
// Non-empty slice + bad encoder:
_, err := testBySize([]string{"foo", "bar"}, 10, badEncoder)
assert.ErrorContains(t, err, `failed to encode item 0: failed to encode "foo"`)
}
{
// Non-empty slice + two items that are < maxSize + yield returns error.
err := BySize([]string{"foo", "bar"}, 10, goodEncoder, func(batch [][]byte) error { return fmt.Errorf("yield failed") })
assert.ErrorContains(t, err, "yield failed")
}
{
// Non-empty slice + two items that are = maxSize + yield returns error.
err := BySize([]string{"foo", "bar"}, 6, goodEncoder, func(batch [][]byte) error { return fmt.Errorf("yield failed") })
assert.ErrorContains(t, err, "yield failed")
}
{
// Non-empty slice + two items that are > maxSize + yield returns error.
err := BySize([]string{"foo", "bar-baz"}, 8, goodEncoder, func(batch [][]byte) error { return fmt.Errorf("yield failed") })
assert.ErrorContains(t, err, "yield failed")
}
{
// Non-empty slice + item is larger than max size:
_, err := testBySize([]string{"foo", "i-am-23-characters-long", "bar"}, 20, goodEncoder)
assert.ErrorContains(t, err, "item 1 is larger (23 bytes) than max size (20 bytes)")
}
{
// Non-empty slice + item equal to max size:
batches, err := testBySize([]string{"foo", "i-am-23-characters-long", "bar"}, 23, goodEncoder)
assert.NoError(t, err)
assert.Len(t, batches, 3)
assert.Equal(t, [][]byte{[]byte("foo")}, batches[0])
assert.Equal(t, [][]byte{[]byte("i-am-23-characters-long")}, batches[1])
assert.Equal(t, [][]byte{[]byte("bar")}, batches[2])
}
{
// Non-empty slice + one item:
batches, err := testBySize([]string{"foo"}, 100, goodEncoder)
assert.NoError(t, err)
assert.Len(t, batches, 1)
assert.Equal(t, [][]byte{[]byte("foo")}, batches[0])
}
{
// Non-empty slice + all items exactly fit into one batch:
batches, err := testBySize([]string{"foo", "bar", "baz", "qux"}, 12, goodEncoder)
assert.NoError(t, err)
assert.Len(t, batches, 1)
assert.Equal(t, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("qux")}, batches[0])
}
{
// Non-empty slice + all items exactly fit into just under one batch:
batches, err := testBySize([]string{"foo", "bar", "baz", "qux"}, 13, goodEncoder)
assert.NoError(t, err)
assert.Len(t, batches, 1)
assert.Equal(t, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("qux")}, batches[0])
}
}

0 comments on commit cbd1eb0

Please sign in to comment.