-
Notifications
You must be signed in to change notification settings - Fork 20
/
STSCredentialsProvider.java
156 lines (141 loc) · 6.94 KB
/
STSCredentialsProvider.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package com.amazon.aws.emr.credentials;
import com.amazon.aws.emr.ApplicationConfiguration;
import com.amazonaws.AmazonClientException;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleResult;
import com.amazonaws.services.securitytoken.model.Credentials;
import com.amazonaws.util.EC2MetadataUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.Duration;
import java.util.Date;
import java.util.Optional;
import java.util.TimeZone;
import java.util.concurrent.ThreadLocalRandom;
import javax.inject.Inject;
import javax.inject.Singleton;
import lombok.extern.slf4j.Slf4j;
/**
* Fetches credentials for {@code AssumeRoleRequest} from STS.
*/
@Slf4j
@Singleton
public class STSCredentialsProvider implements MetadataCredentialsProvider {
@Inject
STSClient stsClient;
public static final Duration MIN_REMAINING_TIME_TO_REFRESH_CREDENTIALS = Duration.ofMinutes(10);
public static final Duration MAX_RANDOM_TIME_TO_REFRESH_CREDENTIALS = Duration.ofMinutes(5);
private static final int CREDENTIALS_MAP_MAX_SIZE = 20000;
private final LoadingCache<AssumeRoleRequest, Optional<EC2MetadataUtils.IAMSecurityCredential>> credentialsCache = CacheBuilder
.newBuilder().maximumSize(CREDENTIALS_MAP_MAX_SIZE)
.build(new CacheLoader<AssumeRoleRequest, Optional<EC2MetadataUtils.IAMSecurityCredential>>() {
@Override
public Optional<EC2MetadataUtils.IAMSecurityCredential> load(AssumeRoleRequest assumeRoleRequest) {
return assumeRole(assumeRoleRequest);
}
});
/**
* Create an instance of SimpleDataFormat.
* SimpleDateFormat is not thread safe, so we create an instance when needed instead of using a shared one
*
* @return
*/
static SimpleDateFormat createInterceptorDateTimeFormat() {
SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'");
dateFormat.setTimeZone(TimeZone.getTimeZone("UTC"));
return dateFormat;
}
/**
* {@inheritDoc}
*/
@Override
public Optional<EC2MetadataUtils.IAMSecurityCredential> getUserCredentials(AssumeRoleRequest assumeRoleRequest) {
log.debug("Request to assume role {} with STS", assumeRoleRequest);
Optional<EC2MetadataUtils.IAMSecurityCredential> credentials = credentialsCache.getUnchecked(assumeRoleRequest);
if (credentials.isPresent() && shouldRefresh(credentials.get())) {
// TODO: we should consider using Caffeine which provides ttl at item level
log.debug("Invalidating the cache for assume role {}", assumeRoleRequest);
/*
* In case of multiple threads reaching here, we should be alright as locking is at
* segment level for both invalidate() and get() calls.
*/
credentialsCache.invalidate(assumeRoleRequest);
credentials = credentialsCache.getUnchecked(assumeRoleRequest);
}
return credentials;
}
/**
* Makes actual call to STS.
*
* @param assumeRoleRequest the request to assume
* @return an {@code Optional} containing {@link EC2MetadataUtils.IAMSecurityCredential}
*/
private Optional<EC2MetadataUtils.IAMSecurityCredential> assumeRole(AssumeRoleRequest assumeRoleRequest) {
log.info("Need to assume role {} with STS", assumeRoleRequest);
try {
AssumeRoleResult assumeRoleResult = stsClient.assumeRole(assumeRoleRequest);
EC2MetadataUtils.IAMSecurityCredential credentials = createIAMSecurityCredential(assumeRoleResult.getCredentials());
log.debug("Procured credentials from STS for assume role {}", assumeRoleRequest);
return Optional.of(credentials);
} catch (AmazonServiceException ase) {
// This is an internal server error.
log.error("AWS Service exception {}", ase.getErrorMessage(), ase);
throw ase;
} catch (AmazonClientException ace) {
log.error("AWS Client exception {}", ace.getMessage(), ace);
}
return Optional.empty();
}
private EC2MetadataUtils.IAMSecurityCredential createIAMSecurityCredential(Credentials credentials) {
EC2MetadataUtils.IAMSecurityCredential iamCredential = new EC2MetadataUtils.IAMSecurityCredential();
iamCredential.accessKeyId = credentials.getAccessKeyId();
iamCredential.secretAccessKey = credentials.getSecretAccessKey();
iamCredential.token = credentials.getSessionToken();
iamCredential.code = "Success";
iamCredential.type = "AWS-HMAC";
iamCredential.expiration = createInterceptorDateTimeFormat().format(credentials.getExpiration());
long nowTs = System.currentTimeMillis();
Date now = new Date(nowTs);
iamCredential.lastUpdated = createInterceptorDateTimeFormat().format(now);
return iamCredential;
}
/**
* Determines if we need to refresh the cached credentials.
* <p>
* The credentials are refreshed if we don't have any cached credentials, or if the
* current time +
* {@link STSCredentialsProvider#MIN_REMAINING_TIME_TO_REFRESH_CREDENTIALS} + some random time in range
* [0, {@link STSCredentialsProvider#MAX_RANDOM_TIME_TO_REFRESH_CREDENTIALS}) is
* greater than the expiration of cached credentials.
*
* @param credentials the cached credentials
* @return {@code true} if we need to assume role with STS, else {@code false}
*/
private boolean shouldRefresh(EC2MetadataUtils.IAMSecurityCredential credentials) {
try {
Date expirationDate = createInterceptorDateTimeFormat().parse(credentials.expiration);
return getRandomTimeInRange() + System.currentTimeMillis() > expirationDate.getTime();
} catch (ParseException ex) {
log.error("Unable to parse the expiration in the cached assume role credentials. Refreshing credentials anyway.", ex);
return true;
}
}
@VisibleForTesting
public long getRandomTimeInRange() {
long minTimeMs = MIN_REMAINING_TIME_TO_REFRESH_CREDENTIALS.toMillis();
long maxRandomTimeMs = MAX_RANDOM_TIME_TO_REFRESH_CREDENTIALS.toMillis();
return minTimeMs + ThreadLocalRandom.current().nextLong(maxRandomTimeMs);
}
}