From e7703776f70616e68ac523185b6ac2e041250ac3 Mon Sep 17 00:00:00 2001 From: joshjennings98 Date: Tue, 9 Sep 2025 15:01:26 +0100 Subject: [PATCH 1/4] :sparkles: `safeio` Add support for cancelling readers that make blocking kernel reads during copying --- changes/20250909150027.feature | 1 + utils/safeio/copy.go | 27 +++++++ utils/safeio/copy_test.go | 133 +++++++++++++++++++++++++++++++ utils/safeio/error.go | 4 + utils/safeio/read.go | 37 +++++++++ utils/safeio/read_closer_test.go | 65 +++++++++++++++ 6 files changed, 267 insertions(+) create mode 100644 changes/20250909150027.feature create mode 100644 utils/safeio/read_closer_test.go diff --git a/changes/20250909150027.feature b/changes/20250909150027.feature new file mode 100644 index 0000000000..fce78d98c9 --- /dev/null +++ b/changes/20250909150027.feature @@ -0,0 +1 @@ +:sparkles: `safeio` Add support for cancelling readers that make blocking kernel reads during copying diff --git a/utils/safeio/copy.go b/utils/safeio/copy.go index 0f2f4038e7..3f8735f48f 100644 --- a/utils/safeio/copy.go +++ b/utils/safeio/copy.go @@ -28,6 +28,18 @@ func Cat(ctx context.Context, dst io.Writer, src ...io.Reader) (copied int64, er return CopyDataWithContext(ctx, NewContextualMultipleReader(ctx, src...), dst) } +// SafeCopyDataWithContext copies from src to dst similarly to io.Copy but with context control to stop when asked. +// Unlike CopyWithContext it requires a ReadCloser, this allows it to stop even if the system is doing a kernel read. +func SafeCopyDataWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer) (copied int64, err error) { + return safeCopyDataWithContext(ctx, src, dst, func(dst io.Writer, src io.ReadCloser) (int64, error) { return io.Copy(dst, src) }) +} + +// SafeCopyNWithContext copies n bytes from src to dst similarly to io.CopyN but with context control to stop when asked. +// Unlike CopyNWithContext it requires a ReadCloser, this allows it to stop even if the system is doing a kernel read. +func SafeCopyNWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer, n int64) (copied int64, err error) { + return safeCopyDataWithContext(ctx, src, dst, func(dst io.Writer, src io.ReadCloser) (int64, error) { return io.CopyN(dst, src, n) }) +} + func copyDataWithContext(ctx context.Context, src io.Reader, dst io.Writer, copyFunc func(io.Writer, io.Reader) (int64, error)) (copied int64, err error) { err = parallelisation.DetermineContextError(ctx) if err != nil { @@ -37,8 +49,23 @@ func copyDataWithContext(ctx context.Context, src io.Reader, dst io.Writer, copy return } +func safeCopyDataWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer, copyFunc func(io.Writer, io.ReadCloser) (int64, error)) (copied int64, err error) { + err = parallelisation.DetermineContextError(ctx) + if err != nil { + return + } + copied, err = reallySafeCopy(ContextualWriter(ctx, dst), NewContextualReadCloser(ctx, src), copyFunc) + return +} + func safeCopy(w io.Writer, r io.Reader, iocopyFunc func(io.Writer, io.Reader) (int64, error)) (int64, error) { copied, err := iocopyFunc(w, r) err = ConvertIOError(err) return copied, err } + +func reallySafeCopy(w io.Writer, r io.ReadCloser, iocopyFunc func(io.Writer, io.ReadCloser) (int64, error)) (int64, error) { + copied, err := iocopyFunc(w, r) + err = ConvertIOError(err) + return copied, err +} diff --git a/utils/safeio/copy_test.go b/utils/safeio/copy_test.go index ccae39860d..9745fe4091 100644 --- a/utils/safeio/copy_test.go +++ b/utils/safeio/copy_test.go @@ -3,11 +3,15 @@ package safeio import ( "bytes" "context" + "io" + "os" "testing" + "time" "github.com/go-faker/faker/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" @@ -81,6 +85,135 @@ func TestCopyNWithContext(t *testing.T) { assert.Equal(t, safecast.ToInt64(len(text)-1), n2) } +func TestSafeCopyDataWithContext(t *testing.T) { + defer goleak.VerifyNone(t) + var buf1, buf2 bytes.Buffer + text := faker.Sentence() + n, err := WriteString(context.Background(), &buf1, text) + require.NoError(t, err) + require.NotZero(t, n) + assert.Equal(t, len(text), n) + rc := io.NopCloser(bytes.NewReader(buf1.Bytes())) // make it an io.ReadCloser + n2, err := SafeCopyDataWithContext(context.Background(), rc, &buf2) + require.NoError(t, err) + require.NotZero(t, n2) + assert.Equal(t, safecast.ToInt64(len(text)), n2) + assert.Equal(t, text, buf2.String()) + + ctx, cancel := context.WithCancel(context.Background()) + buf1.Reset() + buf2.Reset() + n, err = WriteString(context.Background(), &buf1, text) + require.NoError(t, err) + require.NotZero(t, n) + assert.Equal(t, len(text), n) + + cancel() + rc = io.NopCloser(bytes.NewReader(buf1.Bytes())) + n2, err = SafeCopyDataWithContext(ctx, rc, &buf2) + require.Error(t, err) + errortest.AssertError(t, err, commonerrors.ErrCancelled) + assert.Zero(t, n2) + assert.Empty(t, buf2.String()) + + r, w, err := os.Pipe() + require.NoError(t, err) + defer func() { _ = w.Close() }() + ctx2, unblock := context.WithCancel(context.Background()) + done := make(chan struct{}) + + go func() { + _, errCopy := SafeCopyDataWithContext(ctx2, r, io.Discard) + _ = r.Close() + _ = errCopy + close(done) + }() + + time.Sleep(50 * time.Millisecond) // let it enter read(2) https://man7.org/linux/man-pages/man2/read.2.html + unblock() + + select { + case <-done: + // Expected case: unblocked + case <-time.After(2 * time.Second): + assert.FailNow(t, "context cancel should have unblocked copy") + } +} + +func TestSafeCopyNWithContext(t *testing.T) { + defer goleak.VerifyNone(t) + var buf1, buf2 bytes.Buffer + text := faker.Sentence() + n, err := WriteString(context.Background(), &buf1, text) + require.NoError(t, err) + require.NotZero(t, n) + assert.Equal(t, len(text), n) + rc := io.NopCloser(bytes.NewReader(buf1.Bytes())) + n2, err := SafeCopyNWithContext(context.Background(), rc, &buf2, safecast.ToInt64(len(text))) + require.NoError(t, err) + require.NotZero(t, n2) + assert.Equal(t, safecast.ToInt64(len(text)), n2) + assert.Equal(t, text, buf2.String()) + + ctx, cancel := context.WithCancel(context.Background()) + + buf1.Reset() + buf2.Reset() + n, err = WriteString(context.Background(), &buf1, text) + require.NoError(t, err) + require.NotZero(t, n) + assert.Equal(t, len(text), n) + + cancel() + rc = io.NopCloser(bytes.NewReader(buf1.Bytes())) + n2, err = SafeCopyNWithContext(ctx, rc, &buf2, safecast.ToInt64(len(text))) + require.Error(t, err) + errortest.AssertError(t, err, commonerrors.ErrCancelled) + assert.Zero(t, n2) + assert.Empty(t, buf2.String()) + + buf1.Reset() + buf2.Reset() + n, err = WriteString(context.Background(), &buf1, text) + require.NoError(t, err) + require.NotZero(t, n) + rc = io.NopCloser(bytes.NewReader(buf1.Bytes())) + + wantN := safecast.ToInt64(len(text) - 1) + n2, err = SafeCopyNWithContext(context.Background(), rc, &buf2, wantN) + require.NoError(t, err) + require.NotZero(t, n2) + assert.Equal(t, wantN, n2) + assert.Equal(t, text[:len(text)-1], buf2.String()) + + r, w, err := os.Pipe() + require.NoError(t, err) + defer func() { _ = w.Close() }() + ctx2, unblock := context.WithCancel(context.Background()) + done := make(chan struct{}) + var ( + copied int64 + copyErr error + ) + + go func() { + copied, copyErr = SafeCopyNWithContext(ctx2, r, io.Discard, 1024) // nothing to read means it blocks + _ = r.Close() + close(done) + }() + + time.Sleep(50 * time.Millisecond) // let it enter read(2) https://man7.org/linux/man-pages/man2/read.2.html + unblock() + + select { + case <-done: + errortest.AssertError(t, copyErr, commonerrors.ErrCancelled) + assert.Zero(t, copied) + case <-time.After(2 * time.Second): + assert.FailNow(t, "context cancel should have unblocked copy") + } +} + func TestCat(t *testing.T) { var buf1, buf2, buf3 bytes.Buffer text1 := faker.Sentence() diff --git a/utils/safeio/error.go b/utils/safeio/error.go index 368116be49..041ec613e1 100644 --- a/utils/safeio/error.go +++ b/utils/safeio/error.go @@ -2,6 +2,7 @@ package safeio import ( "io" + "os" "github.com/ARM-software/golang-utils/utils/commonerrors" ) @@ -16,6 +17,9 @@ func ConvertIOError(err error) (newErr error) { case commonerrors.Any(newErr, commonerrors.ErrEOF): case commonerrors.Any(newErr, io.EOF, io.ErrUnexpectedEOF): newErr = commonerrors.WrapError(commonerrors.ErrEOF, newErr, "") + case commonerrors.Any(newErr, os.ErrClosed): + // cancelling a reader on a copy will cause it to close the file and return os.ErrClosed so map it to cancelled for this package + newErr = commonerrors.WrapError(commonerrors.ErrCancelled, newErr, "") } return } diff --git a/utils/safeio/read.go b/utils/safeio/read.go index 8044617dfa..f874010a7d 100644 --- a/utils/safeio/read.go +++ b/utils/safeio/read.go @@ -6,6 +6,7 @@ import ( "io" "github.com/dolmen-go/contextio" + "go.uber.org/atomic" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/parallelisation" @@ -76,6 +77,42 @@ func NewContextualReader(ctx context.Context, reader io.Reader) io.Reader { return contextio.NewReader(ctx, reader) } +type safeReadCloser struct { + reader io.Reader + close parallelisation.CloseFunc + closed *atomic.Bool +} + +func (r safeReadCloser) Read(p []byte) (int, error) { + return r.reader.Read(p) +} + +func (r safeReadCloser) Close() error { + if r.closed.Swap(true) { + return nil + } + + return r.close() +} + +// NewContextualReadCloser returns a readcloser which is context aware. +// Context state is checked during the read and close is called if the context is cancelled +// This allows for readers that block on syscalls to be stopped via a context +func NewContextualReadCloser(ctx context.Context, reader io.ReadCloser) io.ReadCloser { + stop := context.AfterFunc(ctx, func() { _ = reader.Close() }) + + r := safeReadCloser{ + reader: contextio.NewReader(ctx, reader), + close: func() error { + _ = stop() + return reader.Close() + }, + closed: atomic.NewBool(false), + } + + return r +} + func NewContextualMultipleReader(ctx context.Context, reader ...io.Reader) io.Reader { readers := make([]io.Reader, len(reader)) for i := range reader { diff --git a/utils/safeio/read_closer_test.go b/utils/safeio/read_closer_test.go new file mode 100644 index 0000000000..76f9d367ec --- /dev/null +++ b/utils/safeio/read_closer_test.go @@ -0,0 +1,65 @@ +package safeio + +import ( + "context" + "io" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewContextualReadCloser(t *testing.T) { + t.Run("Normal contextual reader blocks even after cancel", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + defer func() { _ = r.Close(); _ = w.Close() }() + + ctx, cancel := context.WithCancel(context.Background()) + reader := NewContextualReader(ctx, r) + + done := make(chan struct{}) + go func() { + _, _ = io.Copy(io.Discard, reader) // will block in read(2) https://man7.org/linux/man-pages/man2/read.2.html + close(done) + }() + + // Allow io.Copy to enter kernel read then try to cancel + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case <-done: + assert.FailNow(t, "cancelling context shouldn't unblock a blocking Read in io.Copy") + case <-time.After(200 * time.Millisecond): + // Expected case: still blocked + } + }) + + t.Run("Contextual read closer does not block even on long running copies", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + defer func() { _ = w.Close() }() + + ctx, cancel := context.WithCancel(context.Background()) + rc := NewContextualReadCloser(ctx, r) + + done := make(chan struct{}) + go func() { + _, _ = io.Copy(io.Discard, rc) // will block in read(2) https://man7.org/linux/man-pages/man2/read.2.html + close(done) + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case <-done: + // Expected case: successfully unblocked + case <-time.After(2 * time.Second): + assert.FailNow(t, "copy should have been unblocked by context cancel") + } + }) +} From eb31c93cbf314b5879bbf194286ec909637c2d43 Mon Sep 17 00:00:00 2001 From: joshjennings98 Date: Tue, 9 Sep 2025 15:15:55 +0100 Subject: [PATCH 2/4] ensure idempotency --- utils/safeio/read.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/safeio/read.go b/utils/safeio/read.go index f874010a7d..ab3c2a1bad 100644 --- a/utils/safeio/read.go +++ b/utils/safeio/read.go @@ -105,7 +105,7 @@ func NewContextualReadCloser(ctx context.Context, reader io.ReadCloser) io.ReadC reader: contextio.NewReader(ctx, reader), close: func() error { _ = stop() - return reader.Close() + return nil }, closed: atomic.NewBool(false), } From 18bd73aea9245dab980ea277f5290a2569f70461 Mon Sep 17 00:00:00 2001 From: joshjennings98 Date: Tue, 9 Sep 2025 15:21:31 +0100 Subject: [PATCH 3/4] goleak.VerifyNone --- utils/safeio/read_closer_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/utils/safeio/read_closer_test.go b/utils/safeio/read_closer_test.go index 76f9d367ec..09d5809d23 100644 --- a/utils/safeio/read_closer_test.go +++ b/utils/safeio/read_closer_test.go @@ -9,10 +9,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" ) func TestNewContextualReadCloser(t *testing.T) { t.Run("Normal contextual reader blocks even after cancel", func(t *testing.T) { + defer goleak.VerifyNone(t) + r, w, err := os.Pipe() require.NoError(t, err) defer func() { _ = r.Close(); _ = w.Close() }() @@ -39,6 +42,8 @@ func TestNewContextualReadCloser(t *testing.T) { }) t.Run("Contextual read closer does not block even on long running copies", func(t *testing.T) { + defer goleak.VerifyNone(t) + r, w, err := os.Pipe() require.NoError(t, err) defer func() { _ = w.Close() }() From 5471997423f33026eec1f37b2511417d357e51b6 Mon Sep 17 00:00:00 2001 From: joshjennings98 Date: Tue, 9 Sep 2025 15:25:10 +0100 Subject: [PATCH 4/4] comment --- utils/safeio/read.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/safeio/read.go b/utils/safeio/read.go index ab3c2a1bad..c831c0fd43 100644 --- a/utils/safeio/read.go +++ b/utils/safeio/read.go @@ -78,7 +78,7 @@ func NewContextualReader(ctx context.Context, reader io.Reader) io.Reader { } type safeReadCloser struct { - reader io.Reader + reader io.Reader // use reader to ensure idempotency since you can't call close on the reader itself, only via the wrapper close parallelisation.CloseFunc closed *atomic.Bool }