From 44a6e1930c5db68619faa1609c1692b73b4f709d Mon Sep 17 00:00:00 2001 From: Gustav Munkby Date: Tue, 17 Sep 2019 10:53:54 +0200 Subject: [PATCH 1/6] Split handling of network and query timeout The configuration used to only have an API timeout, and if my reading of the JDBC documentation is correct, this would correspond the specified network timeout for JDBC, although for proper handling it should probably also go into the HTTP client inside the AWS SDK in order to ensure we close sockets promptly and not leave operations running. As far as I understand, the query timeout should instead apply to the full query operation, and not the individual network requests. So far, we retain the old functionality for query timeouts in addition to adding support for network timeouts. --- .../java/io/burt/athena/AthenaConnection.java | 4 +-- .../java/io/burt/athena/AthenaDriver.java | 1 + .../java/io/burt/athena/AthenaStatement.java | 12 +++++--- .../ConcreteConnectionConfiguration.java | 30 ++++++++++++------- .../ConnectionConfiguration.java | 8 +++-- .../ConnectionConfigurationFactory.java | 4 +-- .../io/burt/athena/AthenaConnectionTest.java | 1 + .../io/burt/athena/AthenaDataSourceTest.java | 6 ++-- .../java/io/burt/athena/AthenaDriverTest.java | 4 +-- .../io/burt/athena/AthenaResultSetTest.java | 2 +- .../io/burt/athena/AthenaStatementTest.java | 25 ++++++++++++++++ .../ConfigurableConnectionConfiguration.java | 26 +++++++++++----- 12 files changed, 89 insertions(+), 34 deletions(-) diff --git a/src/main/java/io/burt/athena/AthenaConnection.java b/src/main/java/io/burt/athena/AthenaConnection.java index 01d03e6..45661d5 100644 --- a/src/main/java/io/burt/athena/AthenaConnection.java +++ b/src/main/java/io/burt/athena/AthenaConnection.java @@ -305,13 +305,13 @@ public Properties getClientInfo() throws SQLException { @Override public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { checkClosed(); - configuration = configuration.withTimeout(Duration.ofMillis(milliseconds)); + configuration = configuration.withNetworkTimeout(Duration.ofMillis(milliseconds)); } @Override public int getNetworkTimeout() throws SQLException { checkClosed(); - return (int) configuration.apiCallTimeout().toMillis(); + return (int) configuration.networkTimeout().toMillis(); } @Override diff --git a/src/main/java/io/burt/athena/AthenaDriver.java b/src/main/java/io/burt/athena/AthenaDriver.java index 44b989a..d90d17f 100644 --- a/src/main/java/io/burt/athena/AthenaDriver.java +++ b/src/main/java/io/burt/athena/AthenaDriver.java @@ -133,6 +133,7 @@ public Connection connect(String url, Properties connectionProperties) { workGroup, outputLocation, Duration.ofMinutes(1), + Duration.ofMinutes(30), ResultLoadingStrategy.S3 ); return new AthenaConnection(configuration); diff --git a/src/main/java/io/burt/athena/AthenaStatement.java b/src/main/java/io/burt/athena/AthenaStatement.java index 280328c..9b340d3 100644 --- a/src/main/java/io/burt/athena/AthenaStatement.java +++ b/src/main/java/io/burt/athena/AthenaStatement.java @@ -98,14 +98,14 @@ private String startQueryExecution(String sql) throws InterruptedException, Exec b.resultConfiguration(bb -> bb.outputLocation(configuration.outputLocation())); clientRequestTokenProvider.apply(sql).ifPresent(b::clientRequestToken); }) - .get(configuration.apiCallTimeout().toMillis(), TimeUnit.MILLISECONDS) + .get(networkTimeoutMillis(), TimeUnit.MILLISECONDS) .queryExecutionId(); } private Optional poll() throws SQLException, InterruptedException, ExecutionException, TimeoutException { QueryExecution queryExecution = athenaClient .getQueryExecution(b -> b.queryExecutionId(queryExecutionId)) - .get(configuration.apiCallTimeout().toMillis(), TimeUnit.MILLISECONDS) + .get(networkTimeoutMillis(), TimeUnit.MILLISECONDS) .queryExecution(); switch (queryExecution.status().state()) { case SUCCEEDED: @@ -118,6 +118,10 @@ private Optional poll() throws SQLException, InterruptedException, Ex } } + private long networkTimeoutMillis() { + return Math.min(configuration.networkTimeout().toMillis(), configuration.queryTimeout().toMillis()); + } + private ResultSet createResultSet(QueryExecution queryExecution) { return new AthenaResultSet( configuration.createResult(queryExecution), @@ -271,12 +275,12 @@ public void setEscapeProcessing(boolean enable) { @Override public int getQueryTimeout() { - return (int) configuration.apiCallTimeout().toMillis() / 1000; + return (int) configuration.queryTimeout().toMillis() / 1000; } @Override public void setQueryTimeout(int seconds) { - configuration = configuration.withTimeout(Duration.ofSeconds(seconds)); + configuration = configuration.withQueryTimeout(Duration.ofSeconds(seconds)); } @Override diff --git a/src/main/java/io/burt/athena/configuration/ConcreteConnectionConfiguration.java b/src/main/java/io/burt/athena/configuration/ConcreteConnectionConfiguration.java index 91bc10d..987d1a0 100644 --- a/src/main/java/io/burt/athena/configuration/ConcreteConnectionConfiguration.java +++ b/src/main/java/io/burt/athena/configuration/ConcreteConnectionConfiguration.java @@ -18,24 +18,26 @@ class ConcreteConnectionConfiguration implements ConnectionConfiguration { private final String databaseName; private final String workGroupName; private final String outputLocation; - private final Duration timeout; + private final Duration networkTimeout; + private final Duration queryTimeout; private final ResultLoadingStrategy resultLoadingStrategy; private AthenaAsyncClient athenaClient; private S3AsyncClient s3Client; private PollingStrategy pollingStrategy; - ConcreteConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration timeout, ResultLoadingStrategy resultLoadingStrategy) { + ConcreteConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration networkTimeout, Duration queryTimeout, ResultLoadingStrategy resultLoadingStrategy) { this.awsRegion = awsRegion; this.databaseName = databaseName; this.workGroupName = workGroupName; this.outputLocation = outputLocation; - this.timeout = timeout; + this.networkTimeout = networkTimeout; + this.queryTimeout = queryTimeout; this.resultLoadingStrategy = resultLoadingStrategy; } - private ConcreteConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration timeout, ResultLoadingStrategy resultLoadingStrategy, AthenaAsyncClient athenaClient, S3AsyncClient s3Client, PollingStrategy pollingStrategy) { - this(awsRegion, databaseName, workGroupName, outputLocation, timeout, resultLoadingStrategy); + private ConcreteConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration networkTimeout, Duration queryTimeout, ResultLoadingStrategy resultLoadingStrategy, AthenaAsyncClient athenaClient, S3AsyncClient s3Client, PollingStrategy pollingStrategy) { + this(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy); this.athenaClient = athenaClient; this.s3Client = s3Client; this.pollingStrategy = pollingStrategy; @@ -57,10 +59,13 @@ public String outputLocation() { } @Override - public Duration apiCallTimeout() { - return timeout; + public Duration networkTimeout() { + return networkTimeout; } + @Override + public Duration queryTimeout() { return queryTimeout; } + @Override public AthenaAsyncClient athenaClient() { if (athenaClient == null) { @@ -87,12 +92,17 @@ public PollingStrategy pollingStrategy() { @Override public ConnectionConfiguration withDatabaseName(String databaseName) { - return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, timeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy); + return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy); + } + + @Override + public ConnectionConfiguration withNetworkTimeout(Duration networkTimeout) { + return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy); } @Override - public ConnectionConfiguration withTimeout(Duration timeout) { - return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, timeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy); + public ConnectionConfiguration withQueryTimeout(Duration queryTimeout) { + return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy); } @Override diff --git a/src/main/java/io/burt/athena/configuration/ConnectionConfiguration.java b/src/main/java/io/burt/athena/configuration/ConnectionConfiguration.java index bf4714f..ba60406 100644 --- a/src/main/java/io/burt/athena/configuration/ConnectionConfiguration.java +++ b/src/main/java/io/burt/athena/configuration/ConnectionConfiguration.java @@ -15,7 +15,9 @@ public interface ConnectionConfiguration extends AutoCloseable { String outputLocation(); - Duration apiCallTimeout(); + Duration networkTimeout(); + + Duration queryTimeout(); AthenaAsyncClient athenaClient(); @@ -25,7 +27,9 @@ public interface ConnectionConfiguration extends AutoCloseable { ConnectionConfiguration withDatabaseName(String databaseName); - ConnectionConfiguration withTimeout(Duration timeout); + ConnectionConfiguration withNetworkTimeout(Duration timeout); + + ConnectionConfiguration withQueryTimeout(Duration timeout); Result createResult(QueryExecution queryExecution); } diff --git a/src/main/java/io/burt/athena/configuration/ConnectionConfigurationFactory.java b/src/main/java/io/burt/athena/configuration/ConnectionConfigurationFactory.java index 9e5a98f..550a989 100644 --- a/src/main/java/io/burt/athena/configuration/ConnectionConfigurationFactory.java +++ b/src/main/java/io/burt/athena/configuration/ConnectionConfigurationFactory.java @@ -5,8 +5,8 @@ import java.time.Duration; public class ConnectionConfigurationFactory { - public ConnectionConfiguration createConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration timeout, ResultLoadingStrategy resultLoadingStrategy) { - return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, timeout, resultLoadingStrategy); + public ConnectionConfiguration createConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration networkTimeout, Duration queryTimeout, ResultLoadingStrategy resultLoadingStrategy) { + return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy); } } diff --git a/src/test/java/io/burt/athena/AthenaConnectionTest.java b/src/test/java/io/burt/athena/AthenaConnectionTest.java index 93360e2..f202a52 100644 --- a/src/test/java/io/burt/athena/AthenaConnectionTest.java +++ b/src/test/java/io/burt/athena/AthenaConnectionTest.java @@ -74,6 +74,7 @@ private ConnectionConfiguration createConfiguration() { "test_wg", "s3://test/location", Duration.ofSeconds(1), + Duration.ofSeconds(1), () -> queryExecutionHelper, () -> null, () -> pollingStrategy, diff --git a/src/test/java/io/burt/athena/AthenaDataSourceTest.java b/src/test/java/io/burt/athena/AthenaDataSourceTest.java index 73b3de9..fbe2d6c 100644 --- a/src/test/java/io/burt/athena/AthenaDataSourceTest.java +++ b/src/test/java/io/burt/athena/AthenaDataSourceTest.java @@ -43,7 +43,7 @@ class AthenaDataSourceTest { @BeforeEach void setUp() { connectionConfigurationFactory = spy(new ConnectionConfigurationFactory()); - lenient().when(connectionConfigurationFactory.createConnectionConfiguration(any(), any(), any(), any(), any(), any())).then(invocation -> { + lenient().when(connectionConfigurationFactory.createConnectionConfiguration(any(), any(), any(), any(), any(), any(), any())).then(invocation -> { ConnectionConfiguration cc = (ConnectionConfiguration) invocation.callRealMethod(); cc = spy(cc); lenient().when(cc.athenaClient()).thenReturn(queryExecutionHelper); @@ -72,7 +72,7 @@ void returnsAnAthenaConnection() throws Exception { void createsAnAthenaClientForTheConfiguredRegion() throws Exception { dataSource.setRegion("sa-east-1"); dataSource.getConnection(); - verify(connectionConfigurationFactory).createConnectionConfiguration(eq(Region.SA_EAST_1), any(), any(), any(), any(), any()); + verify(connectionConfigurationFactory).createConnectionConfiguration(eq(Region.SA_EAST_1), any(), any(), any(), any(), any(), any()); } @Test @@ -105,7 +105,7 @@ class WhenGivenAString { void setsTheRegionOfTheAthenaClient() throws Exception { dataSource.setRegion("ca-central-1"); dataSource.getConnection(); - verify(connectionConfigurationFactory).createConnectionConfiguration(eq(Region.CA_CENTRAL_1), any(), any(), any(), any(), any()); + verify(connectionConfigurationFactory).createConnectionConfiguration(eq(Region.CA_CENTRAL_1), any(), any(), any(), any(), any(), any()); } } } diff --git a/src/test/java/io/burt/athena/AthenaDriverTest.java b/src/test/java/io/burt/athena/AthenaDriverTest.java index 4184755..1bdb46d 100644 --- a/src/test/java/io/burt/athena/AthenaDriverTest.java +++ b/src/test/java/io/burt/athena/AthenaDriverTest.java @@ -47,7 +47,7 @@ class AthenaDriverTest implements PomVersionLoader { @BeforeEach void setUpDriver() { connectionConfigurationFactory = spy(new ConnectionConfigurationFactory()); - lenient().when(connectionConfigurationFactory.createConnectionConfiguration(any(), any(), any(), any(), any(), any())).then(invocation -> { + lenient().when(connectionConfigurationFactory.createConnectionConfiguration(any(), any(), any(), any(), any(), any(), any())).then(invocation -> { ConnectionConfiguration cc = (ConnectionConfiguration) invocation.callRealMethod(); cc = spy(cc); lenient().when(cc.athenaClient()).thenReturn(queryExecutionHelper); @@ -95,7 +95,7 @@ void usesTheDefaultDatabaseWhenThereIsNoDatabaseNameInTheUrl() throws Exception @Test void usesTheAwsRegionFromTheProperties() { driver.connect("jdbc:athena", defaultProperties); - verify(connectionConfigurationFactory).createConnectionConfiguration(eq(Region.AP_SOUTHEAST_1), any(), any(), any(), any(), any()); + verify(connectionConfigurationFactory).createConnectionConfiguration(eq(Region.AP_SOUTHEAST_1), any(), any(), any(), any(), any(), any()); } @Test diff --git a/src/test/java/io/burt/athena/AthenaResultSetTest.java b/src/test/java/io/burt/athena/AthenaResultSetTest.java index 2846d81..7156591 100644 --- a/src/test/java/io/burt/athena/AthenaResultSetTest.java +++ b/src/test/java/io/burt/athena/AthenaResultSetTest.java @@ -74,7 +74,7 @@ class AthenaResultSetTest { @BeforeEach void setUpResultSet() { parentStatement = mock(AthenaStatement.class); - connectionConfiguration = new ConfigurableConnectionConfiguration("test_db", "test_wg", "s3://test/location", Duration.ofMillis(10), () -> null, () -> null, () -> null, (q) -> null); + connectionConfiguration = new ConfigurableConnectionConfiguration("test_db", "test_wg", "s3://test/location", Duration.ofMillis(10), Duration.ofMillis(10), () -> null, () -> null, () -> null, (q) -> null); QueryExecution queryExecution = QueryExecution.builder().queryExecutionId("Q1234").build(); queryResultsHelper = new GetQueryResultsHelper(); Result result = new PreloadingStandardResult(queryResultsHelper, queryExecution, StandardResult.MAX_FETCH_SIZE, Duration.ofSeconds(1)); diff --git a/src/test/java/io/burt/athena/AthenaStatementTest.java b/src/test/java/io/burt/athena/AthenaStatementTest.java index 338497e..2374b89 100644 --- a/src/test/java/io/burt/athena/AthenaStatementTest.java +++ b/src/test/java/io/burt/athena/AthenaStatementTest.java @@ -74,6 +74,7 @@ ConnectionConfiguration createConfiguration() { "test_wg", "s3://test/location", Duration.ofSeconds(60), + Duration.ofSeconds(60), () -> queryExecutionHelper, () -> null, () -> pollingStrategy, @@ -188,6 +189,30 @@ void executeAgainClosesPreviousResultSet() throws Exception { assertTrue(rs1.isClosed()); assertFalse(rs2.isClosed()); } + + @Nested + class WhenInterruptedWhileSleeping { + @BeforeEach + void setUp() { + statement = new AthenaStatement(createConfiguration().withNetworkTimeout(Duration.ofMillis(10))); + } + + @Test + void throwsWhenStartQueryExecutionDurationExceedsNetworkTimeout() { + queryExecutionHelper.queueStartQueryResponse("Q1234"); + queryExecutionHelper.queueGetQueryExecutionResponse(QueryExecutionState.SUCCEEDED); + queryExecutionHelper.delayStartQueryExecutionResponses(Duration.ofMillis(100)); + assertThrows(SQLTimeoutException.class, () -> statement.executeQuery("SELECT 1")); + } + + @Test + void throwsWhenGetQueryExecutionDurationExceedsNetworkTimeout() { + queryExecutionHelper.queueStartQueryResponse("Q1234"); + queryExecutionHelper.queueGetQueryExecutionResponse(QueryExecutionState.SUCCEEDED); + queryExecutionHelper.delayGetQueryExecutionResponses(Duration.ofMillis(100)); + assertThrows(SQLTimeoutException.class, () -> statement.executeQuery("SELECT 1")); + } + } } @Nested diff --git a/src/test/java/io/burt/athena/support/ConfigurableConnectionConfiguration.java b/src/test/java/io/burt/athena/support/ConfigurableConnectionConfiguration.java index bedeaeb..79625c8 100644 --- a/src/test/java/io/burt/athena/support/ConfigurableConnectionConfiguration.java +++ b/src/test/java/io/burt/athena/support/ConfigurableConnectionConfiguration.java @@ -15,17 +15,19 @@ public class ConfigurableConnectionConfiguration implements ConnectionConfigurat private final String databaseName; private final String workGroupName; private final String outputLocation; - private final Duration timeout; + private final Duration networkTimeout; + private final Duration queryTimeout; private final Supplier athenaClientFactory; private final Supplier s3ClientFactory; private final Supplier pollingStrategyFactory; private final Function resultFactory; - public ConfigurableConnectionConfiguration(String databaseName, String workGroupName, String outputLocation, Duration timeout, Supplier athenaClientFactory, Supplier s3ClientFactory, Supplier pollingStrategyFactory, Function resultFactory) { + public ConfigurableConnectionConfiguration(String databaseName, String workGroupName, String outputLocation, Duration networkTimeout, Duration queryTimeout, Supplier athenaClientFactory, Supplier s3ClientFactory, Supplier pollingStrategyFactory, Function resultFactory) { this.databaseName = databaseName; this.workGroupName = workGroupName; this.outputLocation = outputLocation; - this.timeout = timeout; + this.networkTimeout = networkTimeout; + this.queryTimeout = queryTimeout; this.athenaClientFactory = athenaClientFactory; this.s3ClientFactory = s3ClientFactory; this.pollingStrategyFactory = pollingStrategyFactory; @@ -48,10 +50,13 @@ public String outputLocation() { } @Override - public Duration apiCallTimeout() { - return timeout; + public Duration networkTimeout() { + return networkTimeout; } + @Override + public Duration queryTimeout() { return queryTimeout; } + @Override public AthenaAsyncClient athenaClient() { return athenaClientFactory.get(); @@ -69,12 +74,17 @@ public PollingStrategy pollingStrategy() { @Override public ConnectionConfiguration withDatabaseName(String newDatabaseName) { - return new ConfigurableConnectionConfiguration(newDatabaseName, workGroupName, outputLocation, timeout, athenaClientFactory, s3ClientFactory, pollingStrategyFactory, resultFactory); + return new ConfigurableConnectionConfiguration(newDatabaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, athenaClientFactory, s3ClientFactory, pollingStrategyFactory, resultFactory); + } + + @Override + public ConnectionConfiguration withNetworkTimeout(Duration newNetworkTimeout) { + return new ConfigurableConnectionConfiguration(databaseName, workGroupName, outputLocation, newNetworkTimeout, queryTimeout, athenaClientFactory, s3ClientFactory, pollingStrategyFactory, resultFactory); } @Override - public ConnectionConfiguration withTimeout(Duration newTimeout) { - return new ConfigurableConnectionConfiguration(databaseName, workGroupName, outputLocation, newTimeout, athenaClientFactory, s3ClientFactory, pollingStrategyFactory, resultFactory); + public ConnectionConfiguration withQueryTimeout(Duration newQueryTimeout) { + return new ConfigurableConnectionConfiguration(databaseName, workGroupName, outputLocation, networkTimeout, newQueryTimeout, athenaClientFactory, s3ClientFactory, pollingStrategyFactory, resultFactory); } @Override From 5cab2b311f55266789a915fab7bb456ef0efd424 Mon Sep 17 00:00:00 2001 From: Gustav Munkby Date: Fri, 27 Sep 2019 21:17:27 +0200 Subject: [PATCH 2/6] Add a query-timeout based deadline to AthenStatement This replaces the CompletableFuture.get operations with timeouts corresponding to the time remaining, and propagates the deadline across the polling. One of the tests in AthenaStatement is now very slow, because you cannot set a non-zero query timeout lower than one second, and it wants to detect the distinction between one operation being delayed and multiple different operations being delayed. --- .../java/io/burt/athena/AthenaConnection.java | 3 +- .../java/io/burt/athena/AthenaStatement.java | 23 ++++++++----- .../io/burt/athena/AthenaStatementTest.java | 22 +++++++++--- .../athena/support/QueryExecutionHelper.java | 8 +++++ .../io/burt/athena/support/TestClock.java | 34 +++++++++++++++++++ 5 files changed, 75 insertions(+), 15 deletions(-) create mode 100644 src/test/java/io/burt/athena/support/TestClock.java diff --git a/src/main/java/io/burt/athena/AthenaConnection.java b/src/main/java/io/burt/athena/AthenaConnection.java index 45661d5..0233a82 100644 --- a/src/main/java/io/burt/athena/AthenaConnection.java +++ b/src/main/java/io/burt/athena/AthenaConnection.java @@ -19,6 +19,7 @@ import java.sql.Savepoint; import java.sql.Statement; import java.sql.Struct; +import java.time.Clock; import java.time.Duration; import java.util.Collections; import java.util.Map; @@ -45,7 +46,7 @@ private void checkClosed() throws SQLException { @Override public Statement createStatement() throws SQLException { checkClosed(); - return new AthenaStatement(configuration); + return new AthenaStatement(configuration, Clock.systemDefaultZone()); } @Override diff --git a/src/main/java/io/burt/athena/AthenaStatement.java b/src/main/java/io/burt/athena/AthenaStatement.java index 9b340d3..666263f 100644 --- a/src/main/java/io/burt/athena/AthenaStatement.java +++ b/src/main/java/io/burt/athena/AthenaStatement.java @@ -11,7 +11,9 @@ import java.sql.SQLTimeoutException; import java.sql.SQLWarning; import java.sql.Statement; +import java.time.Clock; import java.time.Duration; +import java.time.Instant; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -20,6 +22,7 @@ public class AthenaStatement implements Statement { private final AthenaAsyncClient athenaClient; + private Clock clock; private ConnectionConfiguration configuration; private String queryExecutionId; @@ -27,9 +30,10 @@ public class AthenaStatement implements Statement { private Function> clientRequestTokenProvider; private boolean open; - AthenaStatement(ConnectionConfiguration configuration) { + AthenaStatement(ConnectionConfiguration configuration, Clock clock) { this.configuration = configuration; this.athenaClient = configuration.athenaClient(); + this.clock = clock; this.queryExecutionId = null; this.currentResultSet = null; this.clientRequestTokenProvider = sql -> Optional.empty(); @@ -74,8 +78,9 @@ public boolean execute(String sql) throws SQLException { currentResultSet = null; } try { - queryExecutionId = startQueryExecution(sql); - currentResultSet = configuration.pollingStrategy().pollUntilCompleted(this::poll); + Instant deadline = clock.instant().plus(configuration.queryTimeout()); + queryExecutionId = startQueryExecution(sql, deadline); + currentResultSet = configuration.pollingStrategy().pollUntilCompleted(() -> poll(deadline)); return currentResultSet != null; } catch (InterruptedException ie) { Thread.currentThread().interrupt(); @@ -89,7 +94,7 @@ public boolean execute(String sql) throws SQLException { } } - private String startQueryExecution(String sql) throws InterruptedException, ExecutionException, TimeoutException { + private String startQueryExecution(String sql, Instant deadline) throws InterruptedException, ExecutionException, TimeoutException { return athenaClient .startQueryExecution(b -> { b.queryString(sql); @@ -98,14 +103,14 @@ private String startQueryExecution(String sql) throws InterruptedException, Exec b.resultConfiguration(bb -> bb.outputLocation(configuration.outputLocation())); clientRequestTokenProvider.apply(sql).ifPresent(b::clientRequestToken); }) - .get(networkTimeoutMillis(), TimeUnit.MILLISECONDS) + .get(networkTimeoutMillis(deadline), TimeUnit.MILLISECONDS) .queryExecutionId(); } - private Optional poll() throws SQLException, InterruptedException, ExecutionException, TimeoutException { + private Optional poll(Instant deadline) throws SQLException, InterruptedException, ExecutionException, TimeoutException { QueryExecution queryExecution = athenaClient .getQueryExecution(b -> b.queryExecutionId(queryExecutionId)) - .get(networkTimeoutMillis(), TimeUnit.MILLISECONDS) + .get(networkTimeoutMillis(deadline), TimeUnit.MILLISECONDS) .queryExecution(); switch (queryExecution.status().state()) { case SUCCEEDED: @@ -118,8 +123,8 @@ private Optional poll() throws SQLException, InterruptedException, Ex } } - private long networkTimeoutMillis() { - return Math.min(configuration.networkTimeout().toMillis(), configuration.queryTimeout().toMillis()); + private long networkTimeoutMillis(Instant deadline) { + return Math.max(0, Math.min(configuration.networkTimeout().toMillis(), Duration.between(clock.instant(), deadline).toMillis())); } private ResultSet createResultSet(QueryExecution queryExecution) { diff --git a/src/test/java/io/burt/athena/AthenaStatementTest.java b/src/test/java/io/burt/athena/AthenaStatementTest.java index 2374b89..8b02504 100644 --- a/src/test/java/io/burt/athena/AthenaStatementTest.java +++ b/src/test/java/io/burt/athena/AthenaStatementTest.java @@ -5,6 +5,7 @@ import io.burt.athena.result.Result; import io.burt.athena.support.ConfigurableConnectionConfiguration; import io.burt.athena.support.QueryExecutionHelper; +import io.burt.athena.support.TestClock; import io.burt.athena.support.TestNameGenerator; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; @@ -48,13 +49,15 @@ class AthenaStatementTest { private AthenaStatement statement; private PollingStrategy pollingStrategy; private QueryExecution resultFactoryQueryExecution; + private TestClock clock; @BeforeEach void setUpStatement() { result = mock(Result.class); pollingStrategy = createPollingStrategy(); - queryExecutionHelper = new QueryExecutionHelper(); - statement = new AthenaStatement(createConfiguration()); + clock = new TestClock(); + queryExecutionHelper = new QueryExecutionHelper(clock); + statement = new AthenaStatement(createConfiguration(), clock); } PollingStrategy createPollingStrategy() { @@ -194,7 +197,7 @@ void executeAgainClosesPreviousResultSet() throws Exception { class WhenInterruptedWhileSleeping { @BeforeEach void setUp() { - statement = new AthenaStatement(createConfiguration().withNetworkTimeout(Duration.ofMillis(10))); + statement = new AthenaStatement(createConfiguration().withNetworkTimeout(Duration.ofMillis(10)), clock); } @Test @@ -494,22 +497,31 @@ class SetQueryTimeout { @BeforeEach void setUp() { queryExecutionHelper.queueStartQueryResponse("Q1234"); + queryExecutionHelper.queueGetQueryExecutionResponse(QueryExecutionState.QUEUED); queryExecutionHelper.queueGetQueryExecutionResponse(QueryExecutionState.SUCCEEDED); } @Test - void setsTheTimeoutUsedForApiCalls1() { + void setsTheTimeoutUsedForStartQueryExecution() throws SQLException { queryExecutionHelper.delayStartQueryExecutionResponses(Duration.ofMillis(10)); statement.setQueryTimeout(0); assertThrows(SQLTimeoutException.class, () -> statement.executeQuery("SELECT 1")); } @Test - void setsTheTimeoutUsedForApiCalls2() { + void setsTheTimeoutUsedForGetQueryExecution() throws SQLException { queryExecutionHelper.delayGetQueryExecutionResponses(Duration.ofMillis(10)); statement.setQueryTimeout(0); assertThrows(SQLTimeoutException.class, () -> statement.executeQuery("SELECT 1")); } + + @Test + void setsTheTimeoutUsedForQuerySpanningMultipleOperations() { + queryExecutionHelper.delayStartQueryExecutionResponses(Duration.ofMillis(400)); + queryExecutionHelper.delayGetQueryExecutionResponses(Duration.ofMillis(400)); + statement.setQueryTimeout(1); + assertThrows(SQLTimeoutException.class, () -> statement.executeQuery("SELECT 1")); + } } @Nested diff --git a/src/test/java/io/burt/athena/support/QueryExecutionHelper.java b/src/test/java/io/burt/athena/support/QueryExecutionHelper.java index a882d80..1563018 100644 --- a/src/test/java/io/burt/athena/support/QueryExecutionHelper.java +++ b/src/test/java/io/burt/athena/support/QueryExecutionHelper.java @@ -40,8 +40,14 @@ public class QueryExecutionHelper implements AthenaAsyncClient { private Duration getQueryResultsDelay; private Lock getQueryExecutionBlocker; private boolean open; + private TestClock clock; public QueryExecutionHelper() { + this(new TestClock()); + } + + public QueryExecutionHelper(TestClock clock) { + this.clock = clock; this.startQueryRequests = new LinkedList<>(); this.getQueryExecutionRequests = new LinkedList<>(); this.getQueryResultsRequests = new LinkedList<>(); @@ -147,6 +153,8 @@ private CompletableFuture maybeDelayResponse(CompletableFuture future, newFuture.completeExceptionally(e.getCause()); } catch (Exception e) { newFuture.completeExceptionally(e); + } finally { + clock.tick(delay); } }, delay.toMillis(), diff --git a/src/test/java/io/burt/athena/support/TestClock.java b/src/test/java/io/burt/athena/support/TestClock.java new file mode 100644 index 0000000..82154cd --- /dev/null +++ b/src/test/java/io/burt/athena/support/TestClock.java @@ -0,0 +1,34 @@ +package io.burt.athena.support; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; + +public class TestClock extends Clock { + private long millis; + + @Override + public ZoneId getZone() { + return ZoneId.of("UTC"); + } + + @Override + public Clock withZone(ZoneId zone) { + return null; + } + + @Override + public Instant instant() { + return Instant.ofEpochMilli(millis()); + } + + @Override + public long millis() { + return this.millis; + } + + public void tick(Duration duration) { + this.millis += duration.toMillis(); + } +} From cde11d6daa0fcbbba951abfa084ac44910b6d8eb Mon Sep 17 00:00:00 2001 From: Gustav Munkby Date: Thu, 19 Sep 2019 11:19:23 +0200 Subject: [PATCH 3/6] Speed up AthenaStatementTest By exposing a non-standard operation for setting a subsecond query timeout. --- src/main/java/io/burt/athena/AthenaStatement.java | 6 +++++- src/test/java/io/burt/athena/AthenaStatementTest.java | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/main/java/io/burt/athena/AthenaStatement.java b/src/main/java/io/burt/athena/AthenaStatement.java index 666263f..c3b0606 100644 --- a/src/main/java/io/burt/athena/AthenaStatement.java +++ b/src/main/java/io/burt/athena/AthenaStatement.java @@ -278,6 +278,10 @@ public void setEscapeProcessing(boolean enable) { throw new UnsupportedOperationException("Not implemented"); } + public void setQueryTimeout(Duration timeout) { + configuration = configuration.withQueryTimeout(timeout); + } + @Override public int getQueryTimeout() { return (int) configuration.queryTimeout().toMillis() / 1000; @@ -285,7 +289,7 @@ public int getQueryTimeout() { @Override public void setQueryTimeout(int seconds) { - configuration = configuration.withQueryTimeout(Duration.ofSeconds(seconds)); + setQueryTimeout(Duration.ofSeconds(seconds)); } @Override diff --git a/src/test/java/io/burt/athena/AthenaStatementTest.java b/src/test/java/io/burt/athena/AthenaStatementTest.java index 8b02504..c55314c 100644 --- a/src/test/java/io/burt/athena/AthenaStatementTest.java +++ b/src/test/java/io/burt/athena/AthenaStatementTest.java @@ -517,9 +517,9 @@ void setsTheTimeoutUsedForGetQueryExecution() throws SQLException { @Test void setsTheTimeoutUsedForQuerySpanningMultipleOperations() { - queryExecutionHelper.delayStartQueryExecutionResponses(Duration.ofMillis(400)); - queryExecutionHelper.delayGetQueryExecutionResponses(Duration.ofMillis(400)); - statement.setQueryTimeout(1); + queryExecutionHelper.delayStartQueryExecutionResponses(Duration.ofMillis(40)); + queryExecutionHelper.delayGetQueryExecutionResponses(Duration.ofMillis(40)); + statement.setQueryTimeout(Duration.ofMillis(100)); assertThrows(SQLTimeoutException.class, () -> statement.executeQuery("SELECT 1")); } } From 07d207e9d134173de9fc1df8943c479d3ac52b66 Mon Sep 17 00:00:00 2001 From: Gustav Munkby Date: Fri, 27 Sep 2019 21:25:04 +0200 Subject: [PATCH 4/6] Add a query-timeout based deadline to polling The tests in AthenaStatement are not yet addressed, as they these tests use a custom polling strategy and from their perspective, an encountered deadline in the polling is the same as a TimeoutException from the get query results operation. --- .../java/io/burt/athena/AthenaStatement.java | 2 +- .../polling/BackoffPollingStrategy.java | 30 +++++++-- .../polling/FixedDelayPollingStrategy.java | 22 ++++-- .../burt/athena/polling/PollingCallback.java | 3 +- .../burt/athena/polling/PollingStrategy.java | 3 +- .../io/burt/athena/AthenaConnectionTest.java | 4 +- .../io/burt/athena/AthenaStatementTest.java | 6 +- .../polling/BackoffPollingStrategyTest.java | 67 ++++++++++++++----- .../FixedDelayPollingStrategyTest.java | 55 +++++++++++---- 9 files changed, 141 insertions(+), 51 deletions(-) diff --git a/src/main/java/io/burt/athena/AthenaStatement.java b/src/main/java/io/burt/athena/AthenaStatement.java index c3b0606..7c90008 100644 --- a/src/main/java/io/burt/athena/AthenaStatement.java +++ b/src/main/java/io/burt/athena/AthenaStatement.java @@ -80,7 +80,7 @@ public boolean execute(String sql) throws SQLException { try { Instant deadline = clock.instant().plus(configuration.queryTimeout()); queryExecutionId = startQueryExecution(sql, deadline); - currentResultSet = configuration.pollingStrategy().pollUntilCompleted(() -> poll(deadline)); + currentResultSet = configuration.pollingStrategy().pollUntilCompleted(this::poll, deadline); return currentResultSet != null; } catch (InterruptedException ie) { Thread.currentThread().interrupt(); diff --git a/src/main/java/io/burt/athena/polling/BackoffPollingStrategy.java b/src/main/java/io/burt/athena/polling/BackoffPollingStrategy.java index 1f5303b..41e80f9 100644 --- a/src/main/java/io/burt/athena/polling/BackoffPollingStrategy.java +++ b/src/main/java/io/burt/athena/polling/BackoffPollingStrategy.java @@ -2,7 +2,9 @@ import java.sql.ResultSet; import java.sql.SQLException; +import java.time.Clock; import java.time.Duration; +import java.time.Instant; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -13,35 +15,49 @@ class BackoffPollingStrategy implements PollingStrategy { private final Duration maxDelay; private final long factor; private final Sleeper sleeper; + private final Clock clock; BackoffPollingStrategy(Duration firstDelay, Duration maxDelay) { - this(firstDelay, maxDelay, 2L, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis())); + this(firstDelay, maxDelay, 2L, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis()), Clock.systemDefaultZone()); } BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, long factor) { - this(firstDelay, maxDelay, factor, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis())); + this(firstDelay, maxDelay, factor, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis()), Clock.systemDefaultZone()); } BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, Sleeper sleeper) { - this(firstDelay, maxDelay, 2L, sleeper); + this(firstDelay, maxDelay, 2L, sleeper, Clock.systemDefaultZone()); } - BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, long factor, Sleeper sleeper) { + BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, Sleeper sleeper, Clock clock) { + this(firstDelay, maxDelay, 2L, sleeper, clock); + } + + BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, long factor, Sleeper sleeper, Clock clock) { this.firstDelay = firstDelay; this.maxDelay = maxDelay; this.factor = factor; this.sleeper = sleeper; + this.clock = clock; } @Override - public ResultSet pollUntilCompleted(PollingCallback callback) throws SQLException, TimeoutException, ExecutionException, InterruptedException { + public ResultSet pollUntilCompleted(PollingCallback callback, Instant deadline) throws SQLException, TimeoutException, ExecutionException, InterruptedException { Duration nextDelay = firstDelay; while (true) { - Optional resultSet = callback.poll(); + Optional resultSet = callback.poll(deadline); if (resultSet.isPresent()) { return resultSet.get(); } else { - sleeper.sleep(nextDelay); + Duration beforeDeadline = Duration.between(clock.instant(), deadline); + if (beforeDeadline.compareTo(nextDelay) < 0) { + if (beforeDeadline.isNegative()) { + throw new TimeoutException(); + } + sleeper.sleep(beforeDeadline); + } else { + sleeper.sleep(nextDelay); + } nextDelay = nextDelay.multipliedBy(factor); if (nextDelay.compareTo(maxDelay) > 0) { nextDelay = maxDelay; diff --git a/src/main/java/io/burt/athena/polling/FixedDelayPollingStrategy.java b/src/main/java/io/burt/athena/polling/FixedDelayPollingStrategy.java index f989f89..85ab627 100644 --- a/src/main/java/io/burt/athena/polling/FixedDelayPollingStrategy.java +++ b/src/main/java/io/burt/athena/polling/FixedDelayPollingStrategy.java @@ -2,7 +2,9 @@ import java.sql.ResultSet; import java.sql.SQLException; +import java.time.Clock; import java.time.Duration; +import java.time.Instant; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -11,24 +13,34 @@ public class FixedDelayPollingStrategy implements PollingStrategy { private final Duration delay; private final Sleeper sleeper; + private Clock clock; FixedDelayPollingStrategy(Duration delay) { - this(delay, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis())); + this(delay, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis()), Clock.systemDefaultZone()); } - FixedDelayPollingStrategy(Duration delay, Sleeper sleeper) { + FixedDelayPollingStrategy(Duration delay, Sleeper sleeper, Clock clock) { this.delay = delay; this.sleeper = sleeper; + this.clock = clock; } @Override - public ResultSet pollUntilCompleted(PollingCallback callback) throws SQLException, TimeoutException, ExecutionException, InterruptedException { + public ResultSet pollUntilCompleted(PollingCallback callback, Instant deadline) throws SQLException, TimeoutException, ExecutionException, InterruptedException { while (true) { - Optional resultSet = callback.poll(); + Optional resultSet = callback.poll(deadline); if (resultSet.isPresent()) { return resultSet.get(); } else { - sleeper.sleep(delay); + Duration beforeDeadline = Duration.between(clock.instant(), deadline); + if (beforeDeadline.compareTo(delay) < 0) { + if (beforeDeadline.isNegative()) { + throw new TimeoutException(); + } + sleeper.sleep(beforeDeadline); + } else { + sleeper.sleep(delay); + } } } } diff --git a/src/main/java/io/burt/athena/polling/PollingCallback.java b/src/main/java/io/burt/athena/polling/PollingCallback.java index 3729d62..0e3dc70 100644 --- a/src/main/java/io/burt/athena/polling/PollingCallback.java +++ b/src/main/java/io/burt/athena/polling/PollingCallback.java @@ -2,11 +2,12 @@ import java.sql.ResultSet; import java.sql.SQLException; +import java.time.Instant; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; @FunctionalInterface public interface PollingCallback { - Optional poll() throws SQLException, TimeoutException, ExecutionException, InterruptedException; + Optional poll(Instant deadline) throws SQLException, TimeoutException, ExecutionException, InterruptedException; } diff --git a/src/main/java/io/burt/athena/polling/PollingStrategy.java b/src/main/java/io/burt/athena/polling/PollingStrategy.java index baf9727..4d7be45 100644 --- a/src/main/java/io/burt/athena/polling/PollingStrategy.java +++ b/src/main/java/io/burt/athena/polling/PollingStrategy.java @@ -2,9 +2,10 @@ import java.sql.ResultSet; import java.sql.SQLException; +import java.time.Instant; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; public interface PollingStrategy { - ResultSet pollUntilCompleted(PollingCallback callback) throws SQLException, TimeoutException, ExecutionException, InterruptedException; + ResultSet pollUntilCompleted(PollingCallback callback, Instant deadline) throws SQLException, TimeoutException, ExecutionException, InterruptedException; } diff --git a/src/test/java/io/burt/athena/AthenaConnectionTest.java b/src/test/java/io/burt/athena/AthenaConnectionTest.java index f202a52..1165f9a 100644 --- a/src/test/java/io/burt/athena/AthenaConnectionTest.java +++ b/src/test/java/io/burt/athena/AthenaConnectionTest.java @@ -58,9 +58,9 @@ void setUpConnection() { } PollingStrategy createPollingStrategy() { - return callback -> { + return (callback, deadline) -> { while (true) { - Optional rs = callback.poll(); + Optional rs = callback.poll(deadline); if (rs.isPresent()) { return rs.get(); } diff --git a/src/test/java/io/burt/athena/AthenaStatementTest.java b/src/test/java/io/burt/athena/AthenaStatementTest.java index c55314c..587f5c6 100644 --- a/src/test/java/io/burt/athena/AthenaStatementTest.java +++ b/src/test/java/io/burt/athena/AthenaStatementTest.java @@ -61,9 +61,9 @@ void setUpStatement() { } PollingStrategy createPollingStrategy() { - return callback -> { + return (callback, deadline) -> { while (true) { - Optional rs = callback.poll(); + Optional rs = callback.poll(deadline); if (rs.isPresent()) { return rs.get(); } @@ -238,7 +238,7 @@ class WhenInterruptedWhileSleeping { @BeforeEach void setUp() { - pollingStrategy = callback -> { + pollingStrategy = (callback, deadline) -> { throw new InterruptedException(); }; executeResult = new AtomicReference<>(null); diff --git a/src/test/java/io/burt/athena/polling/BackoffPollingStrategyTest.java b/src/test/java/io/burt/athena/polling/BackoffPollingStrategyTest.java index 16864c5..c4f2e4e 100644 --- a/src/test/java/io/burt/athena/polling/BackoffPollingStrategyTest.java +++ b/src/test/java/io/burt/athena/polling/BackoffPollingStrategyTest.java @@ -1,5 +1,6 @@ package io.burt.athena.polling; +import io.burt.athena.support.TestClock; import io.burt.athena.support.TestNameGenerator; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; @@ -13,6 +14,7 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.time.Duration; +import java.time.Instant; import java.util.List; import java.util.Optional; import java.util.concurrent.ExecutionException; @@ -30,12 +32,14 @@ @DisplayNameGeneration(TestNameGenerator.class) class BackoffPollingStrategyTest { private Sleeper sleeper; + private TestClock clock; private PollingStrategy pollingStrategy; @BeforeEach void setUp() { sleeper = mock(Sleeper.class); - pollingStrategy = new BackoffPollingStrategy(Duration.ofMillis(3), Duration.ofSeconds(1), sleeper); + clock = new TestClock(); + pollingStrategy = new BackoffPollingStrategy(Duration.ofMillis(3), Duration.ofSeconds(1), sleeper, clock); } @Nested @@ -45,34 +49,34 @@ class PollUntilCompleted { @Test void pollsUntilTheCallbackReturnsAResultSet() throws Exception { AtomicInteger counter = new AtomicInteger(0); - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { if (counter.get() == 3) { return Optional.of(mock(ResultSet.class)); } else { counter.incrementAndGet(); return Optional.empty(); } - }); + }, Instant.now().plus(Duration.ofSeconds(30))); assertEquals(3, counter.get()); } @Test void returnsTheResultSet() throws Exception { ResultSet rs1 = mock(ResultSet.class); - ResultSet rs2 = pollingStrategy.pollUntilCompleted(() -> Optional.of(rs1)); + ResultSet rs2 = pollingStrategy.pollUntilCompleted((Instant deadline) -> Optional.of(rs1), clock.instant().plus(Duration.ofSeconds(30))); assertSame(rs1, rs2); } @Test void doublesTheDelayAfterEachPollUpToTheConfiguredMax() throws Exception { AtomicInteger counter = new AtomicInteger(0); - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { if (counter.getAndIncrement() == 15) { return Optional.of(mock(ResultSet.class)); } else { return Optional.empty(); } - }); + }, clock.instant().plus(Duration.ofSeconds(30))); verify(sleeper, atLeastOnce()).sleep(delayCaptor.capture()); List delays = delayCaptor.getAllValues(); assertEquals(Duration.ofMillis(3), delays.get(0)); @@ -90,23 +94,52 @@ void doublesTheDelayAfterEachPollUpToTheConfiguredMax() throws Exception { assertEquals(Duration.ofMillis(1000), delays.get(12)); } + @Test + void reducesFinalDelayToMatchDeadline() throws Exception { + AtomicInteger counter = new AtomicInteger(0); + pollingStrategy.pollUntilCompleted((Instant deadline) -> { + clock.tick(Duration.ofMillis(20)); + if (counter.getAndIncrement() >= 4) { + return Optional.of(mock(ResultSet.class)); + } else { + return Optional.empty(); + } + }, clock.instant().plus(Duration.ofMillis(100))); + verify(sleeper, atLeastOnce()).sleep(delayCaptor.capture()); + List delays = delayCaptor.getAllValues(); + assertEquals(Duration.ofMillis(3), delays.get(0)); + assertEquals(Duration.ofMillis(6), delays.get(1)); + assertEquals(Duration.ofMillis(12), delays.get(2)); + assertEquals(Duration.ofMillis(20), delays.get(3)); + } + + @Test + void throwsTimeoutExceptionIfNotCompletedWithinDeadline() throws Exception { + assertThrows(TimeoutException.class, () -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { + clock.tick(Duration.ofSeconds(10)); + return Optional.empty(); + }, clock.instant()); + }); + } + @Nested class WithAFactor { @BeforeEach void setUp() { - pollingStrategy = new BackoffPollingStrategy(Duration.ofMillis(3), Duration.ofSeconds(1), 7, sleeper); + pollingStrategy = new BackoffPollingStrategy(Duration.ofMillis(3), Duration.ofSeconds(1), 7, sleeper, clock); } @Test void usesTheFactorToCalculateTheNextDelay() throws Exception { AtomicInteger counter = new AtomicInteger(0); - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { if (counter.getAndIncrement() == 15) { return Optional.of(mock(ResultSet.class)); } else { return Optional.empty(); } - }); + }, clock.instant().plus(Duration.ofSeconds(30))); verify(sleeper, atLeastOnce()).sleep(delayCaptor.capture()); List delays = delayCaptor.getAllValues(); assertEquals(Duration.ofMillis(3), delays.get(0)); @@ -123,24 +156,24 @@ class WhenTheCallbackThrowsAnException { @Test void passesTheExceptionAlong() { assertThrows(InterruptedException.class, () -> { - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { throw new InterruptedException(); - }); + }, clock.instant().plus(Duration.ofSeconds(30))); }); assertThrows(SQLException.class, () -> { - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { throw new SQLException(); - }); + }, clock.instant().plus(Duration.ofSeconds(30))); }); assertThrows(ExecutionException.class, () -> { - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { throw new ExecutionException(new ArithmeticException()); - }); + }, clock.instant().plus(Duration.ofSeconds(30))); }); assertThrows(TimeoutException.class, () -> { - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { throw new TimeoutException(); - }); + }, clock.instant().plus(Duration.ofSeconds(30))); }); } } diff --git a/src/test/java/io/burt/athena/polling/FixedDelayPollingStrategyTest.java b/src/test/java/io/burt/athena/polling/FixedDelayPollingStrategyTest.java index 9123325..4ada2c6 100644 --- a/src/test/java/io/burt/athena/polling/FixedDelayPollingStrategyTest.java +++ b/src/test/java/io/burt/athena/polling/FixedDelayPollingStrategyTest.java @@ -1,5 +1,6 @@ package io.burt.athena.polling; +import io.burt.athena.support.TestClock; import io.burt.athena.support.TestNameGenerator; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; @@ -11,6 +12,7 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.time.Duration; +import java.time.Instant; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; @@ -27,12 +29,14 @@ @DisplayNameGeneration(TestNameGenerator.class) class FixedDelayPollingStrategyTest { private Sleeper sleeper; + private TestClock clock; private PollingStrategy pollingStrategy; @BeforeEach void setUp() { sleeper = mock(Sleeper.class); - pollingStrategy = new FixedDelayPollingStrategy(Duration.ofSeconds(3), sleeper); + clock = new TestClock(); + pollingStrategy = new FixedDelayPollingStrategy(Duration.ofSeconds(3), sleeper, clock); } @Nested @@ -40,60 +44,83 @@ class PollUntilCompleted { @Test void pollsUntilTheCallbackReturnsAResultSet() throws Exception { AtomicInteger counter = new AtomicInteger(0); - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { if (counter.get() == 3) { return Optional.of(mock(ResultSet.class)); } else { counter.incrementAndGet(); return Optional.empty(); } - }); + }, clock.instant().plus(Duration.ofSeconds(30))); assertEquals(3, counter.get()); } @Test void returnsTheResultSet() throws Exception { ResultSet rs1 = mock(ResultSet.class); - ResultSet rs2 = pollingStrategy.pollUntilCompleted(() -> Optional.of(rs1)); + ResultSet rs2 = pollingStrategy.pollUntilCompleted((Instant deadline) -> Optional.of(rs1), clock.instant().plus(Duration.ofSeconds(30))); assertSame(rs1, rs2); } @Test void delaysTheConfiguredDurationBetweenPolls() throws Exception { AtomicInteger counter = new AtomicInteger(0); - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { if (counter.getAndIncrement() == 3) { return Optional.of(mock(ResultSet.class)); } else { return Optional.empty(); } - }); + }, clock.instant().plus(Duration.ofSeconds(30))); verify(sleeper, times(3)).sleep(Duration.ofSeconds(3)); } + @Test + void reducesFinalDelayToMatchDeadline() throws Exception { + AtomicInteger counter = new AtomicInteger(0); + pollingStrategy.pollUntilCompleted((Instant deadline) -> { + if (counter.getAndIncrement() >= 1) { + return Optional.of(mock(ResultSet.class)); + } else { + return Optional.empty(); + } + }, clock.instant().plus(Duration.ofMillis(100))); + verify(sleeper, times(1)).sleep(Duration.ofMillis(100)); + } + + @Test + void throwsTimeoutExceptionIfNotCompletedWithinDeadline() throws Exception { + assertThrows(TimeoutException.class, () -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { + clock.tick(Duration.ofSeconds(10)); + return Optional.empty(); + }, clock.instant()); + }); + } + @Nested class WhenTheCallbackThrowsAnException { @Test void passesTheExceptionAlong() { assertThrows(InterruptedException.class, () -> { - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { throw new InterruptedException(); - }); + }, clock.instant().plus(Duration.ofSeconds(30))); }); assertThrows(SQLException.class, () -> { - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { throw new SQLException(); - }); + }, clock.instant().plus(Duration.ofSeconds(30))); }); assertThrows(ExecutionException.class, () -> { - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { throw new ExecutionException(new ArithmeticException()); - }); + }, clock.instant().plus(Duration.ofSeconds(30))); }); assertThrows(TimeoutException.class, () -> { - pollingStrategy.pollUntilCompleted(() -> { + pollingStrategy.pollUntilCompleted((Instant deadline) -> { throw new TimeoutException(); - }); + }, clock.instant().plus(Duration.ofSeconds(30))); }); } } From f0dbec79038e1d55abcf06c73287872214df462c Mon Sep 17 00:00:00 2001 From: Gustav Munkby Date: Fri, 27 Sep 2019 22:46:31 +0200 Subject: [PATCH 5/6] Cancel a query after timeout passed The purpose of cancelling the query is to avoid consuming all Athena resources on queries that nobody is waiting for anyhow. If the timeout was passed before StartQueryExceution finished, we can't easily cancel it, as we do not yet have a query execution ID. This doesn't actually wait for the cancellation to go through, and any failure reported in cancellation will just be ignored. Given that we have already used up all the time the user said we had, it seems unreasonable to wait for the cancellation to finish, but there is also no natural way to report an asynchronous error. --- .../java/io/burt/athena/AthenaStatement.java | 14 ++++++++++++-- .../io/burt/athena/AthenaStatementTest.java | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/main/java/io/burt/athena/AthenaStatement.java b/src/main/java/io/burt/athena/AthenaStatement.java index 7c90008..7f29771 100644 --- a/src/main/java/io/burt/athena/AthenaStatement.java +++ b/src/main/java/io/burt/athena/AthenaStatement.java @@ -85,8 +85,18 @@ public boolean execute(String sql) throws SQLException { } catch (InterruptedException ie) { Thread.currentThread().interrupt(); return false; - } catch (TimeoutException ie) { - throw new SQLTimeoutException(ie); + } catch (TimeoutException te) { + SQLTimeoutException ste = new SQLTimeoutException(te); + if (queryExecutionId != null) { + try { + athenaClient.stopQueryExecution(b -> { + b.queryExecutionId(queryExecutionId); + }); + } catch (Exception e) { + ste.addSuppressed(e); + } + } + throw ste; } catch (ExecutionException ee) { SQLException eee = new SQLException(ee.getCause()); eee.addSuppressed(ee); diff --git a/src/test/java/io/burt/athena/AthenaStatementTest.java b/src/test/java/io/burt/athena/AthenaStatementTest.java index 587f5c6..8273b8b 100644 --- a/src/test/java/io/burt/athena/AthenaStatementTest.java +++ b/src/test/java/io/burt/athena/AthenaStatementTest.java @@ -522,6 +522,23 @@ void setsTheTimeoutUsedForQuerySpanningMultipleOperations() { statement.setQueryTimeout(Duration.ofMillis(100)); assertThrows(SQLTimeoutException.class, () -> statement.executeQuery("SELECT 1")); } + + @Test + void cancelsQueryAfterTimeout() throws Exception { + queryExecutionHelper.delayGetQueryExecutionResponses(Duration.ofMillis(10)); + statement.setQueryTimeout(0); + try { statement.executeQuery("SELECT 1"); } catch (SQLTimeoutException ste) { /* expected */ } + StopQueryExecutionRequest request = queryExecutionHelper.stopQueryExecutionRequests().get(0); + assertEquals("Q1234", request.queryExecutionId()); + } + + @Test + void doesNotCancelQueryThatDidNotStart() throws Exception { + queryExecutionHelper.delayStartQueryExecutionResponses(Duration.ofMillis(10)); + statement.setQueryTimeout(0); + try { statement.executeQuery("SELECT 1"); } catch (SQLTimeoutException ste) { /* expected */ } + assertEquals(0, queryExecutionHelper.stopQueryExecutionRequests().size()); + } } @Nested From 3a11e14b9f57910d95db876b2fef4ee52ee176b5 Mon Sep 17 00:00:00 2001 From: Gustav Munkby Date: Sun, 29 Sep 2019 19:23:10 +0200 Subject: [PATCH 6/6] Refactor polling strategies to reduce code duplication --- .../burt/athena/polling/BackoffPollingStrategy.java | 10 +--------- .../athena/polling/FixedDelayPollingStrategy.java | 10 +--------- .../io/burt/athena/polling/PollingStrategy.java | 13 +++++++++++++ 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/main/java/io/burt/athena/polling/BackoffPollingStrategy.java b/src/main/java/io/burt/athena/polling/BackoffPollingStrategy.java index 41e80f9..7f8c77a 100644 --- a/src/main/java/io/burt/athena/polling/BackoffPollingStrategy.java +++ b/src/main/java/io/burt/athena/polling/BackoffPollingStrategy.java @@ -49,15 +49,7 @@ public ResultSet pollUntilCompleted(PollingCallback callback, Instant deadline) if (resultSet.isPresent()) { return resultSet.get(); } else { - Duration beforeDeadline = Duration.between(clock.instant(), deadline); - if (beforeDeadline.compareTo(nextDelay) < 0) { - if (beforeDeadline.isNegative()) { - throw new TimeoutException(); - } - sleeper.sleep(beforeDeadline); - } else { - sleeper.sleep(nextDelay); - } + sleeper.sleep(sleepDuration(nextDelay, clock.instant(), deadline)); nextDelay = nextDelay.multipliedBy(factor); if (nextDelay.compareTo(maxDelay) > 0) { nextDelay = maxDelay; diff --git a/src/main/java/io/burt/athena/polling/FixedDelayPollingStrategy.java b/src/main/java/io/burt/athena/polling/FixedDelayPollingStrategy.java index 85ab627..dcfb7ec 100644 --- a/src/main/java/io/burt/athena/polling/FixedDelayPollingStrategy.java +++ b/src/main/java/io/burt/athena/polling/FixedDelayPollingStrategy.java @@ -32,15 +32,7 @@ public ResultSet pollUntilCompleted(PollingCallback callback, Instant deadline) if (resultSet.isPresent()) { return resultSet.get(); } else { - Duration beforeDeadline = Duration.between(clock.instant(), deadline); - if (beforeDeadline.compareTo(delay) < 0) { - if (beforeDeadline.isNegative()) { - throw new TimeoutException(); - } - sleeper.sleep(beforeDeadline); - } else { - sleeper.sleep(delay); - } + sleeper.sleep(sleepDuration(delay, clock.instant(), deadline)); } } } diff --git a/src/main/java/io/burt/athena/polling/PollingStrategy.java b/src/main/java/io/burt/athena/polling/PollingStrategy.java index 4d7be45..39953c3 100644 --- a/src/main/java/io/burt/athena/polling/PollingStrategy.java +++ b/src/main/java/io/burt/athena/polling/PollingStrategy.java @@ -2,10 +2,23 @@ import java.sql.ResultSet; import java.sql.SQLException; +import java.time.Duration; import java.time.Instant; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; public interface PollingStrategy { ResultSet pollUntilCompleted(PollingCallback callback, Instant deadline) throws SQLException, TimeoutException, ExecutionException, InterruptedException; + + default Duration sleepDuration(Duration desired, Instant now, Instant deadline) throws TimeoutException { + Duration beforeDeadline = Duration.between(now, deadline); + if (beforeDeadline.compareTo(desired) < 0) { + if (beforeDeadline.isNegative()) { + throw new TimeoutException("polling reached deadline"); + } + return beforeDeadline; + } else { + return desired; + } + } }