Skip to content

Commit

Permalink
Using google-auth-library-oauth2-http to get default credentials
Browse files Browse the repository at this point in the history
This should enable use of Workload Identities on GKE.
  • Loading branch information
Mattias Öhrn committed Jul 26, 2019
1 parent 9bfe635 commit b3e61ac
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 47 deletions.
6 changes: 6 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
<version>v1beta4-rev20190510-1.28.0</version>
</dependency>

<dependency>
<groupId>com.google.auth</groupId>
<artifactId>google-auth-library-oauth2-http</artifactId>
<version>0.16.2</version>
</dependency>

<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
package com.google.cloud.sql;

import com.google.api.client.auth.oauth2.Credential;
import com.google.api.client.http.HttpRequestInitializer;

/** Factory for creating {@link Credential}s for interaction with Cloud SQL Admin API. */
public interface CredentialFactory {
/** Name of system property that can specify an alternative credential factory. */
String CREDENTIAL_FACTORY_PROPERTY = "cloudSql.socketFactory.credentialFactory";

Credential create();
HttpRequestInitializer create();
}
45 changes: 10 additions & 35 deletions core/src/main/java/com/google/cloud/sql/core/CoreSocketFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package com.google.cloud.sql.core;

import com.google.api.client.auth.oauth2.Credential;
import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.HttpRequestInitializer;
import com.google.api.client.http.HttpTransport;
Expand All @@ -26,6 +24,8 @@
import com.google.api.services.sqladmin.SQLAdmin;
import com.google.api.services.sqladmin.SQLAdmin.Builder;
import com.google.api.services.sqladmin.SQLAdminScopes;
import com.google.auth.http.HttpCredentialsAdapter;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.sql.CredentialFactory;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
Expand All @@ -40,8 +40,6 @@
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand All @@ -50,7 +48,6 @@
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.net.ssl.SSLSocket;
import jnr.unixsocket.UnixSocketAddress;
import jnr.unixsocket.UnixSocketChannel;
Expand Down Expand Up @@ -93,9 +90,7 @@ public final class CoreSocketFactory {

private static CoreSocketFactory coreSocketFactory;

private final CertificateFactory certificateFactory;
private final ListenableFuture<KeyPair> localKeyPair;
private final Credential credential;
private final ConcurrentHashMap<String, CloudSqlInstance> instances = new ConcurrentHashMap<>();
private final ListeningScheduledExecutorService executor;
private final SQLAdmin adminApi;
Expand All @@ -104,12 +99,9 @@ public final class CoreSocketFactory {
@VisibleForTesting
CoreSocketFactory(
ListenableFuture<KeyPair> localKeyPair,
Credential credential,
SQLAdmin adminApi,
int serverProxyPort,
ListeningScheduledExecutorService executor) {
this.certificateFactory = x509CertificateFactory();
this.credential = credential;
this.adminApi = adminApi;
this.serverProxyPort = serverProxyPort;
this.executor = executor;
Expand All @@ -134,30 +126,20 @@ public static synchronized CoreSocketFactory getInstance() {
credentialFactory = new ApplicationDefaultCredentialFactory();
}

Credential credential = credentialFactory.create();
HttpRequestInitializer credential = credentialFactory.create();
SQLAdmin adminApi = createAdminApiClient(credential);
ListeningScheduledExecutorService executor = getDefaultExecutor();

coreSocketFactory =
new CoreSocketFactory(
executor.submit(CoreSocketFactory::generateRsaKeyPair),
credential,
adminApi,
DEFAULT_SERVER_PROXY_PORT,
executor);
}
return coreSocketFactory;
}

// Returns a factory used to create X.509 public certificates
private static CertificateFactory x509CertificateFactory() {
try {
return CertificateFactory.getInstance("X.509");
} catch (CertificateException err) {
throw new RuntimeException("X509 implementation not available", err);
}
}

// TODO(kvg): Figure out better executor to use for testing
@VisibleForTesting
// Returns a listenable, scheduled executor that exits upon shutdown.
Expand Down Expand Up @@ -286,13 +268,6 @@ private static List<String> listIpTypes(String cloudSqlIpTypes) {
return result;
}

@Nullable
private String getCredentialServiceAccount(Credential credential) {
return credential instanceof GoogleCredential
? ((GoogleCredential) credential).getServiceAccountId()
: null;
}

private static SQLAdmin createAdminApiClient(HttpRequestInitializer requestInitializer) {
HttpTransport httpTransport;
try {
Expand Down Expand Up @@ -321,19 +296,19 @@ private static SQLAdmin createAdminApiClient(HttpRequestInitializer requestIniti

private static class ApplicationDefaultCredentialFactory implements CredentialFactory {
@Override
public Credential create() {
GoogleCredential credential;
public HttpRequestInitializer create() {
GoogleCredentials credentials;
try {
credential = GoogleCredential.getApplicationDefault();
credentials = GoogleCredentials.getApplicationDefault();
} catch (IOException err) {
throw new RuntimeException(
"Unable to obtain credentials to communicate with the Cloud SQL API", err);
}
if (credential.createScopedRequired()) {
credential =
credential.createScoped(Collections.singletonList(SQLAdminScopes.SQLSERVICE_ADMIN));
if (credentials.createScopedRequired()) {
credentials =
credentials.createScoped(Collections.singletonList(SQLAdminScopes.SQLSERVICE_ADMIN));
}
return credential;
return new HttpCredentialsAdapter(credentials);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.json.GoogleJsonError;
import com.google.api.client.googleapis.json.GoogleJsonError.ErrorInfo;
import com.google.api.client.googleapis.json.GoogleJsonResponseException;
Expand Down Expand Up @@ -109,7 +108,6 @@ public class CoreSocketFactoryTest {
// TODO(kvg): Remove this when updating tests to use single CoreSocketFactory
private ListeningScheduledExecutorService defaultExecutor;

@Mock private GoogleCredential credential;
@Mock private SQLAdmin adminApi;
@Mock private SQLAdmin.Instances adminApiInstances;
@Mock private SQLAdmin.Instances.Get adminApiInstancesGet;
Expand Down Expand Up @@ -160,7 +158,7 @@ public void setup() throws IOException, GeneralSecurityException, ExecutionExcep
@Test
public void create_throwsErrorForInvalidInstanceName() throws IOException {
CoreSocketFactory coreSocketFactory =
new CoreSocketFactory(clientKeyPair, credential, adminApi, 3307, defaultExecutor);
new CoreSocketFactory(clientKeyPair, adminApi, 3307, defaultExecutor);
try {
coreSocketFactory.createSslSocket("foo", Arrays.asList("PRIMARY"));
fail();
Expand All @@ -187,7 +185,7 @@ public void create_throwsErrorForInvalidInstanceRegion() throws IOException {
.setRegion("beer"));

CoreSocketFactory coreSocketFactory =
new CoreSocketFactory(clientKeyPair, credential, adminApi, 3307, defaultExecutor);
new CoreSocketFactory(clientKeyPair, adminApi, 3307, defaultExecutor);
try {
coreSocketFactory.createSslSocket("foo:bar:baz", Arrays.asList("PRIMARY"));
fail();
Expand All @@ -209,7 +207,7 @@ public void create_successfulPrivateConnection()
int port = sslServer.start(PRIVATE_IP);

CoreSocketFactory coreSocketFactory =
new CoreSocketFactory(clientKeyPair, credential, adminApi, port, defaultExecutor);
new CoreSocketFactory(clientKeyPair, adminApi, port, defaultExecutor);
Socket socket =
coreSocketFactory.createSslSocket(INSTANCE_CONNECTION_STRING, Arrays.asList("PRIVATE"));

Expand All @@ -230,7 +228,7 @@ public void create_successfulConnection() throws IOException, InterruptedExcepti
int port = sslServer.start();

CoreSocketFactory coreSocketFactory =
new CoreSocketFactory(clientKeyPair, credential, adminApi, port, defaultExecutor);
new CoreSocketFactory(clientKeyPair, adminApi, port, defaultExecutor);
Socket socket =
coreSocketFactory.createSslSocket(INSTANCE_CONNECTION_STRING, Arrays.asList("PRIMARY"));

Expand Down Expand Up @@ -261,7 +259,7 @@ public void create_expiredCertificateOnFirstConnection_certificateRenewed()
.thenReturn(new SslCert().setCert(createEphemeralCert(Duration.ofMinutes(65))));

CoreSocketFactory coreSocketFactory =
new CoreSocketFactory(clientKeyPair, credential, adminApi, port, defaultExecutor);
new CoreSocketFactory(clientKeyPair, adminApi, port, defaultExecutor);
Socket socket =
coreSocketFactory.createSslSocket(INSTANCE_CONNECTION_STRING, Arrays.asList("PRIMARY"));

Expand All @@ -282,7 +280,7 @@ public void create_certificateReusedIfNotExpired() throws IOException, Interrupt
int port = sslServer.start();

CoreSocketFactory coreSocketFactory =
new CoreSocketFactory(clientKeyPair, credential, adminApi, port, defaultExecutor);
new CoreSocketFactory(clientKeyPair, adminApi, port, defaultExecutor);
coreSocketFactory.createSslSocket(INSTANCE_CONNECTION_STRING, Arrays.asList("PRIMARY"));

verify(adminApiInstances).get(PROJECT_ID, INSTANCE_NAME);
Expand All @@ -308,7 +306,7 @@ public void create_adminApiNotEnabled() throws IOException {
new HttpResponseException.Builder(403, "Forbidden", new HttpHeaders()), details));

CoreSocketFactory coreSocketFactory =
new CoreSocketFactory(clientKeyPair, credential, adminApi, 3307, defaultExecutor);
new CoreSocketFactory(clientKeyPair, adminApi, 3307, defaultExecutor);
try {
coreSocketFactory.createSslSocket(INSTANCE_CONNECTION_STRING, Arrays.asList("PRIMARY"));
fail("Expected RuntimeException");
Expand All @@ -335,7 +333,7 @@ public void create_notAuthorizedToGetInstance() throws IOException {
new HttpResponseException.Builder(403, "Forbidden", new HttpHeaders()), details));

CoreSocketFactory coreSocketFactory =
new CoreSocketFactory(clientKeyPair, credential, adminApi, 3307, defaultExecutor);
new CoreSocketFactory(clientKeyPair, adminApi, 3307, defaultExecutor);
try {
coreSocketFactory.createSslSocket(INSTANCE_CONNECTION_STRING, Arrays.asList("PRIMARY"));
fail("Expected RuntimeException");
Expand All @@ -357,7 +355,7 @@ public void create_notAuthorizedToCreateEphemeralCertificate() throws IOExceptio
new HttpResponseException.Builder(403, "Forbidden", new HttpHeaders()), details));

CoreSocketFactory coreSocketFactory =
new CoreSocketFactory(clientKeyPair, credential, adminApi, 3307, defaultExecutor);
new CoreSocketFactory(clientKeyPair, adminApi, 3307, defaultExecutor);
try {
coreSocketFactory.createSslSocket(INSTANCE_CONNECTION_STRING, Arrays.asList("PRIMARY"));
fail();
Expand Down

0 comments on commit b3e61ac

Please sign in to comment.