Skip to content

Commit

Permalink
chore: use CaCert from resp body (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
ttosta-google authored Oct 26, 2023
1 parent 7f97c5c commit d247d38
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,19 @@ class ConnectionInfo {
private final String instanceUid;
private final X509Certificate clientCertificate;
private final List<X509Certificate> certificateChain;
private final X509Certificate caCertificate;

ConnectionInfo(
String ipAddress,
String instanceUid,
X509Certificate clientCertificate,
List<X509Certificate> certificateChain) {
List<X509Certificate> certificateChain,
X509Certificate caCertificate) {
this.ipAddress = ipAddress;
this.instanceUid = instanceUid;
this.clientCertificate = clientCertificate;
this.certificateChain = certificateChain;
this.caCertificate = caCertificate;
}

String getIpAddress() {
Expand All @@ -59,6 +62,10 @@ List<X509Certificate> getCertificateChain() {
return certificateChain;
}

X509Certificate getCaCertificate() {
return caCertificate;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -72,12 +79,14 @@ public boolean equals(Object o) {
return Objects.equal(ipAddress, that.ipAddress)
&& Objects.equal(instanceUid, that.instanceUid)
&& Objects.equal(clientCertificate, that.clientCertificate)
&& Objects.equal(certificateChain, that.certificateChain);
&& Objects.equal(certificateChain, that.certificateChain)
&& Objects.equal(caCertificate, that.caCertificate);
}

@Override
public int hashCode() {
return Objects.hashCode(ipAddress, instanceUid, clientCertificate, certificateChain);
return Objects.hashCode(
ipAddress, instanceUid, clientCertificate, certificateChain, caCertificate);
}

@Override
Expand All @@ -93,6 +102,8 @@ public String toString() {
+ clientCertificate
+ ", certificateChain="
+ certificateChain
+ ", caCertificate="
+ caCertificate
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,16 @@ class Connector {
}

private static SSLSocket buildSocket(
X509Certificate clientCertificate,
X509Certificate caCertificate,
List<X509Certificate> certificateChain,
PrivateKey privateKey) {
try {
// First initialize a KeyManager with the ephemeral certificate
// (including the chain of trust to the root CA cert) and the connector's private key.
KeyManager[] keyManagers =
initializeKeyManager(clientCertificate, certificateChain, privateKey);
KeyManager[] keyManagers = initializeKeyManager(certificateChain, privateKey);

// Next, initialize a TrustManager with the root CA certificate.
TrustManager[] trustManagers = initializeTrustManager(certificateChain);
TrustManager[] trustManagers = initializeTrustManager(caCertificate);

// Now, create a TLS 1.3 SSLContext initialized with the KeyManager and the TrustManager,
// and create the SSL Socket.
Expand All @@ -97,26 +96,21 @@ private static SSLSocket buildSocket(
}
}

private static TrustManager[] initializeTrustManager(List<X509Certificate> certificateChain)
private static TrustManager[] initializeTrustManager(X509Certificate caCertificate)
throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException {
KeyStore trustedKeyStore = KeyStore.getInstance(KeyStore.getDefaultType());
trustedKeyStore.load(
null, // don't load the key store from an input stream
null // there is no password
);
trustedKeyStore.setCertificateEntry(
ROOT_CA_CERT,
certificateChain.get(certificateChain.size() - 1) // root CA cert is last in the chain
);
trustedKeyStore.setCertificateEntry(ROOT_CA_CERT, caCertificate);
TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(X_509);
trustManagerFactory.init(trustedKeyStore);
return trustManagerFactory.getTrustManagers();
}

private static KeyManager[] initializeKeyManager(
X509Certificate clientCertificate,
List<X509Certificate> certificateChain,
PrivateKey privateKey)
List<X509Certificate> certificateChain, PrivateKey privateKey)
throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException,
UnrecoverableKeyException {
KeyStore clientAuthenticationKeyStore = KeyStore.getInstance(KeyStore.getDefaultType());
Expand All @@ -125,7 +119,6 @@ private static KeyManager[] initializeKeyManager(
null // there is no password
);
List<Certificate> chain = new ArrayList<>();
chain.add(clientCertificate);
chain.addAll(certificateChain);
Certificate[] chainArray = chain.toArray(new Certificate[] {});
PrivateKeyEntry privateKeyEntry = new PrivateKeyEntry(privateKey, chainArray);
Expand Down Expand Up @@ -157,7 +150,7 @@ Socket connect(InstanceName instanceName) throws IOException {
try {
SSLSocket socket =
buildSocket(
connectionInfo.getClientCertificate(),
connectionInfo.getCaCertificate(),
connectionInfo.getCertificateChain(),
this.clientConnectorKeyPair.getPrivate());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,15 @@ public ConnectionInfo getConnectionInfo(InstanceName instanceName, KeyPair keyPa
certificateChain.add(parseCertificate(certificateChainByte));
}
X509Certificate clientCertificate = certificateChain.get(0);
ByteString caCertificateBytes = certificateResponse.getCaCertBytes();
X509Certificate caCertificate = parseCertificate(caCertificateBytes);

return new ConnectionInfo(
info.getIpAddress(), info.getInstanceUid(), clientCertificate, certificateChain);
info.getIpAddress(),
info.getInstanceUid(),
clientCertificate,
certificateChain,
caCertificate);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ public void testGetConnectionInfo_returnsConnectionInfo() {
testCertificates.getEphemeralCertificate(keyPair.getPublic(), ONE_HOUR_FROM_NOW),
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate())));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate()));
DefaultConnectionInfoCache connectionInfoCache =
new DefaultConnectionInfoCache(
executor,
Expand Down Expand Up @@ -110,13 +111,15 @@ public void testGetConnectionInfo_schedulesNextOperation() {
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), ONE_HOUR_FROM_NOW),
certificateChain),
certificateChain,
testCertificates.getRootCertificate()),
() ->
new ConnectionInfo(
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), TWO_HOURS_FROM_NOW),
certificateChain));
certificateChain,
testCertificates.getRootCertificate()));
DefaultConnectionInfoCache connectionInfoCache =
new DefaultConnectionInfoCache(
executor,
Expand Down Expand Up @@ -181,7 +184,8 @@ public Object getTransportCode() {
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), ONE_HOUR_FROM_NOW),
certificateChain));
certificateChain,
testCertificates.getRootCertificate()));
DefaultConnectionInfoCache connectionInfoCache =
new DefaultConnectionInfoCache(
executor,
Expand Down Expand Up @@ -224,7 +228,8 @@ public void testGetConnectionInfo_scheduledNextOperationImmediately_onCertificat
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), ONE_HOUR_FROM_NOW),
certificateChain));
certificateChain,
testCertificates.getRootCertificate()));
DefaultConnectionInfoCache connectionInfoCache =
new DefaultConnectionInfoCache(
executor,
Expand Down Expand Up @@ -267,7 +272,8 @@ public void testGetConnectionInfo_isRateLimited() {
testCertificates.getEphemeralCertificate(keyPair.getPublic(), ONE_HOUR_FROM_NOW),
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate())));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate()));
DefaultConnectionInfoCache connectionInfoCache =
new DefaultConnectionInfoCache(
executor,
Expand Down Expand Up @@ -299,13 +305,15 @@ public void testForceRefresh_schedulesNextRefreshImmediately() throws Interrupte
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), ONE_HOUR_FROM_NOW),
certificateChain),
certificateChain,
testCertificates.getRootCertificate()),
() ->
new ConnectionInfo(
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), TWO_HOURS_FROM_NOW),
certificateChain));
certificateChain,
testCertificates.getRootCertificate()));
DefaultConnectionInfoCache connectionInfoCache =
new DefaultConnectionInfoCache(
executor,
Expand Down Expand Up @@ -369,19 +377,22 @@ public void testForceRefresh_refreshCalledOnlyOnceDuringMultipleCalls()
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), ONE_HOUR_FROM_NOW),
certificateChain),
certificateChain,
testCertificates.getRootCertificate()),
() ->
new ConnectionInfo(
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), TWO_HOURS_FROM_NOW),
certificateChain),
certificateChain,
testCertificates.getRootCertificate()),
() ->
new ConnectionInfo(
TEST_INSTANCE_IP,
TEST_INSTANCE_ID,
testCertificates.getEphemeralCertificate(keyPair.getPublic(), THREE_HOURS_FROM_NOW),
certificateChain));
certificateChain,
testCertificates.getRootCertificate()));
DefaultConnectionInfoCache connectionInfoCache =
new DefaultConnectionInfoCache(
executor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public void testGetClientCertificateExpiration()
testCertificates.getEphemeralCertificate(testKeyPair.getPublic(), expected),
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate()));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate());

assertThat(connectionInfo.getClientCertificateExpiration()).isEqualTo(expected);
}
Expand All @@ -74,7 +75,8 @@ public void testEquals() throws CertificateException, OperatorCreationException,
ephemeralCertificate,
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate()));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate());

//noinspection EqualsWithItself
assertThat(c1.equals(c1)).isTrue();
Expand All @@ -87,7 +89,8 @@ public void testEquals() throws CertificateException, OperatorCreationException,
ephemeralCertificate,
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate())));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate()));
assertThat(c1)
.isNotEqualTo(
new ConnectionInfo(
Expand All @@ -96,7 +99,8 @@ public void testEquals() throws CertificateException, OperatorCreationException,
ephemeralCertificate,
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate())));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate()));
assertThat(c1)
.isNotEqualTo(
new ConnectionInfo(
Expand All @@ -106,15 +110,17 @@ public void testEquals() throws CertificateException, OperatorCreationException,
testKeyPair.getPublic(), Instant.now().plus(1, ChronoUnit.DAYS)),
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate())));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate()));
assertThat(c1)
.isNotEqualTo(
new ConnectionInfo(
IP_ADDRESS,
INSTANCE_UID,
testCertificates.getEphemeralCertificate(
testKeyPair.getPublic(), Instant.now().plus(1, ChronoUnit.DAYS)),
Collections.emptyList()));
Collections.emptyList(),
testCertificates.getRootCertificate()));

ConnectionInfo c2 =
new ConnectionInfo(
Expand All @@ -123,7 +129,8 @@ public void testEquals() throws CertificateException, OperatorCreationException,
ephemeralCertificate,
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate()));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate());
assertThat(c1).isEqualTo(c2);
}

Expand All @@ -137,7 +144,8 @@ public void testHashCode()
testCertificates.getEphemeralCertificate(testKeyPair.getPublic(), Instant.now()),
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate()));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate());

assertThat(c1.hashCode()).isEqualTo(getHashCode(c1));
}
Expand All @@ -147,6 +155,7 @@ long getHashCode(ConnectionInfo connectionInfo) {
connectionInfo.getIpAddress(),
connectionInfo.getInstanceUid(),
connectionInfo.getClientCertificate(),
connectionInfo.getCertificateChain());
connectionInfo.getCertificateChain(),
connectionInfo.getCaCertificate());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ public void testConnect_whenTlsHandshakeFails()
clientConnectorKeyPair.getPublic(), Instant.now()),
Arrays.asList(
testCertificates.getIntermediateCertificate(),
testCertificates.getRootCertificate())));
testCertificates.getRootCertificate()),
testCertificates.getRootCertificate()));
StubConnectionInfoCacheFactory connectionInfoCacheFactory =
new StubConnectionInfoCacheFactory(stubConnectionInfoCache);
SSLSocket socket = null;
Expand Down

0 comments on commit d247d38

Please sign in to comment.