Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ protected void updateCommandBuilder(CommandStatementIngest.Builder builder) {
public static class PreparedStatement implements AutoCloseable {
private final FlightClient client;
private final ActionCreatePreparedStatementResult preparedStatementResult;
private ByteString handle;
private VectorSchemaRoot parameterBindingRoot;
private boolean isClosed;
private Schema resultSetSchema;
Expand All @@ -1229,6 +1230,7 @@ public static class PreparedStatement implements AutoCloseable {
preparedStatementResult =
FlightSqlUtils.unpackAndParseOrThrow(
preparedStatementResults.next().getBody(), ActionCreatePreparedStatementResult.class);
handle = preparedStatementResult.getPreparedStatementHandle();
isClosed = false;
}

Expand Down Expand Up @@ -1292,8 +1294,7 @@ public SchemaResult fetchSchema(CallOption... options) {
FlightDescriptor.command(
Any.pack(
CommandPreparedStatementQuery.newBuilder()
.setPreparedStatementHandle(
preparedStatementResult.getPreparedStatementHandle())
.setPreparedStatementHandle(handle)
.build())
.toByteArray());
return client.getSchema(descriptor, options);
Expand Down Expand Up @@ -1324,8 +1325,7 @@ public FlightInfo execute(final CallOption... options) {
FlightDescriptor.command(
Any.pack(
CommandPreparedStatementQuery.newBuilder()
.setPreparedStatementHandle(
preparedStatementResult.getPreparedStatementHandle())
.setPreparedStatementHandle(handle)
.build())
.toByteArray());

Expand All @@ -1339,12 +1339,16 @@ public FlightInfo execute(final CallOption... options) {
try (final ArrowBuf metadata = read.getApplicationMetadata()) {
final FlightSql.DoPutPreparedStatementResult doPutPreparedStatementResult =
FlightSql.DoPutPreparedStatementResult.parseFrom(metadata.nioBuffer());
final ByteString updatedHandle =
doPutPreparedStatementResult.getPreparedStatementHandle();
if (!updatedHandle.isEmpty()) {
handle = updatedHandle;
}
descriptor =
FlightDescriptor.command(
Any.pack(
CommandPreparedStatementQuery.newBuilder()
.setPreparedStatementHandle(
doPutPreparedStatementResult.getPreparedStatementHandle())
.setPreparedStatementHandle(handle)
.build())
.toByteArray());
}
Expand Down Expand Up @@ -1396,8 +1400,7 @@ public long executeUpdate(final CallOption... options) {
FlightDescriptor.command(
Any.pack(
CommandPreparedStatementUpdate.newBuilder()
.setPreparedStatementHandle(
preparedStatementResult.getPreparedStatementHandle())
.setPreparedStatementHandle(handle)
.build())
.toByteArray());
setParameters(parameterBindingRoot == null ? VectorSchemaRoot.of() : parameterBindingRoot);
Expand Down Expand Up @@ -1434,8 +1437,7 @@ public void close(final CallOption... options) {
FlightSqlUtils.FLIGHT_SQL_CLOSE_PREPARED_STATEMENT.getType(),
Any.pack(
ActionClosePreparedStatementRequest.newBuilder()
.setPreparedStatementHandle(
preparedStatementResult.getPreparedStatementHandle())
.setPreparedStatementHandle(handle)
.build())
.toByteArray());
final Iterator<Result> closePreparedStatementResults = client.doAction(action, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
import static org.junit.jupiter.api.Assertions.assertThrows;

import com.google.common.collect.ImmutableList;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
import java.util.ArrayList;
Expand All @@ -42,31 +46,38 @@
import java.util.stream.IntStream;
import org.apache.arrow.flight.CancelFlightInfoRequest;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.PutResult;
import org.apache.arrow.flight.RenewFlightEndpointRequest;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement;
import org.apache.arrow.flight.sql.FlightSqlColumnMetadata;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.flight.sql.NoOpFlightSqlProducer;
import org.apache.arrow.flight.sql.example.FlightSqlExample;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions.TableExistsOption;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions.TableNotExistOption;
import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity;
import org.apache.arrow.flight.sql.util.TableRef;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
Expand Down Expand Up @@ -1594,4 +1605,144 @@ public void testRenewEndpoint() {
new RenewFlightEndpointRequest(info.getEndpoints().get(0))));
assertEquals(FlightStatusCode.UNIMPLEMENTED, fre.status().code());
}

@Test
public void testPreparedStatementUsesUpdatedHandleAfterDoPut() throws Exception {
final ByteString originalHandle = ByteString.copyFromUtf8("original-handle");
final ByteString updatedHandle = ByteString.copyFromUtf8("updated-handle");

try (BufferAllocator testAllocator = new RootAllocator(Integer.MAX_VALUE)) {
final Schema paramSchema =
new Schema(singletonList(Field.nullable("id", MinorType.INT.getType())));
final UpdatedHandleFlightSqlProducer mockProducer =
new UpdatedHandleFlightSqlProducer(
testAllocator, originalHandle, updatedHandle, paramSchema);

try (FlightServer testServer =
FlightServer.builder(
testAllocator, Location.forGrpcInsecure(LOCALHOST, 0), mockProducer)
.build()
.start();
FlightSqlClient testClient =
new FlightSqlClient(
FlightClient.builder(
testAllocator, Location.forGrpcInsecure(LOCALHOST, testServer.getPort()))
.build())) {

try (PreparedStatement ps = testClient.prepare("test query with param=?");
VectorSchemaRoot params = VectorSchemaRoot.create(paramSchema, testAllocator)) {
final IntVector v = (IntVector) params.getVector(0);
v.setSafe(0, 42);
params.setRowCount(1);
ps.setParameters(params);
ps.execute(); // DoPut → server returns updatedHandle in DoPutPreparedStatementResult
} // close() called here via try-with-resources

assertAll(
() ->
assertThat(mockProducer.executeHandle)
.as("getFlightInfoPreparedStatement must use the updated handle")
.isEqualTo(updatedHandle),
() ->
assertThat(mockProducer.closeHandle)
.as("ClosePreparedStatement must use the updated handle")
.isEqualTo(updatedHandle));
}
}
}

/**
* Minimal producer that returns an updated prepared-statement handle in the {@code
* CommandPreparedStatementQuery} used with {@code DoPut} and records which handle is used in
* subsequent operations, allowing the test to verify that the client propagates the updated
* handle correctly.
*/
private static final class UpdatedHandleFlightSqlProducer extends NoOpFlightSqlProducer {

private final BufferAllocator allocator;
private final ByteString originalHandle;
private final ByteString updatedHandle;
private final ByteString serializedParamSchema;
ByteString executeHandle;
ByteString closeHandle;

UpdatedHandleFlightSqlProducer(
BufferAllocator allocator,
ByteString originalHandle,
ByteString updatedHandle,
Schema paramSchema) {
this.allocator = allocator;
this.originalHandle = originalHandle;
this.updatedHandle = updatedHandle;
this.serializedParamSchema = serializeSchema(paramSchema);
}

private static ByteString serializeSchema(Schema schema) {
try {
final ByteArrayOutputStream out = new ByteArrayOutputStream();
MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), schema);
return ByteString.copyFrom(out.toByteArray());
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public void createPreparedStatement(
FlightSql.ActionCreatePreparedStatementRequest request,
CallContext context,
StreamListener<Result> listener) {
listener.onNext(
new Result(
Any.pack(
FlightSql.ActionCreatePreparedStatementResult.newBuilder()
.setPreparedStatementHandle(originalHandle)
.setParameterSchema(serializedParamSchema)
.build())
.toByteArray()));
listener.onCompleted();
}

@Override
public Runnable acceptPutPreparedStatementQuery(
FlightSql.CommandPreparedStatementQuery command,
CallContext context,
FlightStream flightStream,
StreamListener<PutResult> ackStream) {
return () -> {
while (flightStream.next()) {
// consume parameter batches
}
final byte[] responseBytes =
FlightSql.DoPutPreparedStatementResult.newBuilder()
.setPreparedStatementHandle(updatedHandle)
.build()
.toByteArray();
final ArrowBuf buf = allocator.buffer(responseBytes.length);
buf.writeBytes(responseBytes);
try (PutResult putResult = PutResult.metadata(buf)) {
ackStream.onNext(putResult);
ackStream.onCompleted();
}
};
}

@Override
public FlightInfo getFlightInfoPreparedStatement(
FlightSql.CommandPreparedStatementQuery command,
CallContext context,
FlightDescriptor descriptor) {
executeHandle = command.getPreparedStatementHandle();
return new FlightInfo(new Schema(emptyList()), descriptor, emptyList(), -1, -1);
}

@Override
public void closePreparedStatement(
FlightSql.ActionClosePreparedStatementRequest request,
CallContext context,
StreamListener<Result> listener) {
closeHandle = request.getPreparedStatementHandle();
listener.onCompleted();
}
}
}
Loading