diff --git a/core/src/main/java/com/google/cloud/sql/core/CloudSqlInstance.java b/core/src/main/java/com/google/cloud/sql/core/CloudSqlInstance.java index 97805f7ef..8de70e4d4 100644 --- a/core/src/main/java/com/google/cloud/sql/core/CloudSqlInstance.java +++ b/core/src/main/java/com/google/cloud/sql/core/CloudSqlInstance.java @@ -34,6 +34,7 @@ import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.logging.Level; import java.util.logging.Logger; import javax.net.ssl.SSLSocket; @@ -108,13 +109,16 @@ class CloudSqlInstance { * Returns the current data related to the instance from {@link #performRefresh()}. May block if * no valid data is currently available. */ - private InstanceData getInstanceData() { + private InstanceData getInstanceData(long timeoutMs) { ListenableFuture instanceDataFuture; synchronized (instanceDataGuard) { instanceDataFuture = currentInstanceData; } try { - return Uninterruptibles.getUninterruptibly(instanceDataFuture); + return Uninterruptibles.getUninterruptibly( + instanceDataFuture, timeoutMs, TimeUnit.MILLISECONDS); + } catch (TimeoutException ex) { + throw new RuntimeException(ex); } catch (ExecutionException ex) { Throwable cause = ex.getCause(); Throwables.throwIfUnchecked(cause); @@ -126,8 +130,8 @@ private InstanceData getInstanceData() { * Returns an unconnected {@link SSLSocket} using the SSLContext associated with the instance. May * block until required instance data is available. */ - SSLSocket createSslSocket() throws IOException { - return (SSLSocket) getInstanceData().getSslContext().getSocketFactory().createSocket(); + SSLSocket createSslSocket(long timeoutMs) throws IOException { + return (SSLSocket) getInstanceData(timeoutMs).getSslContext().getSocketFactory().createSocket(); } /** @@ -140,8 +144,8 @@ SSLSocket createSslSocket() throws IOException { * @throws IllegalArgumentException If the instance has no IP addresses matching the provided * preferences. */ - String getPreferredIp(List preferredTypes) { - Map ipAddrs = getInstanceData().getIpAddrs(); + String getPreferredIp(List preferredTypes, long timeoutMs) { + Map ipAddrs = getInstanceData(timeoutMs).getIpAddrs(); for (String ipType : preferredTypes) { String preferredIp = ipAddrs.get(ipType); if (preferredIp != null) { @@ -237,8 +241,8 @@ private InstanceData performRefresh() throws InterruptedException, ExecutionExce } } - SslData getSslData() { - return getInstanceData().getSslData(); + SslData getSslData(long timeoutMs) { + return getInstanceData(timeoutMs).getSslData(); } ListenableFuture getNext() { diff --git a/core/src/main/java/com/google/cloud/sql/core/CoreSocketFactory.java b/core/src/main/java/com/google/cloud/sql/core/CoreSocketFactory.java index 3895e03f0..19d73d039 100644 --- a/core/src/main/java/com/google/cloud/sql/core/CoreSocketFactory.java +++ b/core/src/main/java/com/google/cloud/sql/core/CoreSocketFactory.java @@ -70,6 +70,7 @@ public final class CoreSocketFactory { */ @Deprecated public static final String USER_TOKEN_PROPERTY_NAME = "_CLOUD_SQL_USER_TOKEN"; + static final long DEFAULT_MAX_REFRESH_MS = 30000; public static final String DEFAULT_IP_TYPES = "PUBLIC,PRIVATE"; private static final String UNIX_SOCKET_PROPERTY = "unixSocketPath"; private static final Logger logger = Logger.getLogger(CoreSocketFactory.class.getName()); @@ -84,6 +85,7 @@ public final class CoreSocketFactory { private final ListeningScheduledExecutorService executor; private final CredentialFactory credentialFactory; private final int serverProxyPort; + private final long refreshTimeoutMs; private final ApiFetcherFactory apiFetcherFactory; @VisibleForTesting @@ -92,12 +94,14 @@ public final class CoreSocketFactory { ApiFetcherFactory apiFetcherFactory, CredentialFactory credentialFactory, int serverProxyPort, + long refreshTimeoutMs, ListeningScheduledExecutorService executor) { this.apiFetcherFactory = apiFetcherFactory; this.credentialFactory = credentialFactory; this.serverProxyPort = serverProxyPort; this.executor = executor; this.localKeyPair = localKeyPair; + this.refreshTimeoutMs = refreshTimeoutMs; } /** Returns the {@link CoreSocketFactory} singleton. */ @@ -115,6 +119,7 @@ public static synchronized CoreSocketFactory getInstance() { new SqlAdminApiFetcherFactory(getUserAgents()), credentialFactory, DEFAULT_SERVER_PROXY_PORT, + CoreSocketFactory.DEFAULT_MAX_REFRESH_MS, executor); } return coreSocketFactory; @@ -225,13 +230,15 @@ public static SslData getSslData( List delegates) throws IOException { if (enableIamAuth) { - return getInstance() + CoreSocketFactory factory = getInstance(); + return factory .getCloudSqlInstance(csqlInstanceName, AuthType.IAM, targetPrincipal, delegates) - .getSslData(); + .getSslData(factory.refreshTimeoutMs); } - return getInstance() + CoreSocketFactory factory = getInstance(); + return factory .getCloudSqlInstance(csqlInstanceName, AuthType.PASSWORD, targetPrincipal, delegates) - .getSslData(); + .getSslData(factory.refreshTimeoutMs); } /** Returns preferred ip address that can be used to establish Cloud SQL connection. */ @@ -246,7 +253,7 @@ private String getHostIp( String instanceName, List ipTypes, String targetPrincipal, List delegates) { CloudSqlInstance instance = getCloudSqlInstance(instanceName, AuthType.PASSWORD, targetPrincipal, delegates); - return instance.getPreferredIp(ipTypes); + return instance.getPreferredIp(ipTypes, refreshTimeoutMs); } /** @@ -363,14 +370,14 @@ Socket createSslSocket( getCloudSqlInstance(instanceName, authType, targetPrincipal, delegates); try { - SSLSocket socket = instance.createSslSocket(); + SSLSocket socket = instance.createSslSocket(this.refreshTimeoutMs); // TODO(kvg): Support all socket related options listed here: // https://dev.mysql.com/doc/connector-j/en/connector-j-reference-configuration-properties.html socket.setKeepAlive(true); socket.setTcpNoDelay(true); - String instanceIp = instance.getPreferredIp(ipTypes); + String instanceIp = instance.getPreferredIp(ipTypes, refreshTimeoutMs); socket.connect(new InetSocketAddress(instanceIp, serverProxyPort)); socket.startHandshake(); diff --git a/core/src/test/java/com/google/cloud/sql/core/CloudSqlInstanceConcurrencyTest.java b/core/src/test/java/com/google/cloud/sql/core/CloudSqlInstanceConcurrencyTest.java index 8d12033c8..e8e68b9f1 100644 --- a/core/src/test/java/com/google/cloud/sql/core/CloudSqlInstanceConcurrencyTest.java +++ b/core/src/test/java/com/google/cloud/sql/core/CloudSqlInstanceConcurrencyTest.java @@ -78,7 +78,7 @@ public void testForceRefreshDoesNotCauseADeadlockOrBrokenRefreshLoop() throws Ex } // Get SSL Data for each instance, forcing the first refresh to complete. - instances.forEach(CloudSqlInstance::getSslData); + instances.forEach((inst) -> inst.getSslData(2000L)); assertThat(supplier.counter.get()).isEqualTo(instanceCount); @@ -117,7 +117,7 @@ private Thread startForceRefreshThread(CloudSqlInstance inst) { inst.forceRefresh(); inst.forceRefresh(); Thread.sleep(0); - inst.getSslData(); + inst.getSslData(2000L); } catch (Exception e) { logger.info("Exception in force refresh loop."); } diff --git a/core/src/test/java/com/google/cloud/sql/core/CloudSqlInstanceTest.java b/core/src/test/java/com/google/cloud/sql/core/CloudSqlInstanceTest.java index ec178e371..41b48a058 100644 --- a/core/src/test/java/com/google/cloud/sql/core/CloudSqlInstanceTest.java +++ b/core/src/test/java/com/google/cloud/sql/core/CloudSqlInstanceTest.java @@ -39,6 +39,8 @@ public class CloudSqlInstanceTest { + public static final long TEST_TIMEOUT_MS = 1000; + @SuppressWarnings("UnstableApiUsage") public static final RateLimiter TEST_RATE_LIMITER = RateLimiter.create(1000 /* permits per second */); @@ -74,7 +76,7 @@ public void testCloudSqlInstanceDataRetrievedSuccessfully() throws Exception { keyPairFuture, TEST_RATE_LIMITER); - SslData gotSslData = instance.getSslData(); + SslData gotSslData = instance.getSslData(TEST_TIMEOUT_MS); assertThat(gotSslData).isSameInstanceAs(instanceDataSupplier.response.getSslData()); assertThat(instanceDataSupplier.counter.get()).isEqualTo(1); } @@ -102,10 +104,41 @@ public void testInstanceFailsOnConnectionError() throws Exception { keyPairFuture, TEST_RATE_LIMITER); - RuntimeException ex = Assert.assertThrows(RuntimeException.class, instance::getSslData); + RuntimeException ex = + Assert.assertThrows(RuntimeException.class, () -> instance.getSslData(TEST_TIMEOUT_MS)); assertThat(ex).hasMessageThat().contains("always fails"); } + @Test + public void testInstanceFailsOnTooLongToRetrieve() { + PauseCondition cond = new PauseCondition(); + InstanceDataSupplier instanceDataSupplier = + (CloudSqlInstanceName instanceName, + AccessTokenSupplier accessTokenSupplier, + AuthType authType, + ListeningScheduledExecutorService exec, + ListenableFuture keyPair) -> { + // This is never allowed to proceed + cond.pause(); + throw new RuntimeException("fake read timeout"); + }; + + // initialize instance after mocks are set up + CloudSqlInstance instance = + new CloudSqlInstance( + "project:region:instance", + instanceDataSupplier, + AuthType.PASSWORD, + stubCredentialFactory, + executorService, + keyPairFuture, + RateLimiter.create(10)); + + RuntimeException ex = + Assert.assertThrows(RuntimeException.class, () -> instance.getSslData(2000)); + assertThat(ex).hasMessageThat().contains("java.util.concurrent.TimeoutException"); + } + @Test public void testCloudSqlInstanceForcesRefresh() throws Exception { SslData sslData = new SslData(null, null, null); @@ -132,7 +165,7 @@ public void testCloudSqlInstanceForcesRefresh() throws Exception { keyPairFuture, TEST_RATE_LIMITER); - instance.getSslData(); + instance.getSslData(TEST_TIMEOUT_MS); assertThat(refreshCount.get()).isEqualTo(1); // Force refresh, which will start, but not finish the refresh process. @@ -140,7 +173,7 @@ public void testCloudSqlInstanceForcesRefresh() throws Exception { // Then immediately getSslData() and assert that the refresh count has not changed. // Refresh count hasn't changed because we re-use the existing connection info. - instance.getSslData(); + instance.getSslData(TEST_TIMEOUT_MS); assertThat(refreshCount.get()).isEqualTo(1); // Allow the second refresh operation to complete @@ -149,7 +182,7 @@ public void testCloudSqlInstanceForcesRefresh() throws Exception { cond.waitForCondition(() -> refreshCount.get() >= 2, 1000L); // getSslData again, and assert the refresh operation completed. - instance.getSslData(); + instance.getSslData(TEST_TIMEOUT_MS); assertThat(refreshCount.get()).isEqualTo(2); } @@ -193,7 +226,7 @@ public void testCloudSqlRefreshesExpiredData() throws Exception { TEST_RATE_LIMITER); // Get the first data that is about to expire - SslData d = instance.getSslData(); + SslData d = instance.getSslData(TEST_TIMEOUT_MS); assertThat(refreshCount.get()).isEqualTo(1); assertThat(d).isSameInstanceAs(initialData.getSslData()); @@ -205,14 +238,15 @@ public void testCloudSqlRefreshesExpiredData() throws Exception { // Now that the InstanceData has expired, getInstanceData will return the same, expired // token until a new one is retrieved. assertThat(refreshCount.get()).isEqualTo(1); - assertThat(instance.getSslData()).isSameInstanceAs(initialData.getSslData()); + assertThat(instance.getSslData(TEST_TIMEOUT_MS)).isSameInstanceAs(initialData.getSslData()); // Allow the second refresh operation to complete refresh1.proceed(); refresh1.waitForPauseToEnd(1000L); // getSslData again, and assert the refresh operation completed. - refresh1.waitForCondition(() -> instance.getSslData() == data.getSslData(), 1000L); + refresh1.waitForCondition( + () -> instance.getSslData(TEST_TIMEOUT_MS) == data.getSslData(), 1000L); } @Test @@ -252,7 +286,7 @@ public void testThatForceRefreshBalksWhenAForceRefreshIsInProgress() throws Exce TEST_RATE_LIMITER); // Get the first data that is about to expire - SslData d = instance.getSslData(); + SslData d = instance.getSslData(TEST_TIMEOUT_MS); assertThat(refreshCount.get()).isEqualTo(1); assertThat(d).isSameInstanceAs(initialData.getSslData()); @@ -268,17 +302,10 @@ public void testThatForceRefreshBalksWhenAForceRefreshIsInProgress() throws Exce refresh1.waitForPauseToEnd(1000); refresh1.waitForCondition(() -> refreshCount.get() >= 2, 1000); - // getSslData until the refresh operation returns the newer - // SslData instance - SslData d2 = instance.getSslData(); - for (int i = 0; i < 10 && d2 != data.getSslData(); i++) { - Thread.sleep(10); - d2 = instance.getSslData(); - } - // assert the refresh operation completed exactly once after // forceRefresh was called multiple times. - assertThat(d2).isSameInstanceAs(data.getSslData()); + refresh1.waitForCondition( + () -> instance.getSslData(TEST_TIMEOUT_MS) == data.getSslData(), 1000L); assertThat(refreshCount.get()).isEqualTo(2); } @@ -314,13 +341,15 @@ public void testGetPreferredIpTypes() throws Exception { keyPairFuture, TEST_RATE_LIMITER); - assertThat(instance.getPreferredIp(Arrays.asList("PUBLIC", "PRIVATE"))).isEqualTo("10.1.2.3"); - assertThat(instance.getPreferredIp(Collections.singletonList("PUBLIC"))).isEqualTo("10.1.2.3"); - assertThat(instance.getPreferredIp(Arrays.asList("PRIVATE", "PUBLIC"))) + assertThat(instance.getPreferredIp(Arrays.asList("PUBLIC", "PRIVATE"), TEST_TIMEOUT_MS)) + .isEqualTo("10.1.2.3"); + assertThat(instance.getPreferredIp(Collections.singletonList("PUBLIC"), TEST_TIMEOUT_MS)) + .isEqualTo("10.1.2.3"); + assertThat(instance.getPreferredIp(Arrays.asList("PRIVATE", "PUBLIC"), TEST_TIMEOUT_MS)) .isEqualTo("10.10.10.10"); - assertThat(instance.getPreferredIp(Collections.singletonList("PRIVATE"))) + assertThat(instance.getPreferredIp(Collections.singletonList("PRIVATE"), TEST_TIMEOUT_MS)) .isEqualTo("10.10.10.10"); - assertThat(instance.getPreferredIp(Collections.singletonList("PSC"))) + assertThat(instance.getPreferredIp(Collections.singletonList("PSC"), TEST_TIMEOUT_MS)) .isEqualTo("abcde.12345.us-central1.sql.goog"); } @@ -352,7 +381,7 @@ public void testGetPreferredIpTypesThrowsException() throws Exception { TEST_RATE_LIMITER); Assert.assertThrows( IllegalArgumentException.class, - () -> instance.getPreferredIp(Collections.singletonList("PRIVATE"))); + () -> instance.getPreferredIp(Collections.singletonList("PRIVATE"), TEST_TIMEOUT_MS)); } private ListeningScheduledExecutorService newTestExecutor() { diff --git a/core/src/test/java/com/google/cloud/sql/core/CoreSocketFactoryTest.java b/core/src/test/java/com/google/cloud/sql/core/CoreSocketFactoryTest.java index 094485ccb..8154e0893 100644 --- a/core/src/test/java/com/google/cloud/sql/core/CoreSocketFactoryTest.java +++ b/core/src/test/java/com/google/cloud/sql/core/CoreSocketFactoryTest.java @@ -43,6 +43,7 @@ // TODO(berezv): add multithreaded test @RunWith(JUnit4.class) public class CoreSocketFactoryTest extends CloudSqlCoreTestingBase { + private final long TEST_MAX_REFRESH_MS = 5000L; ListeningScheduledExecutorService defaultExecutor; @@ -62,7 +63,8 @@ public void create_throwsErrorForInvalidInstanceName() throws IOException { ApiFetcherFactory factory = new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, credentialFactory, 3307, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, factory, credentialFactory, 3307, TEST_MAX_REFRESH_MS, defaultExecutor); try { coreSocketFactory.createSslSocket( "myProject", @@ -95,7 +97,8 @@ public void create_throwsErrorForInvalidInstanceRegion() throws IOException { ApiFetcherFactory factory = new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, credentialFactory, 3307, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, factory, credentialFactory, 3307, TEST_MAX_REFRESH_MS, defaultExecutor); try { coreSocketFactory.createSslSocket( "myProject:notMyRegion:myInstance", @@ -125,7 +128,8 @@ public void create_successfulPrivateConnection() throws IOException, Interrupted ApiFetcherFactory factory = new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, credentialFactory, port, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, factory, credentialFactory, port, TEST_MAX_REFRESH_MS, defaultExecutor); Socket socket = coreSocketFactory.createSslSocket( "myProject:myRegion:myInstance", @@ -145,7 +149,8 @@ public void create_failOnEmptyTargetPrincipal() throws IOException, InterruptedE ApiFetcherFactory factory = new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, credentialFactory, port, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, factory, credentialFactory, port, TEST_MAX_REFRESH_MS, defaultExecutor); try { coreSocketFactory.createSslSocket( @@ -168,7 +173,8 @@ public void create_successfulConnection() throws IOException, InterruptedExcepti ApiFetcherFactory factory = new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, credentialFactory, port, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, factory, credentialFactory, port, TEST_MAX_REFRESH_MS, defaultExecutor); Socket socket = coreSocketFactory.createSslSocket( @@ -189,7 +195,8 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr ApiFetcherFactory factory = new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, credentialFactory, port, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, factory, credentialFactory, port, TEST_MAX_REFRESH_MS, defaultExecutor); Socket socket = coreSocketFactory.createSslSocket( "example.com:myProject:myRegion:myInstance", @@ -204,7 +211,8 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr public void create_adminApiNotEnabled() throws IOException { ApiFetcherFactory factory = new StubApiFetcherFactory(fakeNotConfiguredException()); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, credentialFactory, 3307, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, factory, credentialFactory, 3307, TEST_MAX_REFRESH_MS, defaultExecutor); try { // Use a different project to get Api Not Enabled Error. coreSocketFactory.createSslSocket( @@ -228,7 +236,8 @@ public void create_adminApiNotEnabled() throws IOException { public void create_notAuthorized() throws IOException { ApiFetcherFactory factory = new StubApiFetcherFactory(fakeNotAuthorizedException()); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, credentialFactory, 3307, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, factory, credentialFactory, 3307, TEST_MAX_REFRESH_MS, defaultExecutor); try { // Use a different instance to simulate incorrect permissions. coreSocketFactory.createSslSocket( @@ -262,7 +271,13 @@ public void supportsCustomCredentialFactoryWithIAM() throws InterruptedException new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, stubCredentialFactory, port, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, + factory, + stubCredentialFactory, + port, + TEST_MAX_REFRESH_MS, + defaultExecutor); Socket socket = coreSocketFactory.createSslSocket( "myProject:myRegion:myInstance", @@ -285,7 +300,13 @@ public void supportsCustomCredentialFactoryWithNoExpirationTime() ApiFetcherFactory factory = new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, stubCredentialFactory, port, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, + factory, + stubCredentialFactory, + port, + TEST_MAX_REFRESH_MS, + defaultExecutor); Socket socket = coreSocketFactory.createSslSocket( "myProject:myRegion:myInstance", @@ -315,7 +336,13 @@ public HttpRequestInitializer create() { ApiFetcherFactory factory = new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); CoreSocketFactory coreSocketFactory = - new CoreSocketFactory(clientKeyPair, factory, stubCredentialFactory, port, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, + factory, + stubCredentialFactory, + port, + TEST_MAX_REFRESH_MS, + defaultExecutor); assertThrows( RuntimeException.class, () -> diff --git a/r2dbc/core/src/test/java/com/google/cloud/sql/core/GcpConnectionFactoryProviderTest.java b/r2dbc/core/src/test/java/com/google/cloud/sql/core/GcpConnectionFactoryProviderTest.java index 6f9582178..6d409cffc 100644 --- a/r2dbc/core/src/test/java/com/google/cloud/sql/core/GcpConnectionFactoryProviderTest.java +++ b/r2dbc/core/src/test/java/com/google/cloud/sql/core/GcpConnectionFactoryProviderTest.java @@ -178,6 +178,12 @@ public void setup() throws GeneralSecurityException { new StubApiFetcherFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); coreSocketFactoryStub = - new CoreSocketFactory(clientKeyPair, fetcher, credentialFactory, 3307, defaultExecutor); + new CoreSocketFactory( + clientKeyPair, + fetcher, + credentialFactory, + 3307, + CoreSocketFactory.DEFAULT_MAX_REFRESH_MS, + defaultExecutor); } }