| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| package com.amazonaws.examples; | ||
|
|
||
| import org.testng.annotations.Test; | ||
|
|
||
| import java.security.GeneralSecurityException; | ||
| import java.security.KeyPair; | ||
| import java.security.KeyPairGenerator; | ||
|
|
||
| public class AsymmetricEncryptedItemTest { | ||
| private static final String TABLE_NAME = "java-ddbec-test-table-asym-example"; | ||
|
|
||
| @Test | ||
| public void testEncryptAndDecrypt() throws GeneralSecurityException { | ||
| final KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); | ||
| keyGen.initialize(2048); | ||
| final KeyPair wrappingKeys = keyGen.generateKeyPair(); | ||
| final KeyPair signingKeys = keyGen.generateKeyPair(); | ||
|
|
||
| AsymmetricEncryptedItem.encryptRecord(TABLE_NAME, wrappingKeys, signingKeys); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| package com.amazonaws.examples; | ||
|
|
||
| import com.amazonaws.services.kms.AWSKMS; | ||
| import com.amazonaws.services.kms.AWSKMSClientBuilder; | ||
| import org.testng.annotations.Test; | ||
|
|
||
| import java.security.GeneralSecurityException; | ||
|
|
||
| import static com.amazonaws.examples.TestUtils.US_WEST_2; | ||
| import static com.amazonaws.examples.TestUtils.US_WEST_2_KEY_ID; | ||
|
|
||
| public class AwsKmsEncryptedItemIT { | ||
| private static final String TABLE_NAME = "java-ddbec-test-table-kms-item-example"; | ||
|
|
||
| @Test | ||
| public void testEncryptAndDecrypt() throws GeneralSecurityException { | ||
| final AWSKMS kms = AWSKMSClientBuilder.standard().withRegion(US_WEST_2).build(); | ||
| AwsKmsEncryptedItem.encryptRecord(TABLE_NAME, US_WEST_2_KEY_ID, kms); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| package com.amazonaws.examples; | ||
|
|
||
| import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; | ||
| import com.amazonaws.services.dynamodbv2.local.embedded.DynamoDBEmbedded; | ||
| import com.amazonaws.services.kms.AWSKMS; | ||
| import com.amazonaws.services.kms.AWSKMSClientBuilder; | ||
| import org.testng.annotations.Test; | ||
|
|
||
| import static com.amazonaws.examples.AwsKmsEncryptedObject.EXAMPLE_TABLE_NAME; | ||
| import static com.amazonaws.examples.AwsKmsEncryptedObject.PARTITION_ATTRIBUTE; | ||
| import static com.amazonaws.examples.AwsKmsEncryptedObject.SORT_ATTRIBUTE; | ||
| import static com.amazonaws.examples.TestUtils.US_WEST_2; | ||
| import static com.amazonaws.examples.TestUtils.US_WEST_2_KEY_ID; | ||
| import static com.amazonaws.examples.TestUtils.createDDBTable; | ||
|
|
||
| public class AwsKmsEncryptedObjectIT { | ||
|
|
||
| @Test | ||
| public void testEncryptAndDecrypt() { | ||
| final AWSKMS kms = AWSKMSClientBuilder.standard().withRegion(US_WEST_2).build(); | ||
| final AmazonDynamoDB ddb = DynamoDBEmbedded.create(); | ||
|
|
||
| // Create the table under test | ||
| createDDBTable(ddb, EXAMPLE_TABLE_NAME, PARTITION_ATTRIBUTE, SORT_ATTRIBUTE); | ||
|
|
||
| AwsKmsEncryptedObject.encryptRecord(US_WEST_2_KEY_ID, ddb, kms); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| package com.amazonaws.examples; | ||
|
|
||
| import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; | ||
| import com.amazonaws.services.dynamodbv2.local.embedded.DynamoDBEmbedded; | ||
|
|
||
| import com.amazonaws.services.kms.AWSKMS; | ||
| import com.amazonaws.services.kms.AWSKMSClientBuilder; | ||
| import org.testng.annotations.Test; | ||
|
|
||
| import java.security.GeneralSecurityException; | ||
|
|
||
| import static com.amazonaws.examples.EncryptionContextOverridesWithDynamoDBMapper.PARTITION_ATTRIBUTE; | ||
| import static com.amazonaws.examples.EncryptionContextOverridesWithDynamoDBMapper.SORT_ATTRIBUTE; | ||
| import static com.amazonaws.examples.EncryptionContextOverridesWithDynamoDBMapper.TABLE_NAME_TO_OVERRIDE; | ||
| import static com.amazonaws.examples.TestUtils.US_WEST_2; | ||
| import static com.amazonaws.examples.TestUtils.US_WEST_2_KEY_ID; | ||
| import static com.amazonaws.examples.TestUtils.createDDBTable; | ||
|
|
||
| public class EncryptionContextOverridesWithDynamoDBMapperIT { | ||
| private static final String OVERRIDE_TABLE_NAME = "java-ddbec-test-table-encctx-override-example"; | ||
|
|
||
| @Test | ||
| public void testEncryptAndDecrypt() throws GeneralSecurityException { | ||
| final AWSKMS kms = AWSKMSClientBuilder.standard().withRegion(US_WEST_2).build(); | ||
| final AmazonDynamoDB ddb = DynamoDBEmbedded.create(); | ||
|
|
||
| // Create the table under test | ||
| createDDBTable(ddb, TABLE_NAME_TO_OVERRIDE, PARTITION_ATTRIBUTE, SORT_ATTRIBUTE); | ||
|
|
||
| EncryptionContextOverridesWithDynamoDBMapper.encryptRecord(US_WEST_2_KEY_ID, OVERRIDE_TABLE_NAME, ddb, kms); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| package com.amazonaws.examples; | ||
|
|
||
| import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; | ||
| import com.amazonaws.services.dynamodbv2.datamodeling.encryption.providers.store.MetaStore; | ||
| import com.amazonaws.services.dynamodbv2.local.embedded.DynamoDBEmbedded; | ||
| import com.amazonaws.services.dynamodbv2.model.ProvisionedThroughput; | ||
| import com.amazonaws.services.kms.AWSKMS; | ||
| import com.amazonaws.services.kms.AWSKMSClientBuilder; | ||
| import org.testng.annotations.Test; | ||
|
|
||
| import java.security.GeneralSecurityException; | ||
|
|
||
| import static com.amazonaws.examples.MostRecentEncryptedItem.PARTITION_ATTRIBUTE; | ||
| import static com.amazonaws.examples.MostRecentEncryptedItem.SORT_ATTRIBUTE; | ||
| import static com.amazonaws.examples.TestUtils.*; | ||
|
|
||
| public class MostRecentEncryptedItemIT { | ||
| private static final String TABLE_NAME = "java-ddbec-test-table-mostrecent-example"; | ||
| private static final String KEY_TABLE_NAME = "java-ddbec-test-table-mostrecent-example-keys"; | ||
| private static final String MATERIAL_NAME = "testMaterial"; | ||
|
|
||
| @Test | ||
| public void testEncryptAndDecrypt() throws GeneralSecurityException { | ||
| final AWSKMS kms = AWSKMSClientBuilder.standard().withRegion(US_WEST_2).build(); | ||
| final AmazonDynamoDB ddb = DynamoDBEmbedded.create(); | ||
|
|
||
| // Create the key table under test | ||
| MetaStore.createTable(ddb, KEY_TABLE_NAME, new ProvisionedThroughput(1L, 1L)); | ||
|
|
||
| // Create the table under test | ||
| createDDBTable(ddb, TABLE_NAME, PARTITION_ATTRIBUTE, SORT_ATTRIBUTE); | ||
|
|
||
| MostRecentEncryptedItem.encryptRecord(TABLE_NAME, KEY_TABLE_NAME, US_WEST_2_KEY_ID, MATERIAL_NAME, ddb, kms); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| package com.amazonaws.examples; | ||
|
|
||
| import org.testng.annotations.Test; | ||
|
|
||
| import javax.crypto.SecretKey; | ||
| import javax.crypto.spec.SecretKeySpec; | ||
| import java.security.GeneralSecurityException; | ||
| import java.security.SecureRandom; | ||
|
|
||
| public class SymmetricEncryptedItemTest { | ||
| private static final String TABLE_NAME = "java-ddbec-test-table-sym-example"; | ||
|
|
||
| @Test | ||
| public void testEncryptAndDecrypt() throws GeneralSecurityException { | ||
| final SecureRandom secureRandom = new SecureRandom(); | ||
| byte[] rawAes = new byte[32]; | ||
| byte[] rawHmac = new byte[32]; | ||
| secureRandom.nextBytes(rawAes); | ||
| secureRandom.nextBytes(rawHmac); | ||
| final SecretKey wrappingKey = new SecretKeySpec(rawAes, "AES"); | ||
| final SecretKey signingKey = new SecretKeySpec(rawHmac, "HmacSHA256"); | ||
|
|
||
| SymmetricEncryptedItem.encryptRecord(TABLE_NAME, wrappingKey, signingKey); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| package com.amazonaws.examples; | ||
|
|
||
| import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; | ||
| import com.amazonaws.services.dynamodbv2.model.*; | ||
|
|
||
| import java.util.ArrayList; | ||
|
|
||
| import static com.amazonaws.examples.AwsKmsEncryptedObject.*; | ||
|
|
||
| public class TestUtils { | ||
| private TestUtils() { | ||
| throw new UnsupportedOperationException( | ||
| "This class exists to hold static resources and cannot be instantiated." | ||
| ); | ||
| } | ||
|
|
||
| /** | ||
| * These special test keys have been configured to allow Encrypt, Decrypt, and GenerateDataKey operations from any | ||
| * AWS principal and should be used when adding new KMS tests. | ||
| * | ||
| * This should go without saying, but never use these keys for production purposes (as anyone in the world can | ||
| * decrypt data encrypted using them). | ||
| */ | ||
| public static final String US_WEST_2_KEY_ID = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; | ||
| public static final String US_WEST_2 = "us-west-2"; | ||
|
|
||
| public static void createDDBTable(AmazonDynamoDB ddb, String tableName, String partitionName, String sortName) { | ||
| ArrayList<AttributeDefinition> attrDef = new ArrayList<AttributeDefinition>(); | ||
| attrDef.add(new AttributeDefinition().withAttributeName(partitionName).withAttributeType(ScalarAttributeType.S)); | ||
| attrDef.add(new AttributeDefinition().withAttributeName(sortName).withAttributeType(ScalarAttributeType.N)); | ||
|
|
||
| ArrayList<KeySchemaElement> keySchema = new ArrayList<KeySchemaElement>(); | ||
| keySchema.add(new KeySchemaElement().withAttributeName(partitionName).withKeyType(KeyType.HASH)); | ||
| keySchema.add(new KeySchemaElement().withAttributeName(sortName).withKeyType(KeyType.RANGE)); | ||
|
|
||
| ddb.createTable(new CreateTableRequest().withTableName(tableName) | ||
| .withAttributeDefinitions(attrDef) | ||
| .withKeySchema(keySchema) | ||
| .withProvisionedThroughput(new ProvisionedThroughput(100L, 100L))); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| package com.amazonaws.services.dynamodbv2.datamodeling.encryption.providers; | ||
|
|
||
| import com.amazonaws.services.dynamodbv2.datamodeling.encryption.EncryptionContext; | ||
| import com.amazonaws.services.dynamodbv2.datamodeling.encryption.materials.DecryptionMaterials; | ||
| import com.amazonaws.services.dynamodbv2.datamodeling.encryption.materials.EncryptionMaterials; | ||
| import com.amazonaws.services.dynamodbv2.datamodeling.encryption.providers.store.ProviderStore; | ||
| import com.amazonaws.services.dynamodbv2.datamodeling.internal.TTLCache; | ||
| import com.amazonaws.services.dynamodbv2.datamodeling.internal.TTLCache.EntryLoader; | ||
|
|
||
| import java.util.concurrent.TimeUnit; | ||
|
|
||
| import static com.amazonaws.services.dynamodbv2.datamodeling.internal.Utils.checkNotNull; | ||
|
|
||
| /** | ||
| * This meta-Provider encrypts data with the most recent version of keying materials from a | ||
| * {@link ProviderStore} and decrypts using whichever version is appropriate. It also caches the | ||
| * results from the {@link ProviderStore} to avoid excessive load on the backing systems. | ||
| */ | ||
| public class CachingMostRecentProvider implements EncryptionMaterialsProvider { | ||
| private static final long INITIAL_VERSION = 0; | ||
| private static final String PROVIDER_CACHE_KEY_DELIM = "#"; | ||
| private static final int DEFAULT_CACHE_MAX_SIZE = 1000; | ||
|
|
||
| private final long ttlInNanos; | ||
| private final ProviderStore keystore; | ||
| protected final String defaultMaterialName; | ||
| private final TTLCache<EncryptionMaterialsProvider> providerCache; | ||
| private final TTLCache<Long> versionCache; | ||
|
|
||
| private final EntryLoader<Long> versionLoader = new EntryLoader<Long>() { | ||
| @Override | ||
| public Long load(String entryKey) { | ||
| return keystore.getMaxVersion(entryKey); | ||
| } | ||
| }; | ||
|
|
||
| private final EntryLoader<EncryptionMaterialsProvider> providerLoader = new EntryLoader<EncryptionMaterialsProvider>() { | ||
| @Override | ||
| public EncryptionMaterialsProvider load(String entryKey) { | ||
| final String[] parts = entryKey.split(PROVIDER_CACHE_KEY_DELIM, 2); | ||
| if (parts.length != 2) { | ||
| throw new IllegalStateException("Invalid cache key for provider cache: " + entryKey); | ||
| } | ||
| return keystore.getProvider(parts[0], Long.parseLong(parts[1])); | ||
| } | ||
| }; | ||
|
|
||
|
|
||
| /** | ||
| * Creates a new {@link CachingMostRecentProvider}. | ||
| * | ||
| * @param keystore | ||
| * The key store that this provider will use to determine which material and which version of material to use | ||
| * @param materialName | ||
| * The name of the materials associated with this provider | ||
| * @param ttlInMillis | ||
| * The length of time in milliseconds to cache the most recent provider | ||
| */ | ||
| public CachingMostRecentProvider(final ProviderStore keystore, final String materialName, final long ttlInMillis) { | ||
| this(keystore, materialName, ttlInMillis, DEFAULT_CACHE_MAX_SIZE); | ||
| } | ||
|
|
||
| /** | ||
| * Creates a new {@link CachingMostRecentProvider}. | ||
| * | ||
| * @param keystore | ||
| * The key store that this provider will use to determine which material and which version of material to use | ||
| * @param materialName | ||
| * The name of the materials associated with this provider | ||
| * @param ttlInMillis | ||
| * The length of time in milliseconds to cache the most recent provider | ||
| * @param maxCacheSize | ||
| * The maximum size of the underlying caches this provider uses. Entries will be evicted from the cache | ||
| * once this size is exceeded. | ||
| */ | ||
| public CachingMostRecentProvider(final ProviderStore keystore, final String materialName, final long ttlInMillis, final int maxCacheSize) { | ||
| this.keystore = checkNotNull(keystore, "keystore must not be null"); | ||
| this.defaultMaterialName = materialName; | ||
| this.ttlInNanos = TimeUnit.MILLISECONDS.toNanos(ttlInMillis); | ||
|
|
||
| this.providerCache = new TTLCache<>(maxCacheSize, ttlInMillis, providerLoader); | ||
| this.versionCache = new TTLCache<>(maxCacheSize, ttlInMillis, versionLoader); | ||
| } | ||
|
|
||
| @Override | ||
| public EncryptionMaterials getEncryptionMaterials(EncryptionContext context) { | ||
| final String materialName = getMaterialName(context); | ||
| final long currentVersion = versionCache.load(materialName); | ||
|
|
||
| if (currentVersion < 0) { | ||
| // The material hasn't been created yet, so specify a loading function | ||
| // to create the first version of materials and update both caches. | ||
| // We want this to be done as part of the cache load to ensure that this logic | ||
| // only happens once in a multithreaded environment, | ||
| // in order to limit calls to the keystore's dependencies. | ||
| final String cacheKey = buildCacheKey(materialName, INITIAL_VERSION); | ||
| EncryptionMaterialsProvider newProvider = providerCache.load( | ||
| cacheKey, | ||
| s -> { | ||
| // Create the new material in the keystore | ||
| final String[] parts = s.split(PROVIDER_CACHE_KEY_DELIM, 2); | ||
| if (parts.length != 2) { | ||
| throw new IllegalStateException("Invalid cache key for provider cache: " + s); | ||
| } | ||
| EncryptionMaterialsProvider provider = keystore.getOrCreate(parts[0], Long.parseLong(parts[1])); | ||
|
|
||
| // We now should have version 0 in our keystore. | ||
| // Update the version cache for this material as a side effect | ||
| versionCache.put(materialName, INITIAL_VERSION); | ||
|
|
||
| // Return the new materials to be put into the cache | ||
| return provider; | ||
| } | ||
| ); | ||
|
|
||
| return newProvider.getEncryptionMaterials(context); | ||
| } else { | ||
| final String cacheKey = buildCacheKey(materialName, currentVersion); | ||
| return providerCache.load(cacheKey).getEncryptionMaterials(context); | ||
| } | ||
| } | ||
|
|
||
| public DecryptionMaterials getDecryptionMaterials(EncryptionContext context) { | ||
| final long version = keystore.getVersionFromMaterialDescription( | ||
| context.getMaterialDescription()); | ||
| final String materialName = getMaterialName(context); | ||
| final String cacheKey = buildCacheKey(materialName, version); | ||
|
|
||
| EncryptionMaterialsProvider provider = providerCache.load(cacheKey); | ||
| return provider.getDecryptionMaterials(context); | ||
| } | ||
|
|
||
| /** | ||
| * Completely empties the cache of both the current and old versions. | ||
| */ | ||
| @Override | ||
| public void refresh() { | ||
| versionCache.clear(); | ||
| providerCache.clear(); | ||
| } | ||
|
|
||
| public String getMaterialName() { | ||
| return defaultMaterialName; | ||
| } | ||
|
|
||
| public long getTtlInMills() { | ||
| return TimeUnit.NANOSECONDS.toMillis(ttlInNanos); | ||
| } | ||
|
|
||
| /** | ||
| * The current version of the materials being used for encryption. Returns -1 if we do not | ||
| * currently have a current version. | ||
| */ | ||
| public long getCurrentVersion() { | ||
| return versionCache.load(getMaterialName()); | ||
| } | ||
|
|
||
| /** | ||
| * The last time the current version was updated. Returns 0 if we do not currently have a | ||
| * current version. | ||
| */ | ||
| public long getLastUpdated() { | ||
| // We cache a version of -1 to mean that there is not a current version | ||
| if (versionCache.load(getMaterialName()) < 0) { | ||
| return 0; | ||
| } | ||
| // Otherwise, return the last update time of that entry | ||
| return TimeUnit.NANOSECONDS.toMillis(versionCache.getLastUpdated(getMaterialName())); | ||
| } | ||
|
|
||
| protected String getMaterialName(final EncryptionContext context) { | ||
| return defaultMaterialName; | ||
| } | ||
|
|
||
| private static String buildCacheKey(final String materialName, final long version) { | ||
| StringBuilder result = new StringBuilder(materialName); | ||
| result.append(PROVIDER_CACHE_KEY_DELIM); | ||
| result.append(version); | ||
| return result.toString(); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| /* | ||
| * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except | ||
| * in compliance with the License. A copy of the License is located at | ||
| * | ||
| * http://aws.amazon.com/apache2.0 | ||
| * | ||
| * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations under the License. | ||
| */ | ||
|
|
||
| package com.amazonaws.services.dynamodbv2.datamodeling.internal; | ||
|
|
||
| interface MsClock { | ||
| MsClock WALLCLOCK = System::nanoTime; | ||
|
|
||
| public long timestampNano(); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,259 @@ | ||
| // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| package com.amazonaws.services.dynamodbv2.datamodeling.internal; | ||
|
|
||
| import com.amazonaws.annotation.ThreadSafe; | ||
|
|
||
| import java.util.concurrent.TimeUnit; | ||
| import java.util.concurrent.atomic.AtomicReference; | ||
| import java.util.concurrent.locks.ReentrantLock; | ||
| import java.util.function.Function; | ||
|
|
||
| import static com.amazonaws.services.dynamodbv2.datamodeling.internal.Utils.checkNotNull; | ||
|
|
||
| /** | ||
| * A cache, backed by an LRUCache, that uses a loader to calculate values on cache miss | ||
| * or expired TTL. | ||
| * | ||
| * Note that this cache does not proactively evict expired entries, | ||
| * however will immediately evict entries discovered to be expired on load. | ||
| * | ||
| * @param <T> | ||
| * value type | ||
| */ | ||
| @ThreadSafe | ||
| public final class TTLCache<T> { | ||
| /** | ||
| * Used for the internal cache. | ||
| */ | ||
| private final LRUCache<LockedState<T>> cache; | ||
|
|
||
| /** | ||
| * Time to live for entries in the cache. | ||
| */ | ||
| private final long ttlInNanos; | ||
|
|
||
| /** | ||
| * Used for loading new values into the cache on cache miss or expiration. | ||
| */ | ||
| private final EntryLoader<T> defaultLoader; | ||
|
|
||
| // Mockable time source, to allow us to test TTL behavior. | ||
| // package access for tests | ||
| MsClock clock = MsClock.WALLCLOCK; | ||
|
|
||
| private static final long TTL_GRACE_IN_NANO = TimeUnit.MILLISECONDS.toNanos(500); | ||
|
|
||
| /** | ||
| * @param maxSize | ||
| * the maximum number of entries of the cache | ||
| * @param ttlInMillis | ||
| * the time to live value for entries of the cache, in milliseconds | ||
| */ | ||
| public TTLCache(final int maxSize, final long ttlInMillis, final EntryLoader<T> loader) { | ||
| if (maxSize < 1) { | ||
| throw new IllegalArgumentException("maxSize " + maxSize + " must be at least 1"); | ||
| } | ||
| if (ttlInMillis < 1) { | ||
| throw new IllegalArgumentException("ttlInMillis " + maxSize + " must be at least 1"); | ||
| } | ||
| this.ttlInNanos = TimeUnit.MILLISECONDS.toNanos(ttlInMillis); | ||
| this.cache = new LRUCache<>(maxSize); | ||
| this.defaultLoader = checkNotNull(loader, "loader must not be null"); | ||
| } | ||
|
|
||
| /** | ||
| * Uses the default loader to calculate the value at key and insert it into the cache, | ||
| * if it doesn't already exist or is expired according to the TTL. | ||
| * | ||
| * This immediately evicts entries past the TTL such that a load failure results | ||
| * in the removal of the entry. | ||
| * | ||
| * Entries that are not expired according to the TTL are returned without recalculating the value. | ||
| * | ||
| * Within a grace period past the TTL, the cache may either return the cached value without recalculating | ||
| * or use the loader to recalculate the value. This is implemented such that, in a multi-threaded environment, | ||
| * only one thread per cache key uses the loader to recalculate the value at one time. | ||
| * | ||
| * @param key | ||
| * The cache key to load the value at | ||
| * @return | ||
| * The value of the given value (already existing or re-calculated). | ||
| */ | ||
| public T load(final String key) { | ||
| return load(key, defaultLoader::load); | ||
| } | ||
|
|
||
| /** | ||
| * Uses the inputted function to calculate the value at key and insert it into the cache, | ||
| * if it doesn't already exist or is expired according to the TTL. | ||
| * | ||
| * This immediately evicts entries past the TTL such that a load failure results | ||
| * in the removal of the entry. | ||
| * | ||
| * Entries that are not expired according to the TTL are returned without recalculating the value. | ||
| * | ||
| * Within a grace period past the TTL, the cache may either return the cached value without recalculating | ||
| * or use the loader to recalculate the value. This is implemented such that, in a multi-threaded environment, | ||
| * only one thread per cache key uses the loader to recalculate the value at one time. | ||
| * | ||
| * Returns the value of the given key (already existing or re-calculated). | ||
| * | ||
| * @param key | ||
| * The cache key to load the value at | ||
| * @param f | ||
| * The function to use to load the value, given key as input | ||
| * @return | ||
| * The value of the given value (already existing or re-calculated). | ||
| */ | ||
| public T load(final String key, Function<String, T> f) { | ||
| final LockedState<T> ls = cache.get(key); | ||
|
|
||
| if (ls == null) { | ||
| // The entry doesn't exist yet, so load a new one. | ||
| return loadNewEntryIfAbsent(key, f); | ||
| } else if (clock.timestampNano() - ls.getState().lastUpdatedNano > ttlInNanos + TTL_GRACE_IN_NANO) { | ||
| // The data has expired past the grace period. | ||
| // Evict the old entry and load a new entry. | ||
| cache.remove(key); | ||
| return loadNewEntryIfAbsent(key, f); | ||
| } else if (clock.timestampNano() - ls.getState().lastUpdatedNano <= ttlInNanos) { | ||
| // The data hasn't expired. Return as-is from the cache. | ||
| return ls.getState().data; | ||
| } else if (!ls.tryLock()) { | ||
| // We are in the TTL grace period. If we couldn't grab the lock, then some other | ||
| // thread is currently loading the new value. Because we are in the grace period, | ||
| // use the cached data instead of waiting for the lock. | ||
| return ls.getState().data; | ||
| } | ||
|
|
||
| // We are in the grace period and have acquired a lock. | ||
| // Update the cache with the value determined by the loading function. | ||
| try { | ||
| T loadedData = f.apply(key); | ||
| ls.update(loadedData, clock.timestampNano()); | ||
| return ls.getState().data; | ||
| } finally { | ||
| ls.unlock(); | ||
| } | ||
| } | ||
|
|
||
| // Synchronously calculate the value for a new entry in the cache if it doesn't already exist. | ||
| // Otherwise return the cached value. | ||
| // It is important that this is the only place where we use the loader for a new entry, | ||
| // given that we don't have the entry yet to lock on. | ||
| // This ensures that the loading function is only called once if multiple threads | ||
| // attempt to add a new entry for the same key at the same time. | ||
| private synchronized T loadNewEntryIfAbsent(final String key, Function<String, T> f) { | ||
| // If the entry already exists in the cache, return it | ||
| final LockedState<T> cachedState = cache.get(key); | ||
| if (cachedState != null) { | ||
| return cachedState.getState().data; | ||
| } | ||
|
|
||
| // Otherwise, load the data and create a new entry | ||
| T loadedData = f.apply(key); | ||
| LockedState<T> ls = new LockedState<>(loadedData, clock.timestampNano()); | ||
| cache.add(key, ls); | ||
| return loadedData; | ||
| } | ||
|
|
||
| /** | ||
| * Put a new entry in the cache. | ||
| * Returns the value previously at that key in the cache, | ||
| * or null if the entry previously didn't exist or | ||
| * is expired. | ||
| */ | ||
| public synchronized T put(final String key, final T value) { | ||
| LockedState<T> ls = new LockedState<>(value, clock.timestampNano()); | ||
| LockedState<T> oldLockedState = cache.add(key, ls); | ||
| if (oldLockedState == null || clock.timestampNano() - oldLockedState.getState().lastUpdatedNano > ttlInNanos + TTL_GRACE_IN_NANO) { | ||
| return null; | ||
| } | ||
| return oldLockedState.getState().data; | ||
| } | ||
|
|
||
| /** | ||
| * Get when the entry at this key was last updated. | ||
| * Returns 0 if the entry doesn't exist at key. | ||
| */ | ||
| public long getLastUpdated(String key) { | ||
| LockedState<T> ls = cache.get(key); | ||
| if (ls == null) { | ||
| return 0; | ||
| } | ||
| return ls.getState().lastUpdatedNano; | ||
| } | ||
|
|
||
| /** | ||
| * Returns the current size of the cache. | ||
| */ | ||
| public int size() { | ||
| return cache.size(); | ||
| } | ||
|
|
||
| /** | ||
| * Returns the maximum size of the cache. | ||
| */ | ||
| public int getMaxSize() { | ||
| return cache.getMaxSize(); | ||
| } | ||
|
|
||
| /** | ||
| * Clears all entries from the cache. | ||
| */ | ||
| public void clear() { | ||
| cache.clear(); | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return cache.toString(); | ||
| } | ||
|
|
||
| public interface EntryLoader<T> { | ||
| T load(String entryKey); | ||
| } | ||
|
|
||
| // An object which stores a state alongside a lock, | ||
| // and performs updates to that state atomically. | ||
| // The state may only be updated if the lock is acquired by the current thread. | ||
| private static class LockedState<T> { | ||
| private final ReentrantLock lock = new ReentrantLock(true); | ||
| private final AtomicReference<State<T>> state; | ||
|
|
||
| public LockedState(T data, long createTimeNano) { | ||
| state = new AtomicReference<>(new State<>(data, createTimeNano)); | ||
| } | ||
|
|
||
| public State<T> getState() { | ||
| return state.get(); | ||
| } | ||
|
|
||
| public void unlock() { | ||
| lock.unlock(); | ||
| } | ||
|
|
||
| public boolean tryLock() { | ||
| return lock.tryLock(); | ||
| } | ||
|
|
||
| public void update(T data, long createTimeNano) { | ||
| if (!lock.isHeldByCurrentThread()) { | ||
| throw new IllegalStateException("Lock not held by current thread"); | ||
| } | ||
| state.set(new State<>(data, createTimeNano)); | ||
| } | ||
| } | ||
|
|
||
| // An object that holds some data and the time at which this object was created | ||
| private static class State<T> { | ||
| public final T data; | ||
| public final long lastUpdatedNano; | ||
|
|
||
| public State(T data, long lastUpdatedNano) { | ||
| this.data = data; | ||
| this.lastUpdatedNano = lastUpdatedNano; | ||
| } | ||
| } | ||
| } |