Skip to content

Commit

Permalink
GH-35240: [Go][FlightRPC] Fix crash in client middleware (#35241)
Browse files Browse the repository at this point in the history
### Rationale for this change

The Go interceptor API includes a provision for errors from trying to start an RPC. The handler for this in the Flight code was trying to use a nil pointer as a result.

### What changes are included in this PR?

Fix a crash when those errors are encountered.

### Are these changes tested?

New tests were added.

### Are there any user-facing changes?

There are no user-facing changes.
* Closes: #35240

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
lidavidm committed Apr 20, 2023
1 parent 7376b33 commit 27066c1
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
4 changes: 0 additions & 4 deletions go/arrow/flight/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ func CreateClientMiddleware(middleware CustomClientMiddleware) ClientMiddleware
}

if err != nil {
if isHdrs {
md, _ := cs.Header()
hdrs.HeadersReceived(ctx, metadata.Join(md, cs.Trailer()))
}
if isPostcall {
post.CallCompleted(ctx, err)
}
Expand Down
64 changes: 64 additions & 0 deletions go/arrow/flight/flight_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ func (s *ServerMiddlewareAddHeader) CallCompleted(ctx context.Context, err error
}
}

type ServerMiddlewareAddHeaderError struct{}

func (s *ServerMiddlewareAddHeaderError) StartCall(ctx context.Context) context.Context {
grpc.SetHeader(ctx, metadata.Pairs("foo", "bar"))
return nil
}

func (s *ServerMiddlewareAddHeaderError) CallCompleted(ctx context.Context, err error) {
grpc.SetTrailer(ctx, metadata.Pairs("super", "duper"))
}

type ServerTraceMiddleware struct{}

type tracetestKey struct{}
Expand Down Expand Up @@ -252,6 +263,33 @@ func TestClientStreamMiddleware(t *testing.T) {
assert.Equal(t, []string{"duper"}, middleware.md.Get("super"))
}

func TestClientStreamMiddlewareWithError(t *testing.T) {
s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
flight.CreateServerMiddleware(&ServerMiddlewareAddHeaderError{}),
})
s.Init("localhost:0")
f := &flightServer{}
s.RegisterFlightService(f)

go s.Serve()
defer s.Shutdown()

middle := &ClientTestSendHeaderMiddleware{}
client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, []flight.ClientMiddleware{
flight.CreateClientMiddleware(middle),
}, grpc.WithTransportCredentials(insecure.NewCredentials()))

require.NoError(t, err)
defer client.Close()

// UseCompressor triggers a particular rare failure path.
_, err = client.DoGet(context.Background(), &flight.Ticket{Ticket: []byte("this flight does not exist")}, grpc.UseCompressor("foo"))
if err == nil {
t.Fatal("Expected error but got nothing")
}
assert.Contains(t, err.Error(), "Compressor is not installed")
}

func TestClientUnaryMiddleware(t *testing.T) {
s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}),
Expand Down Expand Up @@ -295,3 +333,29 @@ func TestClientUnaryMiddleware(t *testing.T) {
})
}
}

func TestClientUnaryMiddlewareWithError(t *testing.T) {
s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
flight.CreateServerMiddleware(&ServerMiddlewareAddHeaderError{}),
})
s.Init("localhost:0")
f := &flightServer{}
s.RegisterFlightService(f)

go s.Serve()
defer s.Shutdown()

middle := &ClientTestSendHeaderMiddleware{}
client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, []flight.ClientMiddleware{
flight.CreateClientMiddleware(middle),
}, grpc.WithTransportCredentials(insecure.NewCredentials()))

require.NoError(t, err)
defer client.Close()

_, err = client.GetSchema(context.Background(), &flight.FlightDescriptor{Path: []string{"this flight does not exist"}}, grpc.UseCompressor("foo"))
if err == nil {
t.Fatal("Expected error but got nothing")
}
assert.Contains(t, err.Error(), "Compressor is not installed")
}
5 changes: 4 additions & 1 deletion go/arrow/flight/flight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ func (f *flightServer) GetSchema(_ context.Context, in *flight.FlightDescriptor)
}

func (f *flightServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
recs := arrdata.Records[string(tkt.GetTicket())]
recs, ok := arrdata.Records[string(tkt.GetTicket())]
if !ok {
return status.Error(codes.NotFound, "flight not found")
}

w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
for _, r := range recs {
Expand Down

0 comments on commit 27066c1

Please sign in to comment.