Skip to content

Commit

Permalink
ARROW-18294: [Java] Fix Flight SQL JDBC PreparedStatement#executeUpda…
Browse files Browse the repository at this point in the history
…te (#14616)

We need to implement a separate code path for executing a prepared statement that returns an  update count.

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
lidavidm committed Nov 9, 2022
1 parent 2ca0b0d commit a590b00
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 12 deletions.
Expand Up @@ -139,7 +139,7 @@ void reset() throws SQLException {
*
* @return the handler.
*/
ArrowFlightSqlClientHandler getClientHandler() throws SQLException {
ArrowFlightSqlClientHandler getClientHandler() {
return clientHandler;
}

Expand Down
Expand Up @@ -42,7 +42,7 @@
* Metadata handler for Arrow Flight.
*/
public class ArrowFlightMetaImpl extends MetaImpl {
private final Map<StatementHandle, PreparedStatement> statementHandlePreparedStatementMap;
private final Map<StatementHandleKey, PreparedStatement> statementHandlePreparedStatementMap;

/**
* Constructs a {@link MetaImpl} object specific for Arrow Flight.
Expand All @@ -67,7 +67,8 @@ static Signature newSignature(final String sql) {

@Override
public void closeStatement(final StatementHandle statementHandle) {
PreparedStatement preparedStatement = statementHandlePreparedStatementMap.remove(statementHandle);
PreparedStatement preparedStatement =
statementHandlePreparedStatementMap.remove(new StatementHandleKey(statementHandle));
// Testing if the prepared statement was created because the statement can be not created until this moment
if (preparedStatement != null) {
preparedStatement.close();
Expand All @@ -82,12 +83,25 @@ public void commit(final ConnectionHandle connectionHandle) {
@Override
public ExecuteResult execute(final StatementHandle statementHandle,
final List<TypedValue> typedValues, final long maxRowCount) {
// TODO Why is maxRowCount ignored?
Preconditions.checkNotNull(statementHandle.signature, "Signature not found.");
return new ExecuteResult(
Collections.singletonList(MetaResultSet.create(
statementHandle.connectionId, statementHandle.id,
true, statementHandle.signature, null)));
Preconditions.checkArgument(connection.id.equals(statementHandle.connectionId),
"Connection IDs are not consistent");
if (statementHandle.signature == null) {
// Update query
final StatementHandleKey key = new StatementHandleKey(statementHandle);
PreparedStatement preparedStatement = statementHandlePreparedStatementMap.get(key);
if (preparedStatement == null) {
throw new IllegalStateException("Prepared statement not found: " + statementHandle);
}
long updatedCount = preparedStatement.executeUpdate();
return new ExecuteResult(Collections.singletonList(MetaResultSet.count(statementHandle.connectionId,
statementHandle.id, updatedCount)));
} else {
// TODO Why is maxRowCount ignored?
return new ExecuteResult(
Collections.singletonList(MetaResultSet.create(
statementHandle.connectionId, statementHandle.id,
true, statementHandle.signature, null)));
}
}

@Override
Expand Down Expand Up @@ -121,6 +135,9 @@ public StatementHandle prepare(final ConnectionHandle connectionHandle,
final String query, final long maxRowCount) {
final StatementHandle handle = super.createStatement(connectionHandle);
handle.signature = newSignature(query);
final PreparedStatement preparedStatement =
((ArrowFlightConnection) connection).getClientHandler().prepare(query);
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
return handle;
}

Expand All @@ -143,7 +160,7 @@ public ExecuteResult prepareAndExecute(final StatementHandle handle,
final PreparedStatement preparedStatement =
((ArrowFlightConnection) connection).getClientHandler().prepare(query);
final StatementType statementType = preparedStatement.getType();
statementHandlePreparedStatementMap.put(handle, preparedStatement);
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
final Signature signature = newSignature(query);
final long updateCount =
statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1;
Expand Down Expand Up @@ -195,6 +212,47 @@ void setDefaultConnectionProperties() {
}

PreparedStatement getPreparedStatement(StatementHandle statementHandle) {
return statementHandlePreparedStatementMap.get(statementHandle);
return statementHandlePreparedStatementMap.get(new StatementHandleKey(statementHandle));
}

// Helper used to look up prepared statement instances later. Avatica doesn't give us the signature in
// an UPDATE code path so we can't directly use StatementHandle as a map key.
private static final class StatementHandleKey {
public final String connectionId;
public final int id;

StatementHandleKey(String connectionId, int id) {
this.connectionId = connectionId;
this.id = id;
}

StatementHandleKey(StatementHandle statementHandle) {
this.connectionId = statementHandle.connectionId;
this.id = statementHandle.id;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}

StatementHandleKey that = (StatementHandleKey) o;

if (id != that.id) {
return false;
}
return connectionId.equals(that.connectionId);
}

@Override
public int hashCode() {
int result = connectionId.hashCode();
result = 31 * result + id;
return result;
}
}
}
Expand Up @@ -18,13 +18,15 @@
package org.apache.arrow.driver.jdbc;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;

import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.ClassRule;
Expand All @@ -34,9 +36,10 @@

public class ArrowFlightPreparedStatementTest {

public static final MockFlightSqlProducer PRODUCER = CoreMockedSqlProducers.getLegacyProducer();
@ClassRule
public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = FlightServerTestRule
.createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer());
.createStandardTestRule(PRODUCER);

private static Connection connection;

Expand Down Expand Up @@ -75,4 +78,14 @@ public void testReturnColumnCount() throws SQLException {
collector.checkThat(6, equalTo(psmt.getMetaData().getColumnCount()));
}
}

@Test
public void testUpdateQuery() throws SQLException {
String query = "Fake update";
PRODUCER.addUpdateQuery(query, /*updatedRows*/42);
try (final PreparedStatement stmt = connection.prepareStatement(query)) {
int updated = stmt.executeUpdate();
assertEquals(42, updated);
}
}
}

0 comments on commit a590b00

Please sign in to comment.