Skip to content

Commit

Permalink
chore: Refactor RefreshAheadConnectionInfoCache. Part of #992.
Browse files Browse the repository at this point in the history
The lazy refresh strategy only refreshes credentials and certificate information when
the application attempts to establish a new database connection. On Cloud Run
and other serverless runtimes, this is more reliable than the default background
refresh strategy.   

Fixes #992.

WIP Refactor BaseConnectionInfoCache

chore: Refactor RefreshAheadConnectionInfoCache. Part of #992.

The lazy refresh strategy only refreshes credentials and certificate information when
the application attempts to establish a new database connection. On Cloud Run
and other serverless runtimes, this is more reliable than the default background
refresh strategy.   

Fixes #992.

WIP Refactor BaseConnectionInfoCache
  • Loading branch information
hessjcg committed May 24, 2024
1 parent b7493bc commit 3716afd
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,62 +19,39 @@
import com.google.cloud.sql.AuthType;
import com.google.cloud.sql.CredentialFactory;
import com.google.cloud.sql.IpType;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import java.io.IOException;
import java.security.KeyPair;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.net.ssl.SSLSocket;

/**
* This class manages information on and creates connections to a Cloud SQL instance using the Cloud
* SQL Admin API. The operations to retrieve information with the API are largely done
* asynchronously, and this class should be considered threadsafe.
*/
class BaseConnectionInfoCache implements ConnectionInfoCache {
private final AccessTokenSupplier accessTokenSupplier;
private final CloudSqlInstanceName instanceName;
private final RefreshStrategy refreshStrategy;
private final ConnectionConfig config;

/**
* Initializes a new Cloud SQL instance based on the given connection name.
*
* @param config instance connection name in the format "PROJECT_ID:REGION_ID:INSTANCE_ID"
* @param connectionInfoRepository Service class for interacting with the Cloud SQL Admin API
* @param executor executor used to schedule asynchronous tasks
* @param keyPair public/private key pair used to authenticate connections
*/
BaseConnectionInfoCache(
ConnectionConfig config,
ConnectionInfoRepository connectionInfoRepository,
CredentialFactory tokenSourceFactory,
ListeningScheduledExecutorService executor,
ListenableFuture<KeyPair> keyPair,
long minRefreshDelayMs) {
this.instanceName = new CloudSqlInstanceName(config.getCloudSqlInstance());
this.config = config;

abstract class BaseConnectionInfoCache implements ConnectionInfoCache {
protected final AccessTokenSupplier accessTokenSupplier;
protected final CloudSqlInstanceName instanceName;
protected final RefreshStrategy refreshStrategy;
protected final ConnectionConfig config;

protected static AccessTokenSupplier newAccessTokenSupplier(
ConnectionConfig config, CredentialFactory tokenSourceFactory) {
if (config.getAuthType() == AuthType.IAM) {
this.accessTokenSupplier = new DefaultAccessTokenSupplier(tokenSourceFactory);
return new DefaultAccessTokenSupplier(tokenSourceFactory);
} else {
this.accessTokenSupplier = Optional::empty;
return Optional::empty;
}
}

// Initialize the data refresher to retrieve instance data.
refreshStrategy =
new RefreshAheadStrategy(
config.getCloudSqlInstance(),
executor,
() ->
connectionInfoRepository.getConnectionInfo(
this.instanceName,
this.accessTokenSupplier,
config.getAuthType(),
executor,
keyPair),
new AsyncRateLimiter(minRefreshDelayMs));
/** Initializes a new Cloud SQL instance based on the given connection name. */
BaseConnectionInfoCache(
ConnectionConfig config,
AccessTokenSupplier accessTokenSupplier,
RefreshStrategy refreshStrategy) {

this.config = config;
this.instanceName = new CloudSqlInstanceName(config.getCloudSqlInstance());
this.accessTokenSupplier = accessTokenSupplier;
this.refreshStrategy = refreshStrategy;
}

/**
Expand All @@ -92,16 +69,6 @@ private ConnectionInfo getConnectionInfo(long timeoutMs) {
return refreshStrategy.getConnectionInfo(timeoutMs);
}

/**
* Returns an unconnected {@link SSLSocket} using the SSLContext associated with the instance. May
* block until required instance data is available.
*/
@Override
public SSLSocket createSslSocket(long timeoutMs) throws IOException {
return (SSLSocket)
getConnectionInfo(timeoutMs).getSslContext().getSocketFactory().createSocket();
}

/**
* Returns metadata needed to create a connection to the instance.
*
Expand Down Expand Up @@ -131,7 +98,8 @@ public ConnectionMetadata getConnectionMetadata(long timeoutMs) {
return new ConnectionMetadata(
preferredIp,
info.getSslData().getKeyManagerFactory(),
info.getSslData().getTrustManagerFactory());
info.getSslData().getTrustManagerFactory(),
info.getSslData().getSslContext());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,9 @@

package com.google.cloud.sql.core;

import java.io.IOException;
import javax.net.ssl.SSLSocket;

/** ConnectionInfoCache is the contract for a caching strategy for ConnectionInfo. */
public interface ConnectionInfoCache {

/**
* Returns an unconnected {@link SSLSocket} using the SSLContext associated with the instance. May
* block until required instance data is available.
*/
SSLSocket createSslSocket(long timeoutMs) throws IOException;

/**
* Returns metadata needed to create a connection to the instance.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.cloud.sql.core;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;

/**
Expand All @@ -27,16 +28,19 @@ public class ConnectionMetadata {
private final String preferredIpAddress;
private final KeyManagerFactory keyManagerFactory;
private final TrustManagerFactory trustManagerFactory;
private final SSLContext sslContext;

/** Construct an immutable ConnectionMetadata. */
public ConnectionMetadata(
String preferredIpAddress,
KeyManagerFactory keyManagerFactory,
TrustManagerFactory trustManagerFactory) {
TrustManagerFactory trustManagerFactory,
SSLContext sslContext) {

this.preferredIpAddress = preferredIpAddress;
this.keyManagerFactory = keyManagerFactory;
this.trustManagerFactory = trustManagerFactory;
this.sslContext = sslContext;
}

public String getPreferredIpAddress() {
Expand All @@ -50,4 +54,8 @@ public KeyManagerFactory getKeyManagerFactory() {
public TrustManagerFactory getTrustManagerFactory() {
return trustManagerFactory;
}

public SSLContext getSslContext() {
return sslContext;
}
}
8 changes: 4 additions & 4 deletions core/src/main/java/com/google/cloud/sql/core/Connector.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException {

ConnectionInfoCache instance = getConnection(config);
try {

String instanceIp = instance.getConnectionMetadata(timeoutMs).getPreferredIpAddress();
ConnectionMetadata metadata = instance.getConnectionMetadata(timeoutMs);
String instanceIp = metadata.getPreferredIpAddress();
logger.debug(String.format("[%s] Connecting to instance.", instanceIp));

SSLSocket socket = instance.createSslSocket(timeoutMs);
SSLSocket socket = (SSLSocket) metadata.getSslContext().getSocketFactory().createSocket();
socket.setKeepAlive(true);
socket.setTcpNoDelay(true);
socket.connect(new InetSocketAddress(instanceIp, serverProxyPort));
Expand Down Expand Up @@ -154,7 +154,7 @@ ConnectionInfoCache getConnection(ConnectionConfig config) {
private ConnectionInfoCache createConnectionInfo(ConnectionConfig config) {
logger.debug(
String.format("[%s] Connection info added to cache.", config.getCloudSqlInstance()));
return new BaseConnectionInfoCache(
return RefreshAheadConnectionInfoCache.newInstance(
config, adminApi, instanceCredentialFactory, executor, localKeyPair, minRefreshDelayMs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ public DefaultConnectionInfoRepositoryFactory(String userAgents) {
@Override
public DefaultConnectionInfoRepository create(
HttpRequestInitializer requestInitializer, ConnectorConfig config) {
SQLAdmin adminApiBuilder = getApiBuilder(requestInitializer, config);
return new DefaultConnectionInfoRepository(adminApiBuilder);
}

private SQLAdmin getApiBuilder(
HttpRequestInitializer requestInitializer, ConnectorConfig config) {
HttpTransport httpTransport;
try {
httpTransport = GoogleNetHttpTransport.newTrustedTransport();
Expand Down Expand Up @@ -71,6 +77,6 @@ public DefaultConnectionInfoRepository create(
if (config.getUniverseDomain() != null) {
adminApiBuilder.setUniverseDomain(config.getUniverseDomain());
}
return new DefaultConnectionInfoRepository(adminApiBuilder.build());
return adminApiBuilder.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.sql.core;

import com.google.cloud.sql.CredentialFactory;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import java.security.KeyPair;

/**
* Implements the refresh ahead cache strategy, which will load the new ConnectionInfo using a
* background thread before its certificate expires.
*/
class RefreshAheadConnectionInfoCache extends BaseConnectionInfoCache {

/**
* Initializes a new Cloud SQL instance based on the given connection name. + * Initializes a new
* Cloud SQL instance based on the given connection name using the background + * refresh
* strategy.
*
* @param config instance connection name in the format "PROJECT_ID:REGION_ID:INSTANCE_ID"
* @param connectionInfoRepository Service class for interacting with the Cloud SQL Admin API
* @param executor executor used to schedule asynchronous tasks
* @param keyPair public/private key pair used to authenticate connections
*/
public static RefreshAheadConnectionInfoCache newInstance(
ConnectionConfig config,
ConnectionInfoRepository connectionInfoRepository,
CredentialFactory tokenSourceFactory,
ListeningScheduledExecutorService executor,
ListenableFuture<KeyPair> keyPair,
long minRefreshDelayMs) {

AccessTokenSupplier accessTokenSupplier =
BaseConnectionInfoCache.newAccessTokenSupplier(config, tokenSourceFactory);
CloudSqlInstanceName instanceName = new CloudSqlInstanceName(config.getCloudSqlInstance());

RefreshAheadStrategy strategy =
new RefreshAheadStrategy(
config.getCloudSqlInstance(),
executor,
() ->
connectionInfoRepository.getConnectionInfo(
instanceName, accessTokenSupplier, config.getAuthType(), executor, keyPair),
new AsyncRateLimiter(minRefreshDelayMs));

return new RefreshAheadConnectionInfoCache(config, accessTokenSupplier, strategy);
}

private RefreshAheadConnectionInfoCache(
ConnectionConfig config,
AccessTokenSupplier accessTokenSupplier,
RefreshStrategy refreshStrategy) {
super(config, accessTokenSupplier, refreshStrategy);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultConnectionInfoCacheConcurrencyTest {
public class RefreshAheadConnectionInfoCacheConcurrencyTest {

public static final int DEFAULT_WAIT = 200;

private static final Logger logger =
LoggerFactory.getLogger(DefaultConnectionInfoCacheConcurrencyTest.class);
LoggerFactory.getLogger(RefreshAheadConnectionInfoCacheConcurrencyTest.class);
public static final int FORCE_REFRESH_COUNT = 10;

private static class TestCredentialFactory implements CredentialFactory, HttpRequestInitializer {
Expand Down Expand Up @@ -68,7 +68,7 @@ public void testForceRefreshDoesNotCauseADeadlockOrBrokenRefreshLoop() throws Ex

for (int i = 0; i < instanceCount; i++) {
caches.add(
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache.newInstance(
new ConnectionConfig.Builder().withCloudSqlInstance("a:b:instance" + i).build(),
supplier,
new TestCredentialFactory(),
Expand Down
Loading

0 comments on commit 3716afd

Please sign in to comment.