Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20250909150027.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: `safeio` Add support for cancelling readers that make blocking kernel reads during copying
27 changes: 27 additions & 0 deletions utils/safeio/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ func Cat(ctx context.Context, dst io.Writer, src ...io.Reader) (copied int64, er
return CopyDataWithContext(ctx, NewContextualMultipleReader(ctx, src...), dst)
}

// SafeCopyDataWithContext copies from src to dst similarly to io.Copy but with context control to stop when asked.
// Unlike CopyWithContext it requires a ReadCloser, this allows it to stop even if the system is doing a kernel read.
func SafeCopyDataWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer) (copied int64, err error) {
return safeCopyDataWithContext(ctx, src, dst, func(dst io.Writer, src io.ReadCloser) (int64, error) { return io.Copy(dst, src) })
}

// SafeCopyNWithContext copies n bytes from src to dst similarly to io.CopyN but with context control to stop when asked.
// Unlike CopyNWithContext it requires a ReadCloser, this allows it to stop even if the system is doing a kernel read.
func SafeCopyNWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer, n int64) (copied int64, err error) {
return safeCopyDataWithContext(ctx, src, dst, func(dst io.Writer, src io.ReadCloser) (int64, error) { return io.CopyN(dst, src, n) })
}

func copyDataWithContext(ctx context.Context, src io.Reader, dst io.Writer, copyFunc func(io.Writer, io.Reader) (int64, error)) (copied int64, err error) {
err = parallelisation.DetermineContextError(ctx)
if err != nil {
Expand All @@ -37,8 +49,23 @@ func copyDataWithContext(ctx context.Context, src io.Reader, dst io.Writer, copy
return
}

func safeCopyDataWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer, copyFunc func(io.Writer, io.ReadCloser) (int64, error)) (copied int64, err error) {
err = parallelisation.DetermineContextError(ctx)
if err != nil {
return
}
copied, err = reallySafeCopy(ContextualWriter(ctx, dst), NewContextualReadCloser(ctx, src), copyFunc)
return
}

func safeCopy(w io.Writer, r io.Reader, iocopyFunc func(io.Writer, io.Reader) (int64, error)) (int64, error) {
copied, err := iocopyFunc(w, r)
err = ConvertIOError(err)
return copied, err
}

func reallySafeCopy(w io.Writer, r io.ReadCloser, iocopyFunc func(io.Writer, io.ReadCloser) (int64, error)) (int64, error) {
copied, err := iocopyFunc(w, r)
err = ConvertIOError(err)
return copied, err
}
133 changes: 133 additions & 0 deletions utils/safeio/copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ package safeio
import (
"bytes"
"context"
"io"
"os"
"testing"
"time"

"github.com/go-faker/faker/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
Expand Down Expand Up @@ -81,6 +85,135 @@ func TestCopyNWithContext(t *testing.T) {
assert.Equal(t, safecast.ToInt64(len(text)-1), n2)
}

func TestSafeCopyDataWithContext(t *testing.T) {
defer goleak.VerifyNone(t)
var buf1, buf2 bytes.Buffer
text := faker.Sentence()
n, err := WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
assert.Equal(t, len(text), n)
rc := io.NopCloser(bytes.NewReader(buf1.Bytes())) // make it an io.ReadCloser
n2, err := SafeCopyDataWithContext(context.Background(), rc, &buf2)
require.NoError(t, err)
require.NotZero(t, n2)
assert.Equal(t, safecast.ToInt64(len(text)), n2)
assert.Equal(t, text, buf2.String())

ctx, cancel := context.WithCancel(context.Background())
buf1.Reset()
buf2.Reset()
n, err = WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
assert.Equal(t, len(text), n)

cancel()
rc = io.NopCloser(bytes.NewReader(buf1.Bytes()))
n2, err = SafeCopyDataWithContext(ctx, rc, &buf2)
require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrCancelled)
assert.Zero(t, n2)
assert.Empty(t, buf2.String())

r, w, err := os.Pipe()
require.NoError(t, err)
defer func() { _ = w.Close() }()
ctx2, unblock := context.WithCancel(context.Background())
done := make(chan struct{})

go func() {
_, errCopy := SafeCopyDataWithContext(ctx2, r, io.Discard)
_ = r.Close()
_ = errCopy
close(done)
}()

time.Sleep(50 * time.Millisecond) // let it enter read(2) https://man7.org/linux/man-pages/man2/read.2.html
unblock()

select {
case <-done:
// Expected case: unblocked
case <-time.After(2 * time.Second):
assert.FailNow(t, "context cancel should have unblocked copy")
}
}

func TestSafeCopyNWithContext(t *testing.T) {
defer goleak.VerifyNone(t)
var buf1, buf2 bytes.Buffer
text := faker.Sentence()
n, err := WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
assert.Equal(t, len(text), n)
rc := io.NopCloser(bytes.NewReader(buf1.Bytes()))
n2, err := SafeCopyNWithContext(context.Background(), rc, &buf2, safecast.ToInt64(len(text)))
require.NoError(t, err)
require.NotZero(t, n2)
assert.Equal(t, safecast.ToInt64(len(text)), n2)
assert.Equal(t, text, buf2.String())

ctx, cancel := context.WithCancel(context.Background())

buf1.Reset()
buf2.Reset()
n, err = WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
assert.Equal(t, len(text), n)

cancel()
rc = io.NopCloser(bytes.NewReader(buf1.Bytes()))
n2, err = SafeCopyNWithContext(ctx, rc, &buf2, safecast.ToInt64(len(text)))
require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrCancelled)
assert.Zero(t, n2)
assert.Empty(t, buf2.String())

buf1.Reset()
buf2.Reset()
n, err = WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
rc = io.NopCloser(bytes.NewReader(buf1.Bytes()))

wantN := safecast.ToInt64(len(text) - 1)
n2, err = SafeCopyNWithContext(context.Background(), rc, &buf2, wantN)
require.NoError(t, err)
require.NotZero(t, n2)
assert.Equal(t, wantN, n2)
assert.Equal(t, text[:len(text)-1], buf2.String())

r, w, err := os.Pipe()
require.NoError(t, err)
defer func() { _ = w.Close() }()
ctx2, unblock := context.WithCancel(context.Background())
done := make(chan struct{})
var (
copied int64
copyErr error
)

go func() {
copied, copyErr = SafeCopyNWithContext(ctx2, r, io.Discard, 1024) // nothing to read means it blocks
_ = r.Close()
close(done)
}()

time.Sleep(50 * time.Millisecond) // let it enter read(2) https://man7.org/linux/man-pages/man2/read.2.html
unblock()

select {
case <-done:
errortest.AssertError(t, copyErr, commonerrors.ErrCancelled)
assert.Zero(t, copied)
case <-time.After(2 * time.Second):
assert.FailNow(t, "context cancel should have unblocked copy")
}
}

func TestCat(t *testing.T) {
var buf1, buf2, buf3 bytes.Buffer
text1 := faker.Sentence()
Expand Down
4 changes: 4 additions & 0 deletions utils/safeio/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package safeio

import (
"io"
"os"

"github.com/ARM-software/golang-utils/utils/commonerrors"
)
Expand All @@ -16,6 +17,9 @@ func ConvertIOError(err error) (newErr error) {
case commonerrors.Any(newErr, commonerrors.ErrEOF):
case commonerrors.Any(newErr, io.EOF, io.ErrUnexpectedEOF):
newErr = commonerrors.WrapError(commonerrors.ErrEOF, newErr, "")
case commonerrors.Any(newErr, os.ErrClosed):
// cancelling a reader on a copy will cause it to close the file and return os.ErrClosed so map it to cancelled for this package
newErr = commonerrors.WrapError(commonerrors.ErrCancelled, newErr, "")
}
return
}
37 changes: 37 additions & 0 deletions utils/safeio/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"

"github.com/dolmen-go/contextio"
"go.uber.org/atomic"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/parallelisation"
Expand Down Expand Up @@ -76,6 +77,42 @@ func NewContextualReader(ctx context.Context, reader io.Reader) io.Reader {
return contextio.NewReader(ctx, reader)
}

type safeReadCloser struct {
reader io.Reader // use reader to ensure idempotency since you can't call close on the reader itself, only via the wrapper
close parallelisation.CloseFunc
closed *atomic.Bool
}

func (r safeReadCloser) Read(p []byte) (int, error) {
return r.reader.Read(p)
}

func (r safeReadCloser) Close() error {
if r.closed.Swap(true) {
return nil
}

return r.close()
}

// NewContextualReadCloser returns a readcloser which is context aware.
// Context state is checked during the read and close is called if the context is cancelled
// This allows for readers that block on syscalls to be stopped via a context
func NewContextualReadCloser(ctx context.Context, reader io.ReadCloser) io.ReadCloser {
stop := context.AfterFunc(ctx, func() { _ = reader.Close() })

r := safeReadCloser{
reader: contextio.NewReader(ctx, reader),
close: func() error {
_ = stop()
return nil
},
closed: atomic.NewBool(false),
}

return r
}

func NewContextualMultipleReader(ctx context.Context, reader ...io.Reader) io.Reader {
readers := make([]io.Reader, len(reader))
for i := range reader {
Expand Down
70 changes: 70 additions & 0 deletions utils/safeio/read_closer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package safeio

import (
"context"
"io"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)

func TestNewContextualReadCloser(t *testing.T) {
t.Run("Normal contextual reader blocks even after cancel", func(t *testing.T) {
defer goleak.VerifyNone(t)

r, w, err := os.Pipe()
require.NoError(t, err)
defer func() { _ = r.Close(); _ = w.Close() }()

ctx, cancel := context.WithCancel(context.Background())
reader := NewContextualReader(ctx, r)

done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, reader) // will block in read(2) https://man7.org/linux/man-pages/man2/read.2.html
close(done)
}()

// Allow io.Copy to enter kernel read then try to cancel
time.Sleep(50 * time.Millisecond)
cancel()

select {
case <-done:
assert.FailNow(t, "cancelling context shouldn't unblock a blocking Read in io.Copy")
case <-time.After(200 * time.Millisecond):
// Expected case: still blocked
}
})

t.Run("Contextual read closer does not block even on long running copies", func(t *testing.T) {
defer goleak.VerifyNone(t)

r, w, err := os.Pipe()
require.NoError(t, err)
defer func() { _ = w.Close() }()

ctx, cancel := context.WithCancel(context.Background())
rc := NewContextualReadCloser(ctx, r)

done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, rc) // will block in read(2) https://man7.org/linux/man-pages/man2/read.2.html
close(done)
}()

time.Sleep(50 * time.Millisecond)
cancel()

select {
case <-done:
// Expected case: successfully unblocked
case <-time.After(2 * time.Second):
assert.FailNow(t, "copy should have been unblocked by context cancel")
}
})
}
Loading