Navigation Menu

Skip to content

Commit

Permalink
GH-33717: [Go] Flight SQL Server handle StreamChunk errors (#33718)
Browse files Browse the repository at this point in the history
### Rationale for this change

Errors returned in the `StreamChunk` channel of the Flight SQL server's `DoGet` method are currently being dropped rather than returned to gRPC client. This is a fix for that problem.

### What changes are included in this PR?

Fix the bug detailed above and add tesing for the fixed code path.

### Are these changes tested?

Yes, the majority of the PR add appropriate test cases.

### Are there any user-facing changes?

Only the fixed bug.

* Closes: #33717

Authored-by: Martin Hilton <mhilton@influxdata.com>
Signed-off-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
mhilton committed Jan 17, 2023
1 parent 98da819 commit 85a111f
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 1 deletion.
2 changes: 1 addition & 1 deletion go/arrow/flight/flightsql/server.go
Expand Up @@ -632,7 +632,7 @@ func (f *flightSqlServer) DoGet(request *flight.Ticket, stream flight.FlightServ

for chunk := range cc {
if chunk.Err != nil {
return err
return chunk.Err
}

wr.SetFlightDescriptor(chunk.Desc)
Expand Down
154 changes: 154 additions & 0 deletions go/arrow/flight/flightsql/server_test.go
Expand Up @@ -18,9 +18,12 @@ package flightsql_test

import (
"context"
"fmt"
"strings"
"testing"

"github.com/apache/arrow/go/v11/arrow"
"github.com/apache/arrow/go/v11/arrow/array"
"github.com/apache/arrow/go/v11/arrow/flight"
"github.com/apache/arrow/go/v11/arrow/flight/flightsql"
pb "github.com/apache/arrow/go/v11/arrow/flight/internal/flight"
Expand All @@ -36,6 +39,156 @@ import (

var dialOpts = []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}

type testServer struct {
flightsql.BaseServer
}

func (*testServer) GetFlightInfoStatement(ctx context.Context, q flightsql.StatementQuery, fd *flight.FlightDescriptor) (*flight.FlightInfo, error) {
ticket, err := flightsql.CreateStatementQueryTicket([]byte(q.GetQuery()))
if err != nil {
return nil, err
}
return &flight.FlightInfo{
FlightDescriptor: fd,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{Ticket: ticket},
}},
}, nil
}

func (*testServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (sc *arrow.Schema, cc <-chan flight.StreamChunk, err error) {
handle := string(ticket.GetStatementHandle())
switch handle {
case "1":
b := array.NewInt16Builder(memory.DefaultAllocator)
sc = arrow.NewSchema([]arrow.Field{{
Name: "t1",
Type: b.Type(),
Nullable: true,
}}, nil)
b.AppendNull()
c := make(chan flight.StreamChunk, 2)
c <- flight.StreamChunk{
Data: array.NewRecord(sc, []arrow.Array{b.NewArray()}, 1),
}
b.Append(1)
c <- flight.StreamChunk{
Data: array.NewRecord(sc, []arrow.Array{b.NewArray()}, 1),
}
close(c)
cc = c
case "2":
b := array.NewInt16Builder(memory.DefaultAllocator)
sc = arrow.NewSchema([]arrow.Field{{
Name: "t1",
Type: b.Type(),
Nullable: true,
}}, nil)
b.Append(2)
c := make(chan flight.StreamChunk, 2)
c <- flight.StreamChunk{
Data: array.NewRecord(sc, []arrow.Array{b.NewArray()}, 1),
}
c <- flight.StreamChunk{
Err: status.Error(codes.Internal, "test error"),
}
close(c)
cc = c
default:
err = fmt.Errorf("unknown statement handle: %s", handle)
}
return
}

type FlightSqlServerSuite struct {
suite.Suite

s flight.Server
cl *flightsql.Client
}

func (s *FlightSqlServerSuite) SetupSuite() {
s.s = flight.NewServerWithMiddleware(nil)
srv := flightsql.NewFlightServer(&testServer{})
s.s.RegisterFlightService(srv)
s.s.Init("localhost:0")

go s.s.Serve()
}

func (s *FlightSqlServerSuite) TearDownSuite() {
s.s.Shutdown()
}

func (s *FlightSqlServerSuite) SetupTest() {
cl, err := flightsql.NewClient(s.s.Addr().String(), nil, nil, dialOpts...)
s.Require().NoError(err)
s.cl = cl
}

func (s *FlightSqlServerSuite) TearDownTest() {
s.Require().NoError(s.cl.Close())
s.cl = nil
}

func (s *FlightSqlServerSuite) TestExecute() {
fi, err := s.cl.Execute(context.TODO(), "1")
s.Require().NoError(err)
ep := fi.GetEndpoint()
s.Require().Len(ep, 1)
fr, err := s.cl.DoGet(context.TODO(), ep[0].GetTicket())
s.Require().NoError(err)
var recs []arrow.Record
for fr.Next() {
rec := fr.Record()
rec.Retain()
defer rec.Release()
recs = append(recs, rec)
}
s.Require().NoError(fr.Err())
tbl := array.NewTableFromRecords(fr.Schema(), recs)
defer tbl.Release()
s.Assert().Equal(int64(2), tbl.NumRows())
s.Assert().Equal(int64(1), tbl.NumCols())
col := tbl.Column(0)
s.Assert().Equal("t1", col.Name())
s.Assert().Equal(2, col.Len())
s.Assert().Equal(1, col.NullN())
s.Assert().Equal(arrow.INT16, col.DataType().ID())
var n int
for _, arr := range col.Data().Chunks() {
data := array.NewInt16Data(arr.Data())
defer data.Release()
for i := 0; i < data.Len(); i++ {
switch n {
case 0:
s.Assert().Equal(true, data.IsNull(i))
case 1:
s.Assert().Equal(false, data.IsNull(i))
s.Assert().Equal(int16(1), data.Value(i))
}
n++
}
}
}

func (s *FlightSqlServerSuite) TestExecuteChunkError() {
fi, err := s.cl.Execute(context.TODO(), "2")
s.Require().NoError(err)
ep := fi.GetEndpoint()
s.Require().Len(ep, 1)
fr, err := s.cl.DoGet(context.TODO(), ep[0].GetTicket())
s.Require().NoError(err)
for fr.Next() {
}
err = fr.Err()
if s.Assert().Error(err) {
st := status.Convert(err)
s.Assert().Equal(codes.Internal, st.Code())
s.Assert().Equal("test error", st.Message())
}
}

type UnimplementedFlightSqlServerSuite struct {
suite.Suite

Expand Down Expand Up @@ -209,4 +362,5 @@ func (s *UnimplementedFlightSqlServerSuite) TestDoAction() {

func TestBaseServer(t *testing.T) {
suite.Run(t, new(UnimplementedFlightSqlServerSuite))
suite.Run(t, new(FlightSqlServerSuite))
}

0 comments on commit 85a111f

Please sign in to comment.