Skip to content

Commit

Permalink
test: set timeout for refresh operation in CoreSocketFactory.
Browse files Browse the repository at this point in the history
  • Loading branch information
hessjcg committed Oct 4, 2023
1 parent 2006a87 commit 554ee9f
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 47 deletions.
20 changes: 11 additions & 9 deletions core/src/main/java/com/google/cloud/sql/core/CloudSqlInstance.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLSocket;
Expand Down Expand Up @@ -108,14 +109,15 @@ class CloudSqlInstance {
* Returns the current data related to the instance from {@link #performRefresh()}. May block if
* no valid data is currently available.
*/
private InstanceData getInstanceData() {
private InstanceData getInstanceData(long timeoutMs) {
ListenableFuture<InstanceData> instanceDataFuture;
synchronized (instanceDataGuard) {
instanceDataFuture = currentInstanceData;
}
try {
return Uninterruptibles.getUninterruptibly(instanceDataFuture);
} catch (ExecutionException ex) {
return Uninterruptibles.getUninterruptibly(
instanceDataFuture, timeoutMs, TimeUnit.MILLISECONDS);
} catch (ExecutionException | TimeoutException ex) {
Throwable cause = ex.getCause();
Throwables.throwIfUnchecked(cause);
throw new RuntimeException(cause);
Expand All @@ -126,8 +128,8 @@ private InstanceData getInstanceData() {
* Returns an unconnected {@link SSLSocket} using the SSLContext associated with the instance. May
* block until required instance data is available.
*/
SSLSocket createSslSocket() throws IOException {
return (SSLSocket) getInstanceData().getSslContext().getSocketFactory().createSocket();
SSLSocket createSslSocket(long timeoutMs) throws IOException {
return (SSLSocket) getInstanceData(timeoutMs).getSslContext().getSocketFactory().createSocket();
}

/**
Expand All @@ -140,8 +142,8 @@ SSLSocket createSslSocket() throws IOException {
* @throws IllegalArgumentException If the instance has no IP addresses matching the provided
* preferences.
*/
String getPreferredIp(List<String> preferredTypes) {
Map<String, String> ipAddrs = getInstanceData().getIpAddrs();
String getPreferredIp(List<String> preferredTypes, long timeoutMs) {
Map<String, String> ipAddrs = getInstanceData(timeoutMs).getIpAddrs();
for (String ipType : preferredTypes) {
String preferredIp = ipAddrs.get(ipType);
if (preferredIp != null) {
Expand Down Expand Up @@ -237,8 +239,8 @@ private InstanceData performRefresh() throws InterruptedException, ExecutionExce
}
}

SslData getSslData() {
return getInstanceData().getSslData();
SslData getSslData(long timeoutMs) {
return getInstanceData(timeoutMs).getSslData();
}

ListenableFuture<InstanceData> getNext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public final class CoreSocketFactory {
*/
@Deprecated public static final String USER_TOKEN_PROPERTY_NAME = "_CLOUD_SQL_USER_TOKEN";

static final long DEFAULT_MAX_REFRESH_MS = 30000;
public static final String DEFAULT_IP_TYPES = "PUBLIC,PRIVATE";
private static final String UNIX_SOCKET_PROPERTY = "unixSocketPath";
private static final Logger logger = Logger.getLogger(CoreSocketFactory.class.getName());
Expand All @@ -84,6 +85,7 @@ public final class CoreSocketFactory {
private final ListeningScheduledExecutorService executor;
private final CredentialFactory credentialFactory;
private final int serverProxyPort;
private final long refreshTimeoutMs;
private final ApiFetcherFactory apiFetcherFactory;

@VisibleForTesting
Expand All @@ -92,12 +94,14 @@ public final class CoreSocketFactory {
ApiFetcherFactory apiFetcherFactory,
CredentialFactory credentialFactory,
int serverProxyPort,
long refreshTimeoutMs,
ListeningScheduledExecutorService executor) {
this.apiFetcherFactory = apiFetcherFactory;
this.credentialFactory = credentialFactory;
this.serverProxyPort = serverProxyPort;
this.executor = executor;
this.localKeyPair = localKeyPair;
this.refreshTimeoutMs = refreshTimeoutMs;
}

/** Returns the {@link CoreSocketFactory} singleton. */
Expand All @@ -115,6 +119,7 @@ public static synchronized CoreSocketFactory getInstance() {
new SqlAdminApiFetcherFactory(getUserAgents()),
credentialFactory,
DEFAULT_SERVER_PROXY_PORT,
CoreSocketFactory.DEFAULT_MAX_REFRESH_MS,
executor);
}
return coreSocketFactory;
Expand Down Expand Up @@ -225,13 +230,15 @@ public static SslData getSslData(
List<String> delegates)
throws IOException {
if (enableIamAuth) {
return getInstance()
CoreSocketFactory factory = getInstance();
return factory
.getCloudSqlInstance(csqlInstanceName, AuthType.IAM, targetPrincipal, delegates)
.getSslData();
.getSslData(factory.refreshTimeoutMs);
}
return getInstance()
CoreSocketFactory factory = getInstance();
return factory
.getCloudSqlInstance(csqlInstanceName, AuthType.PASSWORD, targetPrincipal, delegates)
.getSslData();
.getSslData(factory.refreshTimeoutMs);
}

/** Returns preferred ip address that can be used to establish Cloud SQL connection. */
Expand All @@ -246,7 +253,7 @@ private String getHostIp(
String instanceName, List<String> ipTypes, String targetPrincipal, List<String> delegates) {
CloudSqlInstance instance =
getCloudSqlInstance(instanceName, AuthType.PASSWORD, targetPrincipal, delegates);
return instance.getPreferredIp(ipTypes);
return instance.getPreferredIp(ipTypes, refreshTimeoutMs);
}

/**
Expand Down Expand Up @@ -363,14 +370,14 @@ Socket createSslSocket(
getCloudSqlInstance(instanceName, authType, targetPrincipal, delegates);

try {
SSLSocket socket = instance.createSslSocket();
SSLSocket socket = instance.createSslSocket(this.refreshTimeoutMs);

// TODO(kvg): Support all socket related options listed here:
// https://dev.mysql.com/doc/connector-j/en/connector-j-reference-configuration-properties.html
socket.setKeepAlive(true);
socket.setTcpNoDelay(true);

String instanceIp = instance.getPreferredIp(ipTypes);
String instanceIp = instance.getPreferredIp(ipTypes, refreshTimeoutMs);

socket.connect(new InetSocketAddress(instanceIp, serverProxyPort));
socket.startHandshake();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public void testForceRefreshDoesNotCauseADeadlockOrBrokenRefreshLoop() throws Ex
}

// Get SSL Data for each instance, forcing the first refresh to complete.
instances.forEach(CloudSqlInstance::getSslData);
instances.forEach((inst) -> inst.getSslData(2000L));

assertThat(supplier.counter.get()).isEqualTo(instanceCount);

Expand Down Expand Up @@ -117,7 +117,7 @@ private Thread startForceRefreshThread(CloudSqlInstance inst) {
inst.forceRefresh();
inst.forceRefresh();
Thread.sleep(0);
inst.getSslData();
inst.getSslData(2000L);
} catch (Exception e) {
logger.info("Exception in force refresh loop.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.junit.Test;

public class CloudSqlInstanceTest {
public static final long TEST_TIMEOUT_MS = 1000;

@SuppressWarnings("UnstableApiUsage")
public static final RateLimiter TEST_RATE_LIMITER =
Expand Down Expand Up @@ -74,7 +75,7 @@ public void testCloudSqlInstanceDataRetrievedSuccessfully() throws Exception {
keyPairFuture,
TEST_RATE_LIMITER);

SslData gotSslData = instance.getSslData();
SslData gotSslData = instance.getSslData(TEST_TIMEOUT_MS);
assertThat(gotSslData).isSameInstanceAs(instanceDataSupplier.response.getSslData());
assertThat(instanceDataSupplier.counter.get()).isEqualTo(1);
}
Expand Down Expand Up @@ -102,7 +103,8 @@ public void testInstanceFailsOnConnectionError() throws Exception {
keyPairFuture,
TEST_RATE_LIMITER);

RuntimeException ex = Assert.assertThrows(RuntimeException.class, instance::getSslData);
RuntimeException ex =
Assert.assertThrows(RuntimeException.class, () -> instance.getSslData(TEST_TIMEOUT_MS));
assertThat(ex).hasMessageThat().contains("always fails");
}

Expand Down Expand Up @@ -132,15 +134,15 @@ public void testCloudSqlInstanceForcesRefresh() throws Exception {
keyPairFuture,
TEST_RATE_LIMITER);

instance.getSslData();
instance.getSslData(TEST_TIMEOUT_MS);
assertThat(refreshCount.get()).isEqualTo(1);

// Force refresh, which will start, but not finish the refresh process.
instance.forceRefresh();

// Then immediately getSslData() and assert that the refresh count has not changed.
// Refresh count hasn't changed because we re-use the existing connection info.
instance.getSslData();
instance.getSslData(TEST_TIMEOUT_MS);
assertThat(refreshCount.get()).isEqualTo(1);

// Allow the second refresh operation to complete
Expand All @@ -149,7 +151,7 @@ public void testCloudSqlInstanceForcesRefresh() throws Exception {
cond.waitForCondition(() -> refreshCount.get() >= 2, 1000L);

// getSslData again, and assert the refresh operation completed.
instance.getSslData();
instance.getSslData(TEST_TIMEOUT_MS);
assertThat(refreshCount.get()).isEqualTo(2);
}

Expand Down Expand Up @@ -193,7 +195,7 @@ public void testCloudSqlRefreshesExpiredData() throws Exception {
TEST_RATE_LIMITER);

// Get the first data that is about to expire
SslData d = instance.getSslData();
SslData d = instance.getSslData(TEST_TIMEOUT_MS);
assertThat(refreshCount.get()).isEqualTo(1);
assertThat(d).isSameInstanceAs(initialData.getSslData());

Expand All @@ -205,15 +207,16 @@ public void testCloudSqlRefreshesExpiredData() throws Exception {
// Now that the InstanceData has expired, getInstanceData will return the same, expired
// token until a new one is retrieved.
assertThat(refreshCount.get()).isEqualTo(1);
assertThat(instance.getSslData()).isSameInstanceAs(initialData.getSslData());
assertThat(instance.getSslData(TEST_TIMEOUT_MS))
.isSameInstanceAs(initialData.getSslData());

// Allow the second refresh operation to complete
refresh1.proceed();
refresh1.waitForPauseToEnd(1000L);
refresh1.waitForCondition(() -> refreshCount.get() > 1, 1000L);

// getSslData again, and assert the refresh operation completed.
SslData d2 = instance.getSslData();
SslData d2 = instance.getSslData(TEST_TIMEOUT_MS);
assertThat(d2).isSameInstanceAs(data.getSslData());
}

Expand Down Expand Up @@ -254,7 +257,7 @@ public void testThatForceRefreshBalksWhenAForceRefreshIsInProgress() throws Exce
TEST_RATE_LIMITER);

// Get the first data that is about to expire
SslData d = instance.getSslData();
SslData d = instance.getSslData(TEST_TIMEOUT_MS);
assertThat(refreshCount.get()).isEqualTo(1);
assertThat(d).isSameInstanceAs(initialData.getSslData());

Expand All @@ -272,10 +275,10 @@ public void testThatForceRefreshBalksWhenAForceRefreshIsInProgress() throws Exce

// getSslData until the refresh operation returns the newer
// SslData instance
SslData d2 = instance.getSslData();
SslData d2 = instance.getSslData(TEST_TIMEOUT_MS);
for (int i = 0; i < 10 && d2 != data.getSslData(); i++) {
Thread.sleep(10);
d2 = instance.getSslData();
d2 = instance.getSslData(TEST_TIMEOUT_MS);
}

// assert the refresh operation completed exactly once after
Expand Down Expand Up @@ -316,13 +319,15 @@ public void testGetPreferredIpTypes() throws Exception {
keyPairFuture,
TEST_RATE_LIMITER);

assertThat(instance.getPreferredIp(Arrays.asList("PUBLIC", "PRIVATE"))).isEqualTo("10.1.2.3");
assertThat(instance.getPreferredIp(Collections.singletonList("PUBLIC"))).isEqualTo("10.1.2.3");
assertThat(instance.getPreferredIp(Arrays.asList("PRIVATE", "PUBLIC")))
assertThat(instance.getPreferredIp(Arrays.asList("PUBLIC", "PRIVATE"), TEST_TIMEOUT_MS))
.isEqualTo("10.1.2.3");
assertThat(instance.getPreferredIp(Collections.singletonList("PUBLIC"), TEST_TIMEOUT_MS))
.isEqualTo("10.1.2.3");
assertThat(instance.getPreferredIp(Arrays.asList("PRIVATE", "PUBLIC"), TEST_TIMEOUT_MS))
.isEqualTo("10.10.10.10");
assertThat(instance.getPreferredIp(Collections.singletonList("PRIVATE")))
assertThat(instance.getPreferredIp(Collections.singletonList("PRIVATE"), TEST_TIMEOUT_MS))
.isEqualTo("10.10.10.10");
assertThat(instance.getPreferredIp(Collections.singletonList("PSC")))
assertThat(instance.getPreferredIp(Collections.singletonList("PSC"), TEST_TIMEOUT_MS))
.isEqualTo("abcde.12345.us-central1.sql.goog");
}

Expand Down Expand Up @@ -354,7 +359,7 @@ public void testGetPreferredIpTypesThrowsException() throws Exception {
TEST_RATE_LIMITER);
Assert.assertThrows(
IllegalArgumentException.class,
() -> instance.getPreferredIp(Collections.singletonList("PRIVATE")));
() -> instance.getPreferredIp(Collections.singletonList("PRIVATE"), TEST_TIMEOUT_MS));
}

private ListeningScheduledExecutorService newTestExecutor() {
Expand Down
Loading

0 comments on commit 554ee9f

Please sign in to comment.