Skip to content

Commit

Permalink
fix: throw when token is expired or empty (#1233)
Browse files Browse the repository at this point in the history
Related to #1174
  • Loading branch information
enocom committed Apr 4, 2023
1 parent 41d511a commit 970eed0
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.api.services.sqladmin.model.GenerateEphemeralCertRequest;
import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse;
import com.google.api.services.sqladmin.model.IpMapping;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.cloud.sql.AuthType;
import com.google.common.base.CharMatcher;
Expand All @@ -45,6 +46,9 @@
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.time.Instant;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -269,7 +273,11 @@ private Certificate fetchEphemeralCertificate(
if (authType == AuthType.IAM) {
try {
credentials.refresh();
String token = credentials.getAccessToken().getTokenValue();
AccessToken accessToken = credentials.getAccessToken();

validateAccessToken(accessToken);

String token = accessToken.getTokenValue();
// TODO: remove this once issue with OAuth2 Tokens is resolved.
// See: https://github.com/GoogleCloudPlatform/cloud-sql-jdbc-socket-factory/issues/565
request.setAccessToken(CharMatcher.is('.').trimTrailingFrom(token));
Expand Down Expand Up @@ -310,6 +318,33 @@ private Certificate fetchEphemeralCertificate(
return ephemeralCertificate;
}

private void validateAccessToken(AccessToken accessToken) {
Date expirationTimeDate = accessToken.getExpirationTime();
String tokenValue = accessToken.getTokenValue();

if (expirationTimeDate != null) {
Instant expirationTime = expirationTimeDate.toInstant();
Instant now = Instant.now();

// Is the token expired?
if (expirationTime.isBefore(now) || expirationTime.equals(now)) {
DateTimeFormatter formatter = DateTimeFormatter.ISO_INSTANT.withZone(ZoneId.of("UTC"));
String nowFormat = formatter.format(now);
String expirationFormat = formatter.format(expirationTime);
throw new RuntimeException(
"Access Token expiration time is in the past. Now = "
+ nowFormat
+ " Expiration = "
+ expirationFormat);
}
}

// Is the token empty?
if (tokenValue.length() == 0) {
throw new RuntimeException("Access Token has length of zero");
}
}

/**
* Creates a new SslData based on the provided parameters. It contains a SSLContext that will be
* used to provide new SSLSockets authorized to connect to a Cloud SQL instance. It also contains
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.io.InputStreamReader;
import java.net.Socket;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -193,7 +194,8 @@ public void create_notAuthorized() throws IOException {

@Test
public void supportsCustomCredentialFactoryWithIAM() throws InterruptedException, IOException {
CredentialFactory stubCredentialFactory = new StubCredentialFactory("foo", 6000L);
CredentialFactory stubCredentialFactory =
new StubCredentialFactory("foo", Instant.now().plusSeconds(3600).toEpochMilli());

FakeSslServer sslServer = new FakeSslServer();
int port = sslServer.start(PUBLIC_IP);
Expand Down
16 changes: 11 additions & 5 deletions core/src/test/java/com/google/cloud/sql/core/MockAdminApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Base64.Decoder;
Expand Down Expand Up @@ -110,8 +108,8 @@ public KeyPair getClientKeyPair() {
return clientKeyPair;
}

public OAuth2RefreshHandler getRefreshHandler() {
return new MockRefreshHandler();
public OAuth2RefreshHandler getRefreshHandler(String refreshToken, Date expirationTime) {
return new MockRefreshHandler(refreshToken, expirationTime);
}

public void addConnectSettingsResponse(
Expand Down Expand Up @@ -279,9 +277,17 @@ public GenerateEphemeralCertResponse getGenerateEphemeralCertResponse() {
}

private static class MockRefreshHandler implements OAuth2RefreshHandler {
private final String refreshToken;
private final Date expirationTime;

public MockRefreshHandler(String refreshToken, Date expirationTime) {
this.refreshToken = refreshToken;
this.expirationTime = expirationTime;
}

@Override
public AccessToken refreshAccessToken() throws IOException {
return new AccessToken("refreshed-token", Date.from(Instant.now().plus(1, ChronoUnit.HOURS)));
return new AccessToken(refreshToken, expirationTime);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.security.GeneralSecurityException;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.ExecutionException;
Expand All @@ -39,20 +40,24 @@
import org.junit.Test;

public class SqlAdminApiFetcherTest {

public static final String SAMPLE_PUBLIC_IP = "34.1.2.3";
public static final String SAMPLE_PRIVATE_IP = "10.0.0.1";
public static final String INSTANCE_CONNECTION_NAME = "p:r:i";
public static final String DATABASE_VERSION = "POSTGRES14";

@Test
public void fetchesInstanceData()
public void testFetchInstanceData_returnsIpAddresses()
throws ExecutionException, InterruptedException, GeneralSecurityException,
OperatorCreationException {
MockAdminApi mockAdminApi = new MockAdminApi();
mockAdminApi.addConnectSettingsResponse("p:r:i", "34.1.2.3", "10.0.0.1", "POSTGRES14");
mockAdminApi.addGenerateEphemeralCertResponse("p:r:i", Duration.ofHours(1));
MockAdminApi mockAdminApi = buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION);
SqlAdminApiFetcher fetcher =
new StubApiFetcherFactory(mockAdminApi.getHttpTransport())
.create(new StubCredentialFactory().create());

ListenableFuture<InstanceData> instanceDataFuture =
fetcher.getInstanceData(
new CloudSqlInstanceName("p:r:i"),
new CloudSqlInstanceName(INSTANCE_CONNECTION_NAME),
null,
AuthType.PASSWORD,
newTestExecutor(),
Expand All @@ -62,8 +67,8 @@ public void fetchesInstanceData()
assertThat(instanceData.getSslContext()).isInstanceOf(SSLContext.class);

Map<String, String> ipAddrs = instanceData.getIpAddrs();
assertThat(ipAddrs.get("PRIMARY")).isEqualTo("34.1.2.3");
assertThat(ipAddrs.get("PRIVATE")).isEqualTo("10.0.0.1");
assertThat(ipAddrs.get("PRIMARY")).isEqualTo(SAMPLE_PUBLIC_IP);
assertThat(ipAddrs.get("PRIVATE")).isEqualTo(SAMPLE_PRIVATE_IP);
}

private ListeningScheduledExecutorService newTestExecutor() {
Expand All @@ -76,22 +81,21 @@ private ListeningScheduledExecutorService newTestExecutor() {
}

@Test
public void throwsErrorIamAuthNotSupported()
throws GeneralSecurityException, OperatorCreationException, ExecutionException,
InterruptedException {
MockAdminApi mockAdminApi = new MockAdminApi();
mockAdminApi.addConnectSettingsResponse(
"p:r:i", "34.1.2.3", "10.0.0.1", "SQLSERVER_2019_STANDARD");
mockAdminApi.addGenerateEphemeralCertResponse("p:r:i", Duration.ofHours(1));
public void testFetchInstanceData_throwsException_whenIamAuthnIsNotSupported()
throws GeneralSecurityException, OperatorCreationException {
MockAdminApi mockAdminApi =
buildMockAdminApi(INSTANCE_CONNECTION_NAME, "SQLSERVER_2019_STANDARD");
SqlAdminApiFetcher fetcher =
new StubApiFetcherFactory(mockAdminApi.getHttpTransport())
.create(new StubCredentialFactory().create());

ListenableFuture<InstanceData> instanceData =
fetcher.getInstanceData(
new CloudSqlInstanceName("p:r:i"),
new CloudSqlInstanceName(INSTANCE_CONNECTION_NAME),
OAuth2CredentialsWithRefresh.newBuilder()
.setRefreshHandler(mockAdminApi.getRefreshHandler())
.setRefreshHandler(
mockAdminApi.getRefreshHandler(
"refresh-token", Date.from(Instant.now().plus(1, ChronoUnit.HOURS))))
.setAccessToken(new AccessToken("my-token", Date.from(Instant.now())))
.build(),
AuthType.IAM,
Expand All @@ -101,6 +105,69 @@ public void throwsErrorIamAuthNotSupported()
ExecutionException ex = assertThrows(ExecutionException.class, instanceData::get);
assertThat(ex)
.hasMessageThat()
.contains("[p:r:i] " + "IAM Authentication is not supported for SQL Server instances");
.contains("[p:r:i] IAM Authentication is not supported for SQL Server instances");
}

@Test
public void testFetchInstanceData_throwsException_whenTokenIsEmpty()
throws GeneralSecurityException, OperatorCreationException {
MockAdminApi mockAdminApi = buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION);
SqlAdminApiFetcher fetcher =
new StubApiFetcherFactory(mockAdminApi.getHttpTransport())
.create(new StubCredentialFactory().create());

ListenableFuture<InstanceData> instanceData =
fetcher.getInstanceData(
new CloudSqlInstanceName(INSTANCE_CONNECTION_NAME),
OAuth2CredentialsWithRefresh.newBuilder()
.setRefreshHandler(
mockAdminApi.getRefreshHandler(
"", Date.from(Instant.now().plus(1, ChronoUnit.HOURS)) /* empty */))
.setAccessToken(new AccessToken("" /* ignored */, Date.from(Instant.now())))
.build(),
AuthType.IAM,
newTestExecutor(),
Futures.immediateFuture(mockAdminApi.getClientKeyPair()));

ExecutionException ex = assertThrows(ExecutionException.class, instanceData::get);

assertThat(ex).hasMessageThat().contains("Access Token has length of zero");
}

@Test
public void testFetchInstanceData_throwsException_whenTokenIsExpired()
throws GeneralSecurityException, OperatorCreationException {
MockAdminApi mockAdminApi = buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION);
SqlAdminApiFetcher fetcher =
new StubApiFetcherFactory(mockAdminApi.getHttpTransport())
.create(new StubCredentialFactory().create());

ListenableFuture<InstanceData> instanceData =
fetcher.getInstanceData(
new CloudSqlInstanceName(INSTANCE_CONNECTION_NAME),
OAuth2CredentialsWithRefresh.newBuilder()
.setRefreshHandler(
mockAdminApi.getRefreshHandler(
"refresh-token",
Date.from(Instant.now().minus(1, ChronoUnit.HOURS)) /* 1 hour ago */))
.setAccessToken(new AccessToken("original-token", Date.from(Instant.now())))
.build(),
AuthType.IAM,
newTestExecutor(),
Futures.immediateFuture(mockAdminApi.getClientKeyPair()));

ExecutionException ex = assertThrows(ExecutionException.class, instanceData::get);

assertThat(ex).hasMessageThat().contains("Access Token expiration time is in the past");
}

@SuppressWarnings("SameParameterValue")
private MockAdminApi buildMockAdminApi(String instanceConnectionName, String databaseVersion)
throws GeneralSecurityException, OperatorCreationException {
MockAdminApi mockAdminApi = new MockAdminApi();
mockAdminApi.addConnectSettingsResponse(
instanceConnectionName, SAMPLE_PUBLIC_IP, SAMPLE_PRIVATE_IP, databaseVersion);
mockAdminApi.addGenerateEphemeralCertResponse(instanceConnectionName, Duration.ofHours(1));
return mockAdminApi;
}
}

0 comments on commit 970eed0

Please sign in to comment.