From 85a111f8d5cef3a668c9cf8c47ccb943048e50f6 Mon Sep 17 00:00:00 2001 From: Martin Hilton Date: Tue, 17 Jan 2023 16:15:48 +0000 Subject: [PATCH] GH-33717: [Go] Flight SQL Server handle StreamChunk errors (#33718) ### Rationale for this change Errors returned in the `StreamChunk` channel of the Flight SQL server's `DoGet` method are currently being dropped rather than returned to gRPC client. This is a fix for that problem. ### What changes are included in this PR? Fix the bug detailed above and add tesing for the fixed code path. ### Are these changes tested? Yes, the majority of the PR add appropriate test cases. ### Are there any user-facing changes? Only the fixed bug. * Closes: #33717 Authored-by: Martin Hilton Signed-off-by: Matt Topol --- go/arrow/flight/flightsql/server.go | 2 +- go/arrow/flight/flightsql/server_test.go | 154 +++++++++++++++++++++++ 2 files changed, 155 insertions(+), 1 deletion(-) diff --git a/go/arrow/flight/flightsql/server.go b/go/arrow/flight/flightsql/server.go index c6938073d56e9..9061ffd2eaa50 100644 --- a/go/arrow/flight/flightsql/server.go +++ b/go/arrow/flight/flightsql/server.go @@ -632,7 +632,7 @@ func (f *flightSqlServer) DoGet(request *flight.Ticket, stream flight.FlightServ for chunk := range cc { if chunk.Err != nil { - return err + return chunk.Err } wr.SetFlightDescriptor(chunk.Desc) diff --git a/go/arrow/flight/flightsql/server_test.go b/go/arrow/flight/flightsql/server_test.go index 41420c1dcb3fe..db174ce5b1a43 100644 --- a/go/arrow/flight/flightsql/server_test.go +++ b/go/arrow/flight/flightsql/server_test.go @@ -18,9 +18,12 @@ package flightsql_test import ( "context" + "fmt" "strings" "testing" + "github.com/apache/arrow/go/v11/arrow" + "github.com/apache/arrow/go/v11/arrow/array" "github.com/apache/arrow/go/v11/arrow/flight" "github.com/apache/arrow/go/v11/arrow/flight/flightsql" pb "github.com/apache/arrow/go/v11/arrow/flight/internal/flight" @@ -36,6 +39,156 @@ import ( var dialOpts = []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} +type testServer struct { + flightsql.BaseServer +} + +func (*testServer) GetFlightInfoStatement(ctx context.Context, q flightsql.StatementQuery, fd *flight.FlightDescriptor) (*flight.FlightInfo, error) { + ticket, err := flightsql.CreateStatementQueryTicket([]byte(q.GetQuery())) + if err != nil { + return nil, err + } + return &flight.FlightInfo{ + FlightDescriptor: fd, + Endpoint: []*flight.FlightEndpoint{{ + Ticket: &flight.Ticket{Ticket: ticket}, + }}, + }, nil +} + +func (*testServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (sc *arrow.Schema, cc <-chan flight.StreamChunk, err error) { + handle := string(ticket.GetStatementHandle()) + switch handle { + case "1": + b := array.NewInt16Builder(memory.DefaultAllocator) + sc = arrow.NewSchema([]arrow.Field{{ + Name: "t1", + Type: b.Type(), + Nullable: true, + }}, nil) + b.AppendNull() + c := make(chan flight.StreamChunk, 2) + c <- flight.StreamChunk{ + Data: array.NewRecord(sc, []arrow.Array{b.NewArray()}, 1), + } + b.Append(1) + c <- flight.StreamChunk{ + Data: array.NewRecord(sc, []arrow.Array{b.NewArray()}, 1), + } + close(c) + cc = c + case "2": + b := array.NewInt16Builder(memory.DefaultAllocator) + sc = arrow.NewSchema([]arrow.Field{{ + Name: "t1", + Type: b.Type(), + Nullable: true, + }}, nil) + b.Append(2) + c := make(chan flight.StreamChunk, 2) + c <- flight.StreamChunk{ + Data: array.NewRecord(sc, []arrow.Array{b.NewArray()}, 1), + } + c <- flight.StreamChunk{ + Err: status.Error(codes.Internal, "test error"), + } + close(c) + cc = c + default: + err = fmt.Errorf("unknown statement handle: %s", handle) + } + return +} + +type FlightSqlServerSuite struct { + suite.Suite + + s flight.Server + cl *flightsql.Client +} + +func (s *FlightSqlServerSuite) SetupSuite() { + s.s = flight.NewServerWithMiddleware(nil) + srv := flightsql.NewFlightServer(&testServer{}) + s.s.RegisterFlightService(srv) + s.s.Init("localhost:0") + + go s.s.Serve() +} + +func (s *FlightSqlServerSuite) TearDownSuite() { + s.s.Shutdown() +} + +func (s *FlightSqlServerSuite) SetupTest() { + cl, err := flightsql.NewClient(s.s.Addr().String(), nil, nil, dialOpts...) + s.Require().NoError(err) + s.cl = cl +} + +func (s *FlightSqlServerSuite) TearDownTest() { + s.Require().NoError(s.cl.Close()) + s.cl = nil +} + +func (s *FlightSqlServerSuite) TestExecute() { + fi, err := s.cl.Execute(context.TODO(), "1") + s.Require().NoError(err) + ep := fi.GetEndpoint() + s.Require().Len(ep, 1) + fr, err := s.cl.DoGet(context.TODO(), ep[0].GetTicket()) + s.Require().NoError(err) + var recs []arrow.Record + for fr.Next() { + rec := fr.Record() + rec.Retain() + defer rec.Release() + recs = append(recs, rec) + } + s.Require().NoError(fr.Err()) + tbl := array.NewTableFromRecords(fr.Schema(), recs) + defer tbl.Release() + s.Assert().Equal(int64(2), tbl.NumRows()) + s.Assert().Equal(int64(1), tbl.NumCols()) + col := tbl.Column(0) + s.Assert().Equal("t1", col.Name()) + s.Assert().Equal(2, col.Len()) + s.Assert().Equal(1, col.NullN()) + s.Assert().Equal(arrow.INT16, col.DataType().ID()) + var n int + for _, arr := range col.Data().Chunks() { + data := array.NewInt16Data(arr.Data()) + defer data.Release() + for i := 0; i < data.Len(); i++ { + switch n { + case 0: + s.Assert().Equal(true, data.IsNull(i)) + case 1: + s.Assert().Equal(false, data.IsNull(i)) + s.Assert().Equal(int16(1), data.Value(i)) + } + n++ + } + } +} + +func (s *FlightSqlServerSuite) TestExecuteChunkError() { + fi, err := s.cl.Execute(context.TODO(), "2") + s.Require().NoError(err) + ep := fi.GetEndpoint() + s.Require().Len(ep, 1) + fr, err := s.cl.DoGet(context.TODO(), ep[0].GetTicket()) + s.Require().NoError(err) + for fr.Next() { + } + err = fr.Err() + if s.Assert().Error(err) { + st := status.Convert(err) + s.Assert().Equal(codes.Internal, st.Code()) + s.Assert().Equal("test error", st.Message()) + } +} + type UnimplementedFlightSqlServerSuite struct { suite.Suite @@ -209,4 +362,5 @@ func (s *UnimplementedFlightSqlServerSuite) TestDoAction() { func TestBaseServer(t *testing.T) { suite.Run(t, new(UnimplementedFlightSqlServerSuite)) + suite.Run(t, new(FlightSqlServerSuite)) }