Skip to content

Commit

Permalink
[BJJ: Enter a description for the combined commit.
Browse files Browse the repository at this point in the history
chore: Refactor RefreshAheadConnectionInfoCache. Part of #992.

The lazy refresh strategy only refreshes credentials and certificate information when
the application attempts to establish a new database connection. On Cloud Run
and other serverless runtimes, this is more reliable than the default background
refresh strategy.   

Fixes #992.
  • Loading branch information
hessjcg committed May 23, 2024
1 parent b7493bc commit 82d82ac
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
import com.google.cloud.sql.AuthType;
import com.google.cloud.sql.CredentialFactory;
import com.google.cloud.sql.IpType;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import java.io.IOException;
import java.security.KeyPair;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.net.ssl.SSLSocket;
Expand All @@ -32,51 +29,39 @@
* SQL Admin API. The operations to retrieve information with the API are largely done
* asynchronously, and this class should be considered threadsafe.
*/
class BaseConnectionInfoCache implements ConnectionInfoCache {
private final AccessTokenSupplier accessTokenSupplier;
private final CloudSqlInstanceName instanceName;
private final RefreshStrategy refreshStrategy;
private final ConnectionConfig config;
abstract class BaseConnectionInfoCache implements ConnectionInfoCache {
protected final AccessTokenSupplier accessTokenSupplier;
protected final CloudSqlInstanceName instanceName;
protected final RefreshStrategy refreshStrategy;
protected final ConnectionConfig config;

protected static AccessTokenSupplier newAccessTokenSupplier(
ConnectionConfig config, CredentialFactory tokenSourceFactory) {
if (config.getAuthType() == AuthType.IAM) {
return new DefaultAccessTokenSupplier(tokenSourceFactory);
} else {
return Optional::empty;
}
}

/**
* Initializes a new Cloud SQL instance based on the given connection name.
*
* @param config instance connection name in the format "PROJECT_ID:REGION_ID:INSTANCE_ID"
* @param connectionInfoRepository Service class for interacting with the Cloud SQL Admin API
* @param executor executor used to schedule asynchronous tasks
* @param keyPair public/private key pair used to authenticate connections
* @param tokenSourceFactory The token source factory
*/
BaseConnectionInfoCache(
ConnectionConfig config,
ConnectionInfoRepository connectionInfoRepository,
CredentialFactory tokenSourceFactory,
ListeningScheduledExecutorService executor,
ListenableFuture<KeyPair> keyPair,
long minRefreshDelayMs) {
BaseConnectionInfoCache(ConnectionConfig config, CredentialFactory tokenSourceFactory) {

this.instanceName = new CloudSqlInstanceName(config.getCloudSqlInstance());
this.config = config;

if (config.getAuthType() == AuthType.IAM) {
this.accessTokenSupplier = new DefaultAccessTokenSupplier(tokenSourceFactory);
} else {
this.accessTokenSupplier = Optional::empty;
}
this.accessTokenSupplier = newAccessTokenSupplier(config, tokenSourceFactory);

// Initialize the data refresher to retrieve instance data.
refreshStrategy =
new RefreshAheadStrategy(
config.getCloudSqlInstance(),
executor,
() ->
connectionInfoRepository.getConnectionInfo(
this.instanceName,
this.accessTokenSupplier,
config.getAuthType(),
executor,
keyPair),
new AsyncRateLimiter(minRefreshDelayMs));
this.refreshStrategy = initRefreshStrategy();
}

protected abstract RefreshStrategy initRefreshStrategy();

/**
* Returns the current data related to the instance. May block if no valid data is currently
* available. This method is called by an application thread when it is trying to create a new
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/java/com/google/cloud/sql/core/Connector.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.cloud.sql.ConnectorConfig;
import com.google.cloud.sql.CredentialFactory;
import com.google.cloud.sql.RefreshStrategy;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import java.io.File;
Expand All @@ -26,6 +27,7 @@
import java.net.Socket;
import java.security.KeyPair;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import javax.net.ssl.SSLSocket;
import jnr.unixsocket.UnixSocketAddress;
import jnr.unixsocket.UnixSocketChannel;
Expand Down Expand Up @@ -158,6 +160,9 @@ private ConnectionInfoCache createConnectionInfo(ConnectionConfig config) {
config, adminApi, instanceCredentialFactory, executor, localKeyPair, minRefreshDelayMs);
}

return new RefreshAheadConnectionInfoCache(
config, adminApi, instanceCredentialFactory, executor, localKeyPair, minRefreshDelayMs);
}
public void close() {
logger.debug("Close all connections and remove them from cache.");
this.instances.forEach((key, c) -> c.close());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ public DefaultConnectionInfoRepositoryFactory(String userAgents) {
@Override
public DefaultConnectionInfoRepository create(
HttpRequestInitializer requestInitializer, ConnectorConfig config) {
SQLAdmin adminApiBuilder = getApiBuilder(requestInitializer, config);
return new DefaultConnectionInfoRepository(adminApiBuilder);
}

private SQLAdmin getApiBuilder(
HttpRequestInitializer requestInitializer, ConnectorConfig config) {
HttpTransport httpTransport;
try {
httpTransport = GoogleNetHttpTransport.newTrustedTransport();
Expand Down Expand Up @@ -71,6 +77,6 @@ public DefaultConnectionInfoRepository create(
if (config.getUniverseDomain() != null) {
adminApiBuilder.setUniverseDomain(config.getUniverseDomain());
}
return new DefaultConnectionInfoRepository(adminApiBuilder.build());
return adminApiBuilder.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.sql.core;

import com.google.cloud.sql.CredentialFactory;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import java.security.KeyPair;

/**
* Implements the refresh ahead cache strategy, which will load the new ConnectionInfo using a
* background thread before its certificate expires.
*/
class RefreshAheadConnectionInfoCache extends BaseConnectionInfoCache {
private final ConnectionInfoRepository connectionInfoRepository;
private final ListeningScheduledExecutorService executor;
private final ListenableFuture<KeyPair> keyPair;
private final long minRefreshDelayMs;

/**
* - * Initializes a new Cloud SQL instance based on the given connection name. + * Initializes a
* new Cloud SQL instance based on the given connection name using the background + * refresh
* strategy.
*
* @param config instance connection name in the format "PROJECT_ID:REGION_ID:INSTANCE_ID"
* @param connectionInfoRepository Service class for interacting with the Cloud SQL Admin API
* @param executor executor used to schedule asynchronous tasks
* @param keyPair public/private key pair used to authenticate connections
*/
RefreshAheadConnectionInfoCache(
ConnectionConfig config,
ConnectionInfoRepository connectionInfoRepository,
CredentialFactory tokenSourceFactory,
ListeningScheduledExecutorService executor,
ListenableFuture<KeyPair> keyPair,
long minRefreshDelayMs) {
super(config, tokenSourceFactory);
this.connectionInfoRepository = connectionInfoRepository;
this.executor = executor;
this.keyPair = keyPair;
this.minRefreshDelayMs = minRefreshDelayMs;
}

@Override
protected RefreshStrategy initRefreshStrategy() {
return new RefreshAheadStrategy(
config.getCloudSqlInstance(),
executor,
this::connectionInfoSupplier,
new AsyncRateLimiter(minRefreshDelayMs));
}

private ListenableFuture<ConnectionInfo> connectionInfoSupplier() {
return connectionInfoRepository.getConnectionInfo(
this.instanceName, this.accessTokenSupplier, config.getAuthType(), executor, keyPair);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultConnectionInfoCacheConcurrencyTest {
public class RefreshAheadConnectionInfoCacheConcurrencyTest {

public static final int DEFAULT_WAIT = 200;

private static final Logger logger =
LoggerFactory.getLogger(DefaultConnectionInfoCacheConcurrencyTest.class);
LoggerFactory.getLogger(RefreshAheadConnectionInfoCacheConcurrencyTest.class);
public static final int FORCE_REFRESH_COUNT = 10;

private static class TestCredentialFactory implements CredentialFactory, HttpRequestInitializer {
Expand Down Expand Up @@ -68,7 +68,7 @@ public void testForceRefreshDoesNotCauseADeadlockOrBrokenRefreshLoop() throws Ex

for (int i = 0; i < instanceCount; i++) {
caches.add(
new BaseConnectionInfoCache(
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("a:b:instance" + i).build(),
supplier,
new TestCredentialFactory(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import org.junit.Before;
import org.junit.Test;

public class DefaultConnectionInfoCacheTest {
public class RefreshAheadConnectionInfoCacheTest {

public static final long TEST_TIMEOUT_MS = 3000;

Expand Down Expand Up @@ -71,8 +71,8 @@ public void teardown() {
public void testCloudSqlInstanceDataRetrievedSuccessfully() {
TestDataSupplier instanceDataSupplier = new TestDataSupplier(false);
// initialize connectionInfoCache after mocks are set up
BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
instanceDataSupplier,
stubCredentialFactory,
Expand Down Expand Up @@ -103,8 +103,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
};

// initialize connectionInfoCache after mocks are set up
BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
connectionInfoRepository,
stubCredentialFactory,
Expand Down Expand Up @@ -138,8 +138,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
};

// initialize connectionInfoCache after mocks are set up
BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
connectionInfoRepository,
stubCredentialFactory,
Expand Down Expand Up @@ -177,8 +177,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
}
};

BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
connectionInfoRepository,
stubCredentialFactory,
Expand Down Expand Up @@ -230,8 +230,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
return Futures.immediateFuture(connectionInfo);
}
};
BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
connectionInfoRepository,
stubCredentialFactory,
Expand Down Expand Up @@ -302,8 +302,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
return Futures.immediateFuture(refreshResult);
}
};
BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
connectionInfoRepository,
stubCredentialFactory,
Expand Down Expand Up @@ -371,8 +371,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
return Futures.immediateFuture(refreshResult);
}
};
BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
connectionInfoRepository,
stubCredentialFactory,
Expand Down Expand Up @@ -444,8 +444,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
}
};

BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
connectionInfoRepository,
stubCredentialFactory,
Expand Down Expand Up @@ -521,8 +521,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
}
};

BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
connectionInfoRepository,
stubCredentialFactory,
Expand Down Expand Up @@ -606,8 +606,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
values.forEach(
(ipTypes, wantsIp) -> {
// initialize connectionInfoCache after mocks are set up
BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder()
.withCloudSqlInstance("project:region:instance")
.withIpTypes(ipTypes)
Expand Down Expand Up @@ -651,8 +651,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
};

// initialize connectionInfoCache after mocks are set up
BaseConnectionInfoCache connectionInfoCache =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache connectionInfoCache =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder()
.withCloudSqlInstance("project:region:instance")
.withIpTypes(Collections.singletonList(IpType.PRIVATE))
Expand All @@ -671,8 +671,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
public void testClosedCloudSqlInstanceDataThrowsException() throws Exception {
TestDataSupplier instanceDataSupplier = new TestDataSupplier(false);
// initialize instance after mocks are set up
BaseConnectionInfoCache instance =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache instance =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
instanceDataSupplier,
stubCredentialFactory,
Expand Down Expand Up @@ -712,8 +712,8 @@ public ListenableFuture<ConnectionInfo> getConnectionInfo(
return Futures.immediateFuture(initialData);
}
};
BaseConnectionInfoCache instance =
new BaseConnectionInfoCache(
RefreshAheadConnectionInfoCache instance =
new RefreshAheadConnectionInfoCache(
new ConnectionConfig.Builder().withCloudSqlInstance("project:region:instance").build(),
connectionInfoRepository,
stubCredentialFactory,
Expand Down
Loading

0 comments on commit 82d82ac

Please sign in to comment.