Skip to content

Commit

Permalink
Merge pull request #4 from burtcorp/query-timeout
Browse files Browse the repository at this point in the history
Implement support for query timeouts
  • Loading branch information
iconara committed Sep 30, 2019
2 parents 0d5f3af + 3a11e14 commit 82026f0
Show file tree
Hide file tree
Showing 20 changed files with 329 additions and 96 deletions.
7 changes: 4 additions & 3 deletions src/main/java/io/burt/athena/AthenaConnection.java
Expand Up @@ -19,6 +19,7 @@
import java.sql.Savepoint;
import java.sql.Statement;
import java.sql.Struct;
import java.time.Clock;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;
Expand All @@ -45,7 +46,7 @@ private void checkClosed() throws SQLException {
@Override
public Statement createStatement() throws SQLException {
checkClosed();
return new AthenaStatement(configuration);
return new AthenaStatement(configuration, Clock.systemDefaultZone());
}

@Override
Expand Down Expand Up @@ -305,13 +306,13 @@ public Properties getClientInfo() throws SQLException {
@Override
public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {
checkClosed();
configuration = configuration.withTimeout(Duration.ofMillis(milliseconds));
configuration = configuration.withNetworkTimeout(Duration.ofMillis(milliseconds));
}

@Override
public int getNetworkTimeout() throws SQLException {
checkClosed();
return (int) configuration.apiCallTimeout().toMillis();
return (int) configuration.networkTimeout().toMillis();
}

@Override
Expand Down
1 change: 1 addition & 0 deletions src/main/java/io/burt/athena/AthenaDriver.java
Expand Up @@ -133,6 +133,7 @@ public Connection connect(String url, Properties connectionProperties) {
workGroup,
outputLocation,
Duration.ofMinutes(1),
Duration.ofMinutes(30),
ResultLoadingStrategy.S3
);
return new AthenaConnection(configuration);
Expand Down
45 changes: 34 additions & 11 deletions src/main/java/io/burt/athena/AthenaStatement.java
Expand Up @@ -11,7 +11,9 @@
import java.sql.SQLTimeoutException;
import java.sql.SQLWarning;
import java.sql.Statement;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
Expand All @@ -20,16 +22,18 @@

public class AthenaStatement implements Statement {
private final AthenaAsyncClient athenaClient;
private Clock clock;

private ConnectionConfiguration configuration;
private String queryExecutionId;
private ResultSet currentResultSet;
private Function<String, Optional<String>> clientRequestTokenProvider;
private boolean open;

AthenaStatement(ConnectionConfiguration configuration) {
AthenaStatement(ConnectionConfiguration configuration, Clock clock) {
this.configuration = configuration;
this.athenaClient = configuration.athenaClient();
this.clock = clock;
this.queryExecutionId = null;
this.currentResultSet = null;
this.clientRequestTokenProvider = sql -> Optional.empty();
Expand Down Expand Up @@ -74,22 +78,33 @@ public boolean execute(String sql) throws SQLException {
currentResultSet = null;
}
try {
queryExecutionId = startQueryExecution(sql);
currentResultSet = configuration.pollingStrategy().pollUntilCompleted(this::poll);
Instant deadline = clock.instant().plus(configuration.queryTimeout());
queryExecutionId = startQueryExecution(sql, deadline);
currentResultSet = configuration.pollingStrategy().pollUntilCompleted(this::poll, deadline);
return currentResultSet != null;
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
return false;
} catch (TimeoutException ie) {
throw new SQLTimeoutException(ie);
} catch (TimeoutException te) {
SQLTimeoutException ste = new SQLTimeoutException(te);
if (queryExecutionId != null) {
try {
athenaClient.stopQueryExecution(b -> {
b.queryExecutionId(queryExecutionId);
});
} catch (Exception e) {
ste.addSuppressed(e);
}
}
throw ste;
} catch (ExecutionException ee) {
SQLException eee = new SQLException(ee.getCause());
eee.addSuppressed(ee);
throw eee;
}
}

private String startQueryExecution(String sql) throws InterruptedException, ExecutionException, TimeoutException {
private String startQueryExecution(String sql, Instant deadline) throws InterruptedException, ExecutionException, TimeoutException {
return athenaClient
.startQueryExecution(b -> {
b.queryString(sql);
Expand All @@ -98,14 +113,14 @@ private String startQueryExecution(String sql) throws InterruptedException, Exec
b.resultConfiguration(bb -> bb.outputLocation(configuration.outputLocation()));
clientRequestTokenProvider.apply(sql).ifPresent(b::clientRequestToken);
})
.get(configuration.apiCallTimeout().toMillis(), TimeUnit.MILLISECONDS)
.get(networkTimeoutMillis(deadline), TimeUnit.MILLISECONDS)
.queryExecutionId();
}

private Optional<ResultSet> poll() throws SQLException, InterruptedException, ExecutionException, TimeoutException {
private Optional<ResultSet> poll(Instant deadline) throws SQLException, InterruptedException, ExecutionException, TimeoutException {
QueryExecution queryExecution = athenaClient
.getQueryExecution(b -> b.queryExecutionId(queryExecutionId))
.get(configuration.apiCallTimeout().toMillis(), TimeUnit.MILLISECONDS)
.get(networkTimeoutMillis(deadline), TimeUnit.MILLISECONDS)
.queryExecution();
switch (queryExecution.status().state()) {
case SUCCEEDED:
Expand All @@ -118,6 +133,10 @@ private Optional<ResultSet> poll() throws SQLException, InterruptedException, Ex
}
}

private long networkTimeoutMillis(Instant deadline) {
return Math.max(0, Math.min(configuration.networkTimeout().toMillis(), Duration.between(clock.instant(), deadline).toMillis()));
}

private ResultSet createResultSet(QueryExecution queryExecution) {
return new AthenaResultSet(
configuration.createResult(queryExecution),
Expand Down Expand Up @@ -269,14 +288,18 @@ public void setEscapeProcessing(boolean enable) {
throw new UnsupportedOperationException("Not implemented");
}

public void setQueryTimeout(Duration timeout) {
configuration = configuration.withQueryTimeout(timeout);
}

@Override
public int getQueryTimeout() {
return (int) configuration.apiCallTimeout().toMillis() / 1000;
return (int) configuration.queryTimeout().toMillis() / 1000;
}

@Override
public void setQueryTimeout(int seconds) {
configuration = configuration.withTimeout(Duration.ofSeconds(seconds));
setQueryTimeout(Duration.ofSeconds(seconds));
}

@Override
Expand Down
Expand Up @@ -18,24 +18,26 @@ class ConcreteConnectionConfiguration implements ConnectionConfiguration {
private final String databaseName;
private final String workGroupName;
private final String outputLocation;
private final Duration timeout;
private final Duration networkTimeout;
private final Duration queryTimeout;
private final ResultLoadingStrategy resultLoadingStrategy;

private AthenaAsyncClient athenaClient;
private S3AsyncClient s3Client;
private PollingStrategy pollingStrategy;

ConcreteConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration timeout, ResultLoadingStrategy resultLoadingStrategy) {
ConcreteConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration networkTimeout, Duration queryTimeout, ResultLoadingStrategy resultLoadingStrategy) {
this.awsRegion = awsRegion;
this.databaseName = databaseName;
this.workGroupName = workGroupName;
this.outputLocation = outputLocation;
this.timeout = timeout;
this.networkTimeout = networkTimeout;
this.queryTimeout = queryTimeout;
this.resultLoadingStrategy = resultLoadingStrategy;
}

private ConcreteConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration timeout, ResultLoadingStrategy resultLoadingStrategy, AthenaAsyncClient athenaClient, S3AsyncClient s3Client, PollingStrategy pollingStrategy) {
this(awsRegion, databaseName, workGroupName, outputLocation, timeout, resultLoadingStrategy);
private ConcreteConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration networkTimeout, Duration queryTimeout, ResultLoadingStrategy resultLoadingStrategy, AthenaAsyncClient athenaClient, S3AsyncClient s3Client, PollingStrategy pollingStrategy) {
this(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy);
this.athenaClient = athenaClient;
this.s3Client = s3Client;
this.pollingStrategy = pollingStrategy;
Expand All @@ -57,10 +59,13 @@ public String outputLocation() {
}

@Override
public Duration apiCallTimeout() {
return timeout;
public Duration networkTimeout() {
return networkTimeout;
}

@Override
public Duration queryTimeout() { return queryTimeout; }

@Override
public AthenaAsyncClient athenaClient() {
if (athenaClient == null) {
Expand All @@ -87,12 +92,17 @@ public PollingStrategy pollingStrategy() {

@Override
public ConnectionConfiguration withDatabaseName(String databaseName) {
return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, timeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy);
return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy);
}

@Override
public ConnectionConfiguration withNetworkTimeout(Duration networkTimeout) {
return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy);
}

@Override
public ConnectionConfiguration withTimeout(Duration timeout) {
return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, timeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy);
public ConnectionConfiguration withQueryTimeout(Duration queryTimeout) {
return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy, athenaClient, s3Client, pollingStrategy);
}

@Override
Expand Down
Expand Up @@ -15,7 +15,9 @@ public interface ConnectionConfiguration extends AutoCloseable {

String outputLocation();

Duration apiCallTimeout();
Duration networkTimeout();

Duration queryTimeout();

AthenaAsyncClient athenaClient();

Expand All @@ -25,7 +27,9 @@ public interface ConnectionConfiguration extends AutoCloseable {

ConnectionConfiguration withDatabaseName(String databaseName);

ConnectionConfiguration withTimeout(Duration timeout);
ConnectionConfiguration withNetworkTimeout(Duration timeout);

ConnectionConfiguration withQueryTimeout(Duration timeout);

Result createResult(QueryExecution queryExecution);
}
Expand Up @@ -5,8 +5,8 @@
import java.time.Duration;

public class ConnectionConfigurationFactory {
public ConnectionConfiguration createConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration timeout, ResultLoadingStrategy resultLoadingStrategy) {
return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, timeout, resultLoadingStrategy);
public ConnectionConfiguration createConnectionConfiguration(Region awsRegion, String databaseName, String workGroupName, String outputLocation, Duration networkTimeout, Duration queryTimeout, ResultLoadingStrategy resultLoadingStrategy) {
return new ConcreteConnectionConfiguration(awsRegion, databaseName, workGroupName, outputLocation, networkTimeout, queryTimeout, resultLoadingStrategy);
}
}

22 changes: 15 additions & 7 deletions src/main/java/io/burt/athena/polling/BackoffPollingStrategy.java
Expand Up @@ -2,7 +2,9 @@

import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
Expand All @@ -13,35 +15,41 @@ class BackoffPollingStrategy implements PollingStrategy {
private final Duration maxDelay;
private final long factor;
private final Sleeper sleeper;
private final Clock clock;

BackoffPollingStrategy(Duration firstDelay, Duration maxDelay) {
this(firstDelay, maxDelay, 2L, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis()));
this(firstDelay, maxDelay, 2L, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis()), Clock.systemDefaultZone());
}

BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, long factor) {
this(firstDelay, maxDelay, factor, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis()));
this(firstDelay, maxDelay, factor, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis()), Clock.systemDefaultZone());
}

BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, Sleeper sleeper) {
this(firstDelay, maxDelay, 2L, sleeper);
this(firstDelay, maxDelay, 2L, sleeper, Clock.systemDefaultZone());
}

BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, long factor, Sleeper sleeper) {
BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, Sleeper sleeper, Clock clock) {
this(firstDelay, maxDelay, 2L, sleeper, clock);
}

BackoffPollingStrategy(Duration firstDelay, Duration maxDelay, long factor, Sleeper sleeper, Clock clock) {
this.firstDelay = firstDelay;
this.maxDelay = maxDelay;
this.factor = factor;
this.sleeper = sleeper;
this.clock = clock;
}

@Override
public ResultSet pollUntilCompleted(PollingCallback callback) throws SQLException, TimeoutException, ExecutionException, InterruptedException {
public ResultSet pollUntilCompleted(PollingCallback callback, Instant deadline) throws SQLException, TimeoutException, ExecutionException, InterruptedException {
Duration nextDelay = firstDelay;
while (true) {
Optional<ResultSet> resultSet = callback.poll();
Optional<ResultSet> resultSet = callback.poll(deadline);
if (resultSet.isPresent()) {
return resultSet.get();
} else {
sleeper.sleep(nextDelay);
sleeper.sleep(sleepDuration(nextDelay, clock.instant(), deadline));
nextDelay = nextDelay.multipliedBy(factor);
if (nextDelay.compareTo(maxDelay) > 0) {
nextDelay = maxDelay;
Expand Down
Expand Up @@ -2,7 +2,9 @@

import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
Expand All @@ -11,24 +13,26 @@
public class FixedDelayPollingStrategy implements PollingStrategy {
private final Duration delay;
private final Sleeper sleeper;
private Clock clock;

FixedDelayPollingStrategy(Duration delay) {
this(delay, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis()));
this(delay, duration -> TimeUnit.MILLISECONDS.sleep(duration.toMillis()), Clock.systemDefaultZone());
}

FixedDelayPollingStrategy(Duration delay, Sleeper sleeper) {
FixedDelayPollingStrategy(Duration delay, Sleeper sleeper, Clock clock) {
this.delay = delay;
this.sleeper = sleeper;
this.clock = clock;
}

@Override
public ResultSet pollUntilCompleted(PollingCallback callback) throws SQLException, TimeoutException, ExecutionException, InterruptedException {
public ResultSet pollUntilCompleted(PollingCallback callback, Instant deadline) throws SQLException, TimeoutException, ExecutionException, InterruptedException {
while (true) {
Optional<ResultSet> resultSet = callback.poll();
Optional<ResultSet> resultSet = callback.poll(deadline);
if (resultSet.isPresent()) {
return resultSet.get();
} else {
sleeper.sleep(delay);
sleeper.sleep(sleepDuration(delay, clock.instant(), deadline));
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/io/burt/athena/polling/PollingCallback.java
Expand Up @@ -2,11 +2,12 @@

import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;

@FunctionalInterface
public interface PollingCallback {
Optional<ResultSet> poll() throws SQLException, TimeoutException, ExecutionException, InterruptedException;
Optional<ResultSet> poll(Instant deadline) throws SQLException, TimeoutException, ExecutionException, InterruptedException;
}

0 comments on commit 82026f0

Please sign in to comment.