Skip to content

Commit

Permalink
[otelarrowreceiver] Ensure consume operations are not canceled at str…
Browse files Browse the repository at this point in the history
…eam EOF (open-telemetry#33570)

**Description:** Fixes a bug in the OTel Arrow receiver. When a stream
reaches its end-of-life, the exporter closes the send channel and the
receiver's `Recv()` loop receives an EOF error. This was inadvertently
canceling a context too soon, such that requests in-flight during EOF
would be canceled before finishing.

**Link to tracking Issue:** open-telemetry#26491 

**Testing:** Several tests cover this scenario, and they had to change
for this fix. There is now an extra assertion in the healthy test
channel to ensure that consumers never receive data on a canceled
context (in testing).
  • Loading branch information
jmacd authored and cparkins committed Jul 11, 2024
1 parent 74c5410 commit 1fc43b5
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 59 deletions.
27 changes: 27 additions & 0 deletions .chloggen/otelarrow-eof-cancel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Use this changelog template to create an entry for release notes.

# One of 'breaking', 'deprecation', 'new_component', 'enhancement', 'bug_fix'
change_type: bug_fix

# The name of the component, or a single word describing the area of concern, (e.g. filelogreceiver)
component: otelarrowreceiver

# A brief description of the change. Surround your text with quotes ("") if it needs to start with a backtick (`).
note: Ensure consume operations are not canceled at stream EOF.

# Mandatory: One or more tracking issues related to the change. You can use the PR number here if no issue exists.
issues: [33570]

# (Optional) One or more lines of additional information to render under the primary note.
# These lines will be padded with 2 spaces and then inserted directly into the document.
# Use pipe (|) for multiline entries.
subtext:

# If your change doesn't affect end users or the exported elements of any package,
# you should instead start your pull request title with [chore] or use the "Skip Changelog" label.
# Optional: The change log or logs in which this entry should be included.
# e.g. '[user]' or '[user, api]'
# Include 'user' if the change is relevant to end users.
# Include 'api' if there is a change to a library API.
# Default: '[user]'
change_logs: [user]
44 changes: 31 additions & 13 deletions receiver/otelarrowreceiver/internal/arrow/arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,8 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr
doneCtx, doneCancel := context.WithCancel(streamCtx)
defer doneCancel()

// streamErrCh returns up to two errors from the sender and
// receiver threads started below.
streamErrCh := make(chan error, 2)
recvErrCh := make(chan error, 1)
sendErrCh := make(chan error, 1)
pendingCh := make(chan batchResp, runtime.NumCPU())

// wg is used to ensure this thread returns after both
Expand All @@ -390,6 +389,11 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr
sendWG.Add(1)
recvWG.Add(1)

// flushCtx controls the start of flushing. when this is canceled
// after the receiver finishes, the flush operation begins.
flushCtx, flushCancel := context.WithCancel(doneCtx)
defer flushCancel()

rstream := &receiverStream{
Receiver: r,
}
Expand All @@ -399,27 +403,41 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr
defer recvWG.Done()
defer r.recoverErr(&err)
err = rstream.srvReceiveLoop(doneCtx, serverStream, pendingCh, method, ac)
streamErrCh <- err
recvErrCh <- err
}()

go func() {
var err error
defer sendWG.Done()
defer r.recoverErr(&err)
err = rstream.srvSendLoop(doneCtx, serverStream, &recvWG, pendingCh)
streamErrCh <- err
// the sender receives flushCtx, which is canceled after the
// receiver returns (success or no).
err = rstream.srvSendLoop(flushCtx, serverStream, &recvWG, pendingCh)
sendErrCh <- err
}()

// Wait for sender/receiver threads to return before returning.
defer recvWG.Wait()
defer sendWG.Wait()

select {
case <-doneCtx.Done():
return status.Error(codes.Canceled, "server stream shutdown")
case retErr = <-streamErrCh:
doneCancel()
return
for {
select {
case <-doneCtx.Done():
return status.Error(codes.Canceled, "server stream shutdown")
case err := <-recvErrCh:
flushCancel()
if errors.Is(err, io.EOF) {
// the receiver returned EOF, next we
// expect the sender to finish.
continue
}
return err
case err := <-sendErrCh:
// explicit cancel here, in case the sender fails before
// the receiver does. break the receiver loop here:
doneCancel()
return err
}
}
}

Expand Down Expand Up @@ -555,7 +573,7 @@ func (r *receiverStream) recvOne(streamCtx context.Context, serverStream anyStre

if err != nil {
if errors.Is(err, io.EOF) {
return status.Error(codes.Canceled, "client stream shutdown")
return err
} else if errors.Is(err, context.Canceled) {
return status.Error(codes.Canceled, "server stream shutdown")
}
Expand Down
102 changes: 56 additions & 46 deletions receiver/otelarrowreceiver/internal/arrow/arrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,42 @@ type commonTestCase struct {
}

type testChannel interface {
onConsume() error
onConsume(ctx context.Context) error
}

type healthyTestChannel struct{}
type healthyTestChannel struct {
t *testing.T
}

func newHealthyTestChannel(t *testing.T) *healthyTestChannel {
return &healthyTestChannel{t: t}
}

func (healthyTestChannel) onConsume() error {
return nil
func (h healthyTestChannel) onConsume(ctx context.Context) error {
select {
case <-ctx.Done():
h.t.Error("unexpected consume with canceled request")
return ctx.Err()
default:
return nil
}
}

type unhealthyTestChannel struct{}
type unhealthyTestChannel struct {
t *testing.T
}

func (unhealthyTestChannel) onConsume() error {
return status.Errorf(codes.Unavailable, "consumer unhealthy")
func newUnhealthyTestChannel(t *testing.T) *unhealthyTestChannel {
return &unhealthyTestChannel{t: t}
}

func (u unhealthyTestChannel) onConsume(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return status.Errorf(codes.Unavailable, "consumer unhealthy")
}
}

type recvResult struct {
Expand Down Expand Up @@ -160,7 +183,7 @@ func (ctc *commonTestCase) doAndReturnConsumeTraces(tc testChannel) func(ctx con
Ctx: ctx,
Data: traces,
}
return tc.onConsume()
return tc.onConsume(ctx)
}
}

Expand All @@ -170,7 +193,7 @@ func (ctc *commonTestCase) doAndReturnConsumeMetrics(tc testChannel) func(ctx co
Ctx: ctx,
Data: metrics,
}
return tc.onConsume()
return tc.onConsume(ctx)
}
}

Expand All @@ -180,7 +203,7 @@ func (ctc *commonTestCase) doAndReturnConsumeLogs(tc testChannel) func(ctx conte
Ctx: ctx,
Data: logs,
}
return tc.onConsume()
return tc.onConsume(ctx)
}
}

Expand Down Expand Up @@ -420,7 +443,7 @@ func TestBoundedQueueWithPdataHeaders(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

td := testdata.GenerateTraces(tt.numTraces)
Expand Down Expand Up @@ -468,7 +491,7 @@ func TestBoundedQueueWithPdataHeaders(t *testing.T) {

func TestReceiverTraces(t *testing.T) {
stdTesting := otelAssert.NewStdUnitTest(t)
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

td := testdata.GenerateTraces(2)
Expand All @@ -491,7 +514,7 @@ func TestReceiverTraces(t *testing.T) {
}

func TestReceiverLogs(t *testing.T) {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

ld := testdata.GenerateLogs(2)
Expand All @@ -510,7 +533,7 @@ func TestReceiverLogs(t *testing.T) {
}

func TestReceiverMetrics(t *testing.T) {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)
stdTesting := otelAssert.NewStdUnitTest(t)

Expand All @@ -534,7 +557,7 @@ func TestReceiverMetrics(t *testing.T) {
}

func TestReceiverRecvError(t *testing.T) {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

ctc.start(ctc.newRealConsumer, defaultBQ())
Expand All @@ -547,7 +570,7 @@ func TestReceiverRecvError(t *testing.T) {
}

func TestReceiverSendError(t *testing.T) {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

ld := testdata.GenerateLogs(2)
Expand Down Expand Up @@ -587,7 +610,7 @@ func TestReceiverConsumeError(t *testing.T) {
}

for _, item := range data {
tc := unhealthyTestChannel{}
tc := newUnhealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

var batch *arrowpb.BatchArrowRecords
Expand Down Expand Up @@ -646,7 +669,7 @@ func TestReceiverInvalidData(t *testing.T) {
}

for _, item := range data {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

var batch *arrowpb.BatchArrowRecords
Expand Down Expand Up @@ -682,7 +705,7 @@ func TestReceiverMemoryLimit(t *testing.T) {
}

for _, item := range data {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

var batch *arrowpb.BatchArrowRecords
Expand Down Expand Up @@ -738,7 +761,7 @@ func copyBatch(in *arrowpb.BatchArrowRecords) *arrowpb.BatchArrowRecords {
}

func TestReceiverEOF(t *testing.T) {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)
stdTesting := otelAssert.NewStdUnitTest(t)

Expand Down Expand Up @@ -771,9 +794,7 @@ func TestReceiverEOF(t *testing.T) {
wg.Add(1)

go func() {
err := ctc.wait()
// EOF is treated the same as Canceled.
requireCanceledStatus(t, err)
require.NoError(t, ctc.wait())
wg.Done()
}()

Expand All @@ -800,7 +821,7 @@ func TestReceiverHeadersNoAuth(t *testing.T) {
}

func testReceiverHeaders(t *testing.T, includeMeta bool) {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

expectData := []map[string][]string{
Expand Down Expand Up @@ -855,9 +876,7 @@ func testReceiverHeaders(t *testing.T, includeMeta bool) {
wg.Add(1)

go func() {
err := ctc.wait()
// EOF is treated the same as Canceled.
requireCanceledStatus(t, err)
require.NoError(t, ctc.wait())
wg.Done()
}()

Expand All @@ -883,7 +902,7 @@ func testReceiverHeaders(t *testing.T, includeMeta bool) {
}

func TestReceiverCancel(t *testing.T) {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

ctc.cancel()
Expand Down Expand Up @@ -1159,7 +1178,7 @@ func TestReceiverAuthHeadersStream(t *testing.T) {
}

func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) {
tc := healthyTestChannel{}
tc := newHealthyTestChannel(t)
ctc := newCommonTestCase(t, tc)

expectData := []map[string][]string{
Expand Down Expand Up @@ -1245,7 +1264,7 @@ func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) {
close(ctc.receive)
}()

var expectErrs []bool
var expectCodes []arrowpb.StatusCode

for _, testInput := range expectData {
// The static stream context contains one extra variable.
Expand All @@ -1256,7 +1275,7 @@ func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) {
cpy[k] = v
}

expectErr := false
expectCode := arrowpb.StatusCode_OK
if dataAuth {
hasAuth := false
for _, val := range cpy["auth"] {
Expand All @@ -1265,13 +1284,13 @@ func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) {
if hasAuth {
cpy["has_auth"] = []string{":+1:", ":100:"}
} else {
expectErr = true
expectCode = arrowpb.StatusCode_UNAUTHENTICATED
}
}

expectErrs = append(expectErrs, expectErr)
expectCodes = append(expectCodes, expectCode)

if expectErr {
if expectCode != arrowpb.StatusCode_OK {
continue
}

Expand All @@ -1286,23 +1305,14 @@ func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) {
}
}

err := ctc.wait()
// EOF is treated the same as Canceled
requireCanceledStatus(t, err)

// Add in expectErrs for when receiver sees EOF,
// the status code will not be arrowpb.StatusCode_OK.
expectErrs = append(expectErrs, true)
require.NoError(t, ctc.wait())

require.Equal(t, len(expectCodes), dataCount)
require.Equal(t, len(expectData), dataCount)
require.Equal(t, len(recvBatches), dataCount)

for idx, batch := range recvBatches {
if expectErrs[idx] {
require.NotEqual(t, arrowpb.StatusCode_OK, batch.StatusCode)
} else {
require.Equal(t, arrowpb.StatusCode_OK, batch.StatusCode)
}
require.Equal(t, expectCodes[idx], batch.StatusCode)
}
}

Expand Down

0 comments on commit 1fc43b5

Please sign in to comment.