From 52015e66396a1ac8c58de2c8818a2779754975b0 Mon Sep 17 00:00:00 2001 From: Bryan Rosander Date: Mon, 22 Aug 2016 12:57:37 -0400 Subject: [PATCH] NIFI-2621 - Generating unique serial numbers for certificates --- .../nifi/security/util/CertificateUtils.java | 43 ++++++++++++++- .../security/util/CertificateUtilsTest.groovy | 55 +++++++++++++++++++ 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java b/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java index 760e82ea4605..df239fabdff5 100644 --- a/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java +++ b/nifi-commons/nifi-security-utils/src/main/java/org/apache/nifi/security/util/CertificateUtils.java @@ -74,9 +74,23 @@ public final class CertificateUtils { private static final Logger logger = LoggerFactory.getLogger(CertificateUtils.class); private static final String PEER_NOT_AUTHENTICATED_MSG = "peer not authenticated"; - private static final Map dnOrderMap = createDnOrderMap(); + /** + * The time in milliseconds that the last unique serial number was generated + */ + private static long lastSerialNumberMillis = 0L; + + /** + * An incrementor to add uniqueness to serial numbers generated in the same millisecond + */ + private static int serialNumberIncrementor = 0; + + /** + * BigInteger value to use for the base of the unique serial number + */ + private static BigInteger millisecondBigInteger; + private static Map createDnOrderMap() { Map orderMap = new HashMap<>(); int count = 0; @@ -438,6 +452,29 @@ private static X500Name reverseX500Name(X500Name x500Name) { return new X500Name(rdns.toArray(new RDN[rdns.size()])); } + /** + * Generates a unique serial number by using the current time in milliseconds left shifted 32 bits (to make room for incrementor) with an incrementor added + * + * @return a unique serial number (technically unique to this classloader) + */ + protected static synchronized BigInteger getUniqueSerialNumber() { + final long currentTimeMillis = System.currentTimeMillis(); + final int incrementorValue; + + if (lastSerialNumberMillis != currentTimeMillis) { + // We can only get into this block once per millisecond + millisecondBigInteger = BigInteger.valueOf(currentTimeMillis).shiftLeft(32); + lastSerialNumberMillis = currentTimeMillis; + incrementorValue = 0; + serialNumberIncrementor = 1; + } else { + // Already created at least one serial number this millisecond + incrementorValue = serialNumberIncrementor++; + } + + return millisecondBigInteger.add(BigInteger.valueOf(incrementorValue)); + } + /** * Generates a self-signed {@link X509Certificate} suitable for use as a Certificate Authority. * @@ -458,7 +495,7 @@ public static X509Certificate generateSelfSignedX509Certificate(KeyPair keyPair, X509v3CertificateBuilder certBuilder = new X509v3CertificateBuilder( reverseX500Name(new X500Name(dn)), - BigInteger.valueOf(System.currentTimeMillis()), + getUniqueSerialNumber(), startDate, endDate, reverseX500Name(new X500Name(dn)), subPubKeyInfo); @@ -507,7 +544,7 @@ public static X509Certificate generateIssuedCertificate(String dn, PublicKey pub X509v3CertificateBuilder certBuilder = new X509v3CertificateBuilder( reverseX500Name(new X500Name(issuer.getSubjectX500Principal().getName())), - BigInteger.valueOf(System.currentTimeMillis()), + getUniqueSerialNumber(), startDate, endDate, reverseX500Name(new X500Name(dn)), subPubKeyInfo); diff --git a/nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy b/nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy index 672736425774..47ac918c126f 100644 --- a/nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy +++ b/nifi-commons/nifi-security-utils/src/test/groovy/org/apache/nifi/security/util/CertificateUtilsTest.groovy @@ -40,9 +40,16 @@ import java.security.SignatureException import java.security.cert.Certificate import java.security.cert.CertificateException import java.security.cert.X509Certificate +import java.util.concurrent.Callable +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ExecutionException +import java.util.concurrent.Executors +import java.util.concurrent.Future import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean import static org.junit.Assert.assertEquals +import static org.junit.Assert.assertTrue @RunWith(JUnit4.class) class CertificateUtilsTest extends GroovyTestCase { @@ -497,4 +504,52 @@ class CertificateUtilsTest extends GroovyTestCase { assertEquals("$cn,$l,$st,$o,$ou,$c,$street,$dc,$uid,$surname,$givenName,$initials".toString(), CertificateUtils.reorderDn("$surname,$st,$o,$initials,$givenName,$uid,$street,$c,$cn,$ou,$l,$dc")); } + + @Test + public void testUniqueSerialNumbers() { + def running = new AtomicBoolean(true); + def executorService = Executors.newCachedThreadPool() + def serialNumbers = Collections.newSetFromMap(new ConcurrentHashMap()) + try { + def futures = new ArrayList() + for (int i = 0; i < 8; i++) { + futures.add(executorService.submit(new Callable() { + @Override + Integer call() throws Exception { + int count = 0; + while (running.get()) { + def before = System.currentTimeMillis() + def serialNumber = CertificateUtils.getUniqueSerialNumber() + def after = System.currentTimeMillis() + def serialNumberMillis = serialNumber.shiftRight(32) + assertTrue(serialNumberMillis >= before) + assertTrue(serialNumberMillis <= after) + assertTrue(serialNumbers.add(serialNumber)) + count++; + } + return count; + } + })); + } + + Thread.sleep(1000) + + running.set(false) + + def totalRuns = 0; + for (int i = 0; i < futures.size(); i++) { + try { + def numTimes = futures.get(i).get() + logger.info("future $i executed $numTimes times") + totalRuns += numTimes; + } catch (ExecutionException e) { + throw e.getCause() + } + } + logger.info("Generated ${serialNumbers.size()} unique serial numbers") + assertEquals(totalRuns, serialNumbers.size()) + } finally { + executorService.shutdown() + } + } }