Skip to content

Commit

Permalink
GH-39910: [Go] Add func to load prepared statement from ActionCreateP…
Browse files Browse the repository at this point in the history
…reparedStatementResult (#39913)

Currently, in order to create a PreparedStatement a DoAction call will always be made via the client. I need to be able to make a PreparedStatement from persisted data that will not trigger the DoAction call to the server.
* Closes: #39910

Authored-by: Alva Bandy <abandy@live.com>
Signed-off-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
abandy committed Feb 7, 2024
1 parent e83295b commit f609bb1
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
65 changes: 65 additions & 0 deletions go/arrow/flight/flightsql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,31 @@ func (c *Client) PrepareSubstrait(ctx context.Context, plan SubstraitPlan, opts
return parsePreparedStatementResponse(c, c.Alloc, stream)
}

func (c *Client) LoadPreparedStatementFromResult(result *CreatePreparedStatementResult) (*PreparedStatement, error) {
var (
err error
dsSchema, paramSchema *arrow.Schema
)
if result.DatasetSchema != nil {
dsSchema, err = flight.DeserializeSchema(result.DatasetSchema, c.Alloc)
if err != nil {
return nil, err
}
}
if result.ParameterSchema != nil {
paramSchema, err = flight.DeserializeSchema(result.ParameterSchema, c.Alloc)
if err != nil {
return nil, err
}
}
return &PreparedStatement{
client: c,
handle: result.PreparedStatementHandle,
datasetSchema: dsSchema,
paramSchema: paramSchema,
}, nil
}

func parsePreparedStatementResponse(c *Client, mem memory.Allocator, results pb.FlightService_DoActionClient) (*PreparedStatement, error) {
if err := results.CloseSend(); err != nil {
return nil, err
Expand Down Expand Up @@ -1027,6 +1052,46 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption
return p.client.getFlightInfo(ctx, desc, opts...)
}

// ExecutePut calls DoPut for the prepared statement on the server. If SetParameters
// has been called then the parameter bindings will be sent before execution.
//
// Will error if already closed.
func (p *PreparedStatement) ExecutePut(ctx context.Context, opts ...grpc.CallOption) error {
if p.closed {
return errors.New("arrow/flightsql: prepared statement already closed")
}

cmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: p.handle}

desc, err := descForCommand(cmd)
if err != nil {
return err
}

if p.hasBindParameters() {
pstream, err := p.client.Client.DoPut(ctx, opts...)
if err != nil {
return err
}

wr, err := p.writeBindParameters(pstream, desc)
if err != nil {
return err
}
if err = wr.Close(); err != nil {
return err
}
pstream.CloseSend()

// wait for the server to ack the result
if _, err = pstream.Recv(); err != nil && err != io.EOF {
return err
}
}

return nil
}

// ExecutePoll executes the prepared statement on the server and returns a PollInfo
// indicating the progress of execution.
//
Expand Down
30 changes: 30 additions & 0 deletions go/arrow/flight/flightsql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,36 @@ func (s *FlightSqlClientSuite) TestRenewFlightEndpoint() {
s.Equal(&mockedRenewedEndpoint, renewedEndpoint)
}

func (s *FlightSqlClientSuite) TestPreparedStatementLoadFromResult() {
const query = "query"

result := &pb.ActionCreatePreparedStatementResult{
PreparedStatementHandle: []byte(query),
}

parameterSchemaResult := arrow.NewSchema([]arrow.Field{{Name: "p_id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
result.ParameterSchema = flight.SerializeSchema(parameterSchemaResult, memory.DefaultAllocator)
datasetSchemaResult := arrow.NewSchema([]arrow.Field{{Name: "ds_id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
result.DatasetSchema = flight.SerializeSchema(datasetSchemaResult, memory.DefaultAllocator)

prepared, err := s.sqlClient.LoadPreparedStatementFromResult(result)
s.NoError(err)

s.Equal(string(prepared.Handle()), "query")

paramSchema := prepared.ParameterSchema()
paramRec, _, err := array.RecordFromJSON(memory.DefaultAllocator, paramSchema, strings.NewReader(`[{"p_id": 1}]`))
s.NoError(err)
defer paramRec.Release()

datasetSchema := prepared.DatasetSchema()
datasetRec, _, err := array.RecordFromJSON(memory.DefaultAllocator, datasetSchema, strings.NewReader(`[{"ds_id": 1}]`))
s.NoError(err)
defer datasetRec.Release()

s.Equal(string(prepared.Handle()), "query")
}

func TestFlightSqlClient(t *testing.T) {
suite.Run(t, new(FlightSqlClientSuite))
}
2 changes: 2 additions & 0 deletions go/arrow/flight/flightsql/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -852,3 +852,5 @@ const (
// cancellation request.
CancelResultNotCancellable = pb.ActionCancelQueryResult_CANCEL_RESULT_NOT_CANCELLABLE
)

type CreatePreparedStatementResult = pb.ActionCreatePreparedStatementResult

0 comments on commit f609bb1

Please sign in to comment.