Skip to content

Commit

Permalink
HDDS-8178. CertificateClient and KeyStoresFactory support multiple Su…
Browse files Browse the repository at this point in the history
…b-CA certificates in the trust chain (#4442)
  • Loading branch information
ChenSammi committed Mar 28, 2023
1 parent 2a82613 commit 012ecd3
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
import java.security.KeyStore;
import java.security.Principal;
import java.security.PrivateKey;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -64,7 +65,7 @@ public class ReloadingX509KeyManager extends X509ExtendedKeyManager {
* materials are changed.
*/
private PrivateKey currentPrivateKey;
private String currentCertId;
private List<String> currentCertIdsList = new ArrayList<>();

/**
* Construct a <code>Reloading509KeystoreManager</code>.
Expand Down Expand Up @@ -150,11 +151,14 @@ public ReloadingX509KeyManager loadFrom(CertificateClient caClient) {
private X509ExtendedKeyManager loadKeyManager(CertificateClient caClient)
throws GeneralSecurityException, IOException {
PrivateKey privateKey = caClient.getPrivateKey();
X509Certificate cert = caClient.getCertificate();
String certId = cert.getSerialNumber().toString();
// Security materials keep the same
if (currentCertId != null && currentPrivateKey != null &&
currentCertId.equals(certId) && currentPrivateKey.equals(privateKey)) {
List<X509Certificate> newCertList = caClient.getTrustChain();
if (currentPrivateKey != null && currentPrivateKey.equals(privateKey) &&
currentCertIdsList.size() > 0 &&
newCertList.size() == currentCertIdsList.size() &&
!newCertList.stream().filter(
c -> !currentCertIdsList.contains(c.getSerialNumber().toString()))
.findAny().isPresent()) {
// Security materials(key and certificates) keep the same.
return null;
}

Expand All @@ -163,7 +167,8 @@ private X509ExtendedKeyManager loadKeyManager(CertificateClient caClient)
keystore.load(null, null);

keystore.setKeyEntry(caClient.getComponentName() + "_key",
privateKey, EMPTY_PASSWORD, new Certificate[]{cert});
privateKey, EMPTY_PASSWORD,
newCertList.toArray(new X509Certificate[0]));

KeyManagerFactory keyMgrFactory = KeyManagerFactory.getInstance(
KeyManagerFactory.getDefaultAlgorithm());
Expand All @@ -176,7 +181,10 @@ private X509ExtendedKeyManager loadKeyManager(CertificateClient caClient)
}

currentPrivateKey = privateKey;
currentCertId = cert.getSerialNumber().toString();
currentCertIdsList.clear();
for (X509Certificate cert: newCertList) {
currentCertIdsList.add(cert.getSerialNumber().toString());
}
return keyManager;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ public final class ReloadingX509TrustManager implements X509TrustManager {
private final String type;
private final AtomicReference<X509TrustManager> trustManagerRef;
/**
* Current CA cert in trustManager, to detect if certificate is changed.
* Current Root CA cert in trustManager, to detect if certificate is changed.
*/
private String currentCACertId = null;
private String currentRootCACertId = null;

/**
* Creates a reloadable trustmanager. The trustmanager reloads itself
Expand Down Expand Up @@ -124,17 +124,21 @@ public ReloadingX509TrustManager loadFrom(CertificateClient caClient) {

X509TrustManager loadTrustManager(CertificateClient caClient)
throws GeneralSecurityException, IOException {
X509Certificate cert = caClient.getCACertificate();
String certId = cert.getSerialNumber().toString();
// SCM certificate client sets root CA as CA cert instead of root CA cert
X509Certificate rootCACert = caClient.getRootCACertificate() != null ?
caClient.getRootCACertificate() : caClient.getCACertificate();

String rootCACertId = rootCACert.getSerialNumber().toString();
// Certificate keeps the same.
if (currentCACertId != null && currentCACertId.equals(certId)) {
if (currentRootCACertId != null &&
currentRootCACertId.equals(rootCACertId)) {
return null;
}

X509TrustManager trustManager = null;
KeyStore ks = KeyStore.getInstance(type);
ks.load(null, null);
ks.setCertificateEntry(certId, cert);
ks.setCertificateEntry(rootCACertId, rootCACert);

TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm());
Expand All @@ -146,7 +150,7 @@ X509TrustManager loadTrustManager(CertificateClient caClient)
break;
}
}
currentCACertId = certId;
currentRootCACertId = rootCACertId;
return trustManager;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ X509Certificate getCertificate(String certSerialId)
*/
CertPath getCACertPath();

/**
* Return all certificates in this component's trust chain,
* the last one is the root CA certificate.
*/
List<X509Certificate> getTrustChain();

/**
* Return the latest Root CA certificate known to the client.
* @return latest Root CA certificate known to the client.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -331,6 +332,39 @@ public synchronized X509Certificate getCACertificate() {
return firstCertificateFrom(caCertPath);
}

/**
* Return all certificates in this component's trust chain,
* the last one is the root CA certificate.
*/
@Override
public synchronized List<X509Certificate> getTrustChain() {
CertPath path = getCertPath();
if (path == null || path.getCertificates() == null) {
return null;
}

List<X509Certificate> chain = new ArrayList<>();
// certificate bundle case
if (path.getCertificates().size() > 1) {
for (int i = 0; i < path.getCertificates().size(); i++) {
chain.add((X509Certificate) path.getCertificates().get(i));
}
} else {
// case before certificate bundle is supported
chain.add(getCertificate());
X509Certificate cert = getCACertificate();
if (cert != null) {
chain.add(getCACertificate());
}
cert = getRootCACertificate();
if (cert != null) {
chain.add(cert);
}
}

return chain;
}

@Override
public synchronized CertPath getCACertPath() {
if (caCertId != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ public static void setUp() throws Exception {
public void testReload() throws Exception {
TrustManager tm =
caClient.getServerKeyStoresFactory().getTrustManagers()[0];
X509Certificate cert1 = caClient.getCACertificate();
X509Certificate cert1 = caClient.getRootCACertificate();
assertEquals(cert1,
((ReloadingX509TrustManager)tm).getAcceptedIssuers()[0]);

caClient.renewRootCA();
caClient.renewKey();
X509Certificate cert2 = caClient.getCACertificate();
X509Certificate cert2 = caClient.getRootCACertificate();
assertNotEquals(cert1, cert2);

assertEquals(cert2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.time.Duration;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -207,6 +208,14 @@ public CertPath getCACertPath() {
return null;
}

@Override
public List<X509Certificate> getTrustChain() {
List<X509Certificate> list = new ArrayList<>();
list.add(x509Certificate);
list.add(rootCert);
return list;
}

@Override
public X509Certificate getCACertificate() {
return rootCert;
Expand Down Expand Up @@ -262,7 +271,7 @@ public String getComponentName() {

@Override
public X509Certificate getRootCACertificate() {
return x509Certificate;
return rootCert;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ public void testSecureScmStartupSuccess() throws Exception {
ScmInfo scmInfo = scm.getClientProtocolServer().getScmInfo();
assertEquals(clusterId, scmInfo.getClusterId());
assertEquals(scmId, scmInfo.getScmId());
assertEquals(2, scm.getScmCertificateClient().getTrustChain().size());
}

@Test
Expand Down Expand Up @@ -867,6 +868,7 @@ public void testSecureOmInitSuccess() throws Exception {
assertNotNull(om.getCertificateClient().getPublicKey());
assertNotNull(om.getCertificateClient().getPrivateKey());
assertNotNull(om.getCertificateClient().getCertificate());
assertEquals(3, om.getCertificateClient().getTrustChain().size());
assertTrue(omLogs.getOutput().contains("Init response: GETCERT"));
assertTrue(omLogs.getOutput().contains("Successfully stored " +
"SCM signed certificate"));
Expand Down Expand Up @@ -903,7 +905,6 @@ public void testCertificateRotation() throws Exception {

SecurityConfig securityConfig = new SecurityConfig(conf);


// save first cert
final int certificateLifetime = 20; // seconds
KeyCodec keyCodec =
Expand Down

0 comments on commit 012ecd3

Please sign in to comment.