Skip to content

Commit

Permalink
fix: remove race condition bug in refresh logic and increase threadpo… (
Browse files Browse the repository at this point in the history
  • Loading branch information
ttosta-google committed Sep 18, 2023
1 parent 8889f5f commit 7d44e0d
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 14 deletions.
17 changes: 17 additions & 0 deletions alloydb-jdbc-connector/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
<!-- The postgres driver must be provided by client application and is otherwise unused here. -->
<ignoredDependency>org.postgresql:postgresql:*</ignoredDependency>
</ignoredDependencies>
<usedDependencies>
<!-- This dependency is not used at compile-time. -->
<dependency>org.slf4j:slf4j-jdk14</dependency>
</usedDependencies>
</configuration>
</plugin>
</plugins>
Expand Down Expand Up @@ -113,6 +117,19 @@
<scope>provided</scope>
</dependency>

<!-- Logging -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>2.0.9</version>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
<version>2.0.9</version>
</dependency>

<!-- Test dependencies -->
<dependency>
<groupId>junit</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ public enum ConnectorRegistry implements Closeable {
private final Connector connector;

ConnectorRegistry() {
this.executor = Executors.newScheduledThreadPool(2);
// During refresh, each instance consumes 2 threads from the thread pool. By using 8 threads,
// there should be enough free threads so that there will not be a deadlock. Most users
// configure 3 or fewer instances, requiring 6 threads during refresh. By setting
// this to 8, it's enough threads for most users, plus a safety factor of 2.
this.executor = Executors.newScheduledThreadPool(8);
try {
alloyDBAdminClient = AlloyDBAdminClient.create();
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,22 @@
import java.security.KeyPair;
import java.security.cert.CertificateException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* DefaultConnectionInfoCache is the cache used by default to hold connection info. In testing, this
* class may be replaced with alternative implementations of ConnectionInfoCache.
*/
class DefaultConnectionInfoCache implements ConnectionInfoCache {

private static final Logger logger = LoggerFactory.getLogger(DefaultConnectionInfoCache.class);

private final ScheduledExecutorService executor;
private final ConnectionInfoRepository connectionInfoRepo;
private final InstanceName instanceName;
Expand All @@ -49,6 +54,9 @@ class DefaultConnectionInfoCache implements ConnectionInfoCache {
@GuardedBy("connectionInfoLock")
private Future<ConnectionInfo> next;

@GuardedBy("connectionInfoLock")
private boolean forceRefreshRunning;

DefaultConnectionInfoCache(
ScheduledExecutorService executor,
ConnectionInfoRepository connectionInfoRepo,
Expand Down Expand Up @@ -91,35 +99,61 @@ public ConnectionInfo getConnectionInfo() {
*/
private ConnectionInfo performRefresh()
throws CertificateException, ExecutionException, InterruptedException {
logger.info(
String.format("[%s] Refresh Operation: Acquiring rate limiter permit.", instanceName));
// Rate limit the speed of refresh operations.
this.rateLimiter.acquire();
logger.info(
String.format(
"[%s] Refresh Operation: Acquired rate limiter permit. Starting refresh...",
instanceName));

try {
ConnectionInfo connectionInfo =
this.connectionInfoRepo.getConnectionInfo(this.instanceName, this.clientConnectorKeyPair);
logger.info(
String.format(
"[%s] Refresh Operation: Completed refresh with new certificate expiration at %s.",
instanceName, connectionInfo.getClientCertificateExpiration().toString()));

long secondsToRefresh =
refreshCalculator.calculateSecondsUntilNextRefresh(
Instant.now(), connectionInfo.getClientCertificateExpiration());
logger.info(
String.format(
"[%s] Refresh Operation: Next operation scheduled at %s.",
instanceName,
Instant.now()
.plus(secondsToRefresh, ChronoUnit.SECONDS)
.truncatedTo(ChronoUnit.SECONDS)
.toString()));

synchronized (connectionInfoLock) {
current = Futures.immediateFuture(connectionInfo);
next =
executor.schedule(
this::performRefresh,
refreshCalculator.calculateSecondsUntilNextRefresh(
Instant.now(), connectionInfo.getClientCertificateExpiration()),
TimeUnit.SECONDS);
next = executor.schedule(this::performRefresh, secondsToRefresh, TimeUnit.SECONDS);
forceRefreshRunning = false;
}

return connectionInfo;
} catch (CertificateException | ExecutionException | InterruptedException e) {
logger.info(
String.format(
"[%s] Refresh Operation: Failed! Starting next refresh operation immediately.",
instanceName),
e);
// For known exceptions, schedule a refresh immediately.
synchronized (connectionInfoLock) {
next = executor.submit(this::performRefresh);
}
throw e;
} catch (RuntimeException e) {
logger.info(String.format("[%s] Refresh Operation: Failed!", instanceName), e);
// If the exception is an ApiException, schedule a refresh immediately
// before re-throwing the exception.
Throwable cause = e.getCause();
if (cause instanceof ApiException) {
logger.info(
String.format("[%s] Starting next refresh operation immediately.", instanceName), e);
synchronized (connectionInfoLock) {
next = executor.submit(this::performRefresh);
}
Expand All @@ -135,15 +169,26 @@ private ConnectionInfo performRefresh()
@Override
public void forceRefresh() {
synchronized (connectionInfoLock) {
// Don't force a refresh until the current forceRefresh operation
// has produced a successful refresh.
if (forceRefreshRunning) {
logger.info(
String.format(
"[%s] Force Refresh: ignore this call as a refresh operation is currently in progress.",
instanceName));
return;
}

forceRefreshRunning = true;
// If a scheduled refresh hasn't started, perform one immediately.
next.cancel(false);
if (next.isCancelled()) {
current = executor.submit(this::performRefresh);
next = current;
} else {
// Otherwise it's already running, so just move next to current.
current = next;
}
logger.info(
String.format(
"[%s] Force Refresh: the next refresh operation was cancelled."
+ " Scheduling new refresh operation immediately.",
instanceName));
current = executor.submit(this::performRefresh);
next = current;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class ConnectionInfoCacheTest {
private static final String TEST_INSTANCE_ID = "some-instance-id";
private static final Instant ONE_HOUR_FROM_NOW = Instant.now().plus(1, ChronoUnit.HOURS);
private static final Instant TWO_HOURS_FROM_NOW = ONE_HOUR_FROM_NOW.plus(1, ChronoUnit.HOURS);
private static final Instant THREE_HOURS_FROM_NOW = TWO_HOURS_FROM_NOW.plus(1, ChronoUnit.HOURS);
private InstanceName instanceName;
private KeyPair keyPair;
private SpyRateLimiter spyRateLimiter;
Expand Down Expand Up @@ -339,6 +340,67 @@ public void testForceRefresh_schedulesNextRefreshImmediately() {
.isEqualTo(TWO_HOURS_FROM_NOW.truncatedTo(ChronoUnit.SECONDS));
}

@Test
public void testForceRefresh_refreshCalledOnlyOnceDuringMultipleCalls() {
ScheduledExecutorService executor = Executors.newScheduledThreadPool(2);

InMemoryConnectionInfoRepo connectionInfoRepo = new InMemoryConnectionInfoRepo();
List<X509Certificate> certificateChain =
Arrays.asList(
testCertificates.getIntermediateCertificate(), testCertificates.getRootCertificate());
connectionInfoRepo.addResponses(
() ->
new ConnectionInfo(
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), ONE_HOUR_FROM_NOW),
certificateChain),
() ->
new ConnectionInfo(
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), TWO_HOURS_FROM_NOW),
certificateChain),
() ->
new ConnectionInfo(
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), THREE_HOURS_FROM_NOW),
certificateChain));
DefaultConnectionInfoCache connectionInfoCache =
new DefaultConnectionInfoCache(
executor,
connectionInfoRepo,
instanceName,
keyPair,
new RefreshCalculator(),
spyRateLimiter);

// Before force refresh, the first refresh data is available.
ConnectionInfo connectionInfo = connectionInfoCache.getConnectionInfo();
assertThat(
connectionInfo
.getClientCertificate()
.getNotAfter()
.toInstant()
.truncatedTo(ChronoUnit.SECONDS))
.isEqualTo(ONE_HOUR_FROM_NOW.truncatedTo(ChronoUnit.SECONDS));

connectionInfoCache.forceRefresh();
// This second call should be ignored as there is a refresh operation in progress.
connectionInfoCache.forceRefresh();

// After the force refresh, new refresh data is available.
connectionInfo = connectionInfoCache.getConnectionInfo();
assertThat(
connectionInfo
.getClientCertificate()
.getNotAfter()
.toInstant()
.truncatedTo(ChronoUnit.SECONDS))
.isEqualTo(TWO_HOURS_FROM_NOW.truncatedTo(ChronoUnit.SECONDS));
}

private static class SpyRateLimiter implements RateLimiter {
public final AtomicBoolean wasRateLimited = new AtomicBoolean(false);

Expand Down

0 comments on commit 7d44e0d

Please sign in to comment.