Skip to content

Commit

Permalink
We will use the regional endpoint with STS whenever possible. (#45)
Browse files Browse the repository at this point in the history
* We will use the regional endpoint with STS whenever possible.
The global endpoint is used in case the region cannot be found.

#44

* STSClient is now an injectable. This is necessary to consult the
ApplicationConfiguration object to create the AWS STS client.

* Added some more comments.

* Incorporated the PR comments.
- Renamed constant to use regional STS endpoint
- Added some comments about the constant

* Default to global endpoint if we can't find the AWS region.
  • Loading branch information
sharad-oss committed Sep 30, 2021
1 parent 246d139 commit e4a9c4d
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ public Set<String> getAllowedUsersForImpersonation() {
return this.IMPERSONATION_ALLOWED_USERS;
}

public boolean isRegionalStsEnabled() {
return Boolean.parseBoolean(properties.getProperty(Constants.REGIONAL_STS_ENDPOINT_ENABLED,
String.valueOf("true")));
}


public boolean isSetSourceIdentityEnabled() {
return Boolean.parseBoolean(properties.getProperty(Constants.SET_SOURCE_IDENTITY_ENABLED,
String.valueOf("false")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ final public class Constants {
// Set the source identity in the Assume Role calls
public static final String SET_SOURCE_IDENTITY_ENABLED = "rolemapper.sourceidentity.enabled";

// Determines if regional STS endpoint is used. Setting to "false" uses global endpoint.
// Default value is true.
public static final String REGIONAL_STS_ENDPOINT_ENABLED = "rolemapper.regional.sts.endpoint.enabled";

private Constants() {
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.amazon.aws.emr.credentials;

import com.amazonaws.services.securitytoken.model.AssumeRoleRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleResult;

/**
* Interface representing the AWS STS client.
*/
public interface STSClient {

AssumeRoleResult assumeRole(AssumeRoleRequest assumeRoleRequest);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.amazon.aws.emr.credentials;

import com.amazon.aws.emr.ApplicationConfiguration;
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleResult;
import javax.annotation.PostConstruct;
import javax.inject.Inject;
import lombok.extern.slf4j.Slf4j;
import org.glassfish.hk2.api.Immediate;

/**
* Creates a custom AWS STS client based on {@link ApplicationConfiguration}
*/
@Slf4j
@Immediate
public class STSClientImpl implements STSClient {
// If using regional configurations we will use us-west-2 as default
// This is primarily used in integration tests
private String regionString = "us-west-2";

@Inject
ApplicationConfiguration applicationConfiguration;

AWSSecurityTokenService stsClient;

@PostConstruct
void init() {
if (applicationConfiguration.isRegionalStsEnabled()) {
Region region = null;
try {
region = Regions.getCurrentRegion();
regionString = region.getName();
String endpoint = String.format("https://sts.%s.amazonaws.com", regionString);
log.info("Running the application with regional STS endpoint " + endpoint);
stsClient = AWSSecurityTokenServiceClientBuilder
.standard()
.withEndpointConfiguration(new EndpointConfiguration(endpoint, regionString))
.build();
} catch (Exception e) {
log.error("Cannot determine the AWS region. Defaulting to global endpoint.");
createGlobalEndpointClient();
}
} else {
createGlobalEndpointClient();
}
}

private void createGlobalEndpointClient() {
log.info("Running the application with global STS endpoint.");
stsClient = AWSSecurityTokenServiceClientBuilder
.standard()
.build();
}

@Override
public AssumeRoleResult assumeRole(AssumeRoleRequest assumeRoleRequest) {
return stsClient.assumeRole(assumeRoleRequest);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

package com.amazon.aws.emr.credentials;

import com.amazon.aws.emr.ApplicationConfiguration;
import com.amazonaws.AmazonClientException;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest;
Expand All @@ -15,16 +19,16 @@
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import lombok.extern.slf4j.Slf4j;

import javax.inject.Singleton;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.Duration;
import java.util.Date;
import java.util.Optional;
import java.util.TimeZone;
import java.util.concurrent.ThreadLocalRandom;
import javax.inject.Inject;
import javax.inject.Singleton;
import lombok.extern.slf4j.Slf4j;

/**
* Fetches credentials for {@code AssumeRoleRequest} from STS.
Expand All @@ -33,11 +37,12 @@
@Singleton
public class STSCredentialsProvider implements MetadataCredentialsProvider {

@Inject
STSClient stsClient;

public static final Duration MIN_REMAINING_TIME_TO_REFRESH_CREDENTIALS = Duration.ofMinutes(10);
public static final Duration MAX_RANDOM_TIME_TO_REFRESH_CREDENTIALS = Duration.ofMinutes(5);
private static final int CREDENTIALS_MAP_MAX_SIZE = 20000;
// Initialized later for testing using mocks.
public static AWSSecurityTokenService stsClient = null;

private final LoadingCache<AssumeRoleRequest, Optional<EC2MetadataUtils.IAMSecurityCredential>> credentialsCache = CacheBuilder
.newBuilder().maximumSize(CREDENTIALS_MAP_MAX_SIZE)
Expand All @@ -48,15 +53,6 @@ public Optional<EC2MetadataUtils.IAMSecurityCredential> load(AssumeRoleRequest a
}
});

synchronized static AWSSecurityTokenService getStsClient() {
if (stsClient == null) {
stsClient = AWSSecurityTokenServiceClientBuilder
.standard()
.build();
}
return stsClient;
}

/**
* Create an instance of SimpleDataFormat.
* SimpleDateFormat is not thread safe, so we create an instance when needed instead of using a shared one
Expand Down Expand Up @@ -99,7 +95,7 @@ public Optional<EC2MetadataUtils.IAMSecurityCredential> getUserCredentials(Assum
private Optional<EC2MetadataUtils.IAMSecurityCredential> assumeRole(AssumeRoleRequest assumeRoleRequest) {
log.info("Need to assume role {} with STS", assumeRoleRequest);
try {
AssumeRoleResult assumeRoleResult = getStsClient().assumeRole(assumeRoleRequest);
AssumeRoleResult assumeRoleResult = stsClient.assumeRole(assumeRoleRequest);
EC2MetadataUtils.IAMSecurityCredential credentials = createIAMSecurityCredential(assumeRoleResult.getCredentials());
log.debug("Procured credentials from STS for assume role {}", assumeRoleRequest);
return Optional.of(credentials);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import com.amazon.aws.emr.common.system.PrincipalResolver;
import com.amazon.aws.emr.common.system.factory.PrincipalResolverFactory;
import com.amazon.aws.emr.credentials.MetadataCredentialsProvider;
import com.amazon.aws.emr.credentials.STSClient;
import com.amazon.aws.emr.credentials.STSClientImpl;
import com.amazon.aws.emr.credentials.STSCredentialsProvider;
import com.amazon.aws.emr.mapping.MappingInvoker;
import com.amazon.aws.emr.common.system.user.LinuxUserIdService;
Expand All @@ -27,6 +29,7 @@ protected void configure() {
bind(MappingInvoker.class).to(MappingInvoker.class).in(Immediate.class);
bind(STSCredentialsProvider.class).to(MetadataCredentialsProvider.class).in(Singleton.class);
bind(ApplicationConfiguration.class).to(ApplicationConfiguration.class).in(Immediate.class);
bind(STSClientImpl.class).to(STSClient.class).in(Immediate.class);
bindFactory(PrincipalResolverFactory.class).to(PrincipalResolver.class).in(Singleton.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import static org.hamcrest.core.Is.is;
import static org.powermock.api.mockito.PowerMockito.mockStatic;

import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleResult;
import com.amazonaws.services.securitytoken.model.Credentials;
Expand All @@ -20,6 +19,7 @@
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito;
Expand All @@ -37,15 +37,16 @@ public class STSCredentialsProviderTest {
private static final long TWO_MIN_MS = 2 * 60 * 1000;

@Mock
AWSSecurityTokenService stsClient;
STSClient stsClient;

@InjectMocks
STSCredentialsProvider stsCredentialsProvider;

AssumeRoleRequest assumeRoleRequest;

@Before
public void setup() {
mockStatic(STSCredentialsProvider.class);
Mockito.when(STSCredentialsProvider.getStsClient())
.thenReturn(stsClient);
Mockito.when(STSCredentialsProvider.createInterceptorDateTimeFormat()).thenCallRealMethod();
assumeRoleRequest = createTestAssumeRoleRequest();
}
Expand All @@ -56,7 +57,6 @@ public void get_credentials() {
Mockito.when(stsClient.assumeRole(assumeRoleRequest)).thenReturn(
new AssumeRoleResult()
.withCredentials(longLivedCredentials));
STSCredentialsProvider stsCredentialsProvider = new STSCredentialsProvider();
Optional<EC2MetadataUtils.IAMSecurityCredential> optionalIAMSecurityCredentials = stsCredentialsProvider
.getUserCredentials(assumeRoleRequest);
assertThat(optionalIAMSecurityCredentials.isPresent(), is(true));
Expand All @@ -78,17 +78,11 @@ public void get_cached_credentials() {
new AssumeRoleResult()
.withCredentials(longLivedCredentials));

STSCredentialsProvider stsCredentialsProvider = new STSCredentialsProvider();
stsCredentialsProvider.getUserCredentials(assumeRoleRequest);

PowerMockito.verifyStatic(STSCredentialsProvider.class, Mockito.times(1));
STSCredentialsProvider.getStsClient();

// Make the second call
Mockito.verify(stsClient, Mockito.times(1)).assumeRole(assumeRoleRequest);
// Make the second call and there should no additional Mock invocation
stsCredentialsProvider.getUserCredentials(assumeRoleRequest);
// The invocations with STS client don't go up
PowerMockito.verifyStatic(STSCredentialsProvider.class, Mockito.times(1));
STSCredentialsProvider.getStsClient();
Mockito.verify(stsClient, Mockito.times(1)).assumeRole(assumeRoleRequest);
}

@Test
Expand All @@ -97,21 +91,21 @@ public void expired_credentials() {
Mockito.when(stsClient.assumeRole(assumeRoleRequest)).thenReturn(
new AssumeRoleResult()
.withCredentials(shortLivedTestCredentials));
STSCredentialsProvider stsCredentialsProvider = new STSCredentialsProvider();
stsCredentialsProvider.getUserCredentials(assumeRoleRequest);

/*
* Why 2?
* First call gets the credentials using sts as the cache is empty.
* Second call is made to STS as the retrieved credentials are expired.
*/
PowerMockito.verifyStatic(STSCredentialsProvider.class, Mockito.times(2));
STSCredentialsProvider.getStsClient();
Mockito.verify(stsClient, Mockito.times(2)).assumeRole(assumeRoleRequest);

// Make the second call and there should no another Mock invocation
stsCredentialsProvider.getUserCredentials(assumeRoleRequest);
Mockito.verify(stsClient, Mockito.times(3)).assumeRole(assumeRoleRequest);

// Make second call, should invoke STS client again
stsCredentialsProvider.getUserCredentials(assumeRoleRequest);
PowerMockito.verifyStatic(STSCredentialsProvider.class, Mockito.times(3));
STSCredentialsProvider.getStsClient();
}

@Test
Expand All @@ -120,20 +114,20 @@ public void about_to_expire_credentials() {
Mockito.when(stsClient.assumeRole(assumeRoleRequest)).thenReturn(
new AssumeRoleResult()
.withCredentials(shortLivedTestCredentials));
STSCredentialsProvider stsCredentialsProvider = new STSCredentialsProvider();
stsCredentialsProvider.getUserCredentials(assumeRoleRequest);
PowerMockito.verifyStatic(STSCredentialsProvider.class, Mockito.times(2));
STSCredentialsProvider.getStsClient();
Mockito.verify(stsClient, Mockito.times(2)).assumeRole(assumeRoleRequest);

// Make the second call and there should no another Mock invocation
stsCredentialsProvider.getUserCredentials(assumeRoleRequest);
Mockito.verify(stsClient, Mockito.times(3)).assumeRole(assumeRoleRequest);

// Make second call, should invoke STS client again
stsCredentialsProvider.getUserCredentials(assumeRoleRequest);
PowerMockito.verifyStatic(STSCredentialsProvider.class, Mockito.times(3));
STSCredentialsProvider.getStsClient();
}

@Test
public void random_refresh_time() {
STSCredentialsProvider stsCredentialsProvider = new STSCredentialsProvider();
assertThat(stsCredentialsProvider.getRandomTimeInRange(), allOf(
greaterThan(STSCredentialsProvider.MIN_REMAINING_TIME_TO_REFRESH_CREDENTIALS.toMillis()),
lessThan(STSCredentialsProvider.MIN_REMAINING_TIME_TO_REFRESH_CREDENTIALS.toMillis() +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import com.amazon.aws.emr.common.system.impl.CommandBasedPrincipalResolver;
import com.amazon.aws.emr.common.system.user.UserIdService;
import com.amazon.aws.emr.credentials.MetadataCredentialsProvider;
import com.amazon.aws.emr.credentials.STSClient;
import com.amazon.aws.emr.credentials.STSClientImpl;
import com.amazon.aws.emr.credentials.STSCredentialsProvider;
import com.amazon.aws.emr.integration.IntegrationTestsUserService;
import com.amazon.aws.emr.mapping.MappingInvoker;
Expand All @@ -24,6 +26,7 @@ public class DefaultMapperIntegrationBinder extends AbstractBinder {
protected void configure() {
bind(IntegrationTestsUserService.class).to(UserIdService.class);
bind(MappingInvoker.class).to(MappingInvoker.class).in(Immediate.class);
bind(STSClientImpl.class).to(STSClient.class).in(Immediate.class);
bind(STSCredentialsProvider.class).to(MetadataCredentialsProvider.class).in(Singleton.class);
bind(DefaultMapperImplApplicationConfig.class).to(ApplicationConfiguration.class)
.in(Immediate.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import com.amazon.aws.emr.common.system.impl.CommandBasedPrincipalResolver;
import com.amazon.aws.emr.common.system.user.UserIdService;
import com.amazon.aws.emr.credentials.MetadataCredentialsProvider;
import com.amazon.aws.emr.credentials.STSClient;
import com.amazon.aws.emr.credentials.STSClientImpl;
import com.amazon.aws.emr.credentials.STSCredentialsProvider;
import com.amazon.aws.emr.integration.IntegrationTestsUserService;
import com.amazon.aws.emr.mapping.MappingInvoker;
Expand All @@ -24,6 +26,7 @@ public class PoliciesUnionMapperIntegrationBinder extends AbstractBinder {
protected void configure() {
bind(IntegrationTestsUserService.class).to(UserIdService.class);
bind(MappingInvoker.class).to(MappingInvoker.class).in(Immediate.class);
bind(STSClientImpl.class).to(STSClient.class).in(Immediate.class);
bind(STSCredentialsProvider.class).to(MetadataCredentialsProvider.class).in(Singleton.class);
bind(PoliciesUnionMapperImplApplicationConfig.class).to(ApplicationConfiguration.class)
.in(Immediate.class);
Expand Down

0 comments on commit e4a9c4d

Please sign in to comment.