Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
add pattern to assume role with sts auth
Browse files Browse the repository at this point in the history
  • Loading branch information
Rebecca Dong committed Aug 27, 2020
1 parent f7648fa commit e54e214
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ import com.amazonaws.DefaultRequest
import com.amazonaws.auth.AWS4Signer
import com.amazonaws.auth.AWSCredentials
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider
import com.amazonaws.auth.profile.internal.securitytoken.RoleInfo
import com.amazonaws.auth.profile.internal.securitytoken.STSProfileCredentialsServiceProvider
import com.amazonaws.http.HttpMethodName
import com.amazonaws.regions.Regions
import com.amazonaws.services.kms.AWSKMSClient
import com.amazonaws.services.kms.model.DecryptRequest
import com.amazonaws.services.kms.model.DecryptResult
import com.amazonaws.services.securitytoken.AWSSecurityTokenService
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder
import com.nike.cerberus.util.PropUtils
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
Expand All @@ -41,6 +44,7 @@ import org.apache.http.HttpStatus
import java.nio.ByteBuffer
import java.util.concurrent.TimeUnit

import static com.nike.cerberus.api.util.TestUtils.updateArnWithPartition
import static io.restassured.RestAssured.*
import static io.restassured.module.jsv.JsonSchemaValidator.*
import static org.hamcrest.Matchers.*
Expand Down Expand Up @@ -123,7 +127,7 @@ class CerberusApiActions {
* Generates and returns signed headers.
* @return Signed headers
*/
static Map<String, String> getSignedHeaders(String region){
static Map<String, String> getSignedHeaders(String region, String accountId, String roleName){

String url = "https://sts." + region + ".amazonaws.com";
if(CHINA_REGIONS.contains(region)) {
Expand All @@ -132,6 +136,13 @@ class CerberusApiActions {

URI endpoint = null;

def iamPrincipalArn = updateArnWithPartition("arn:aws:iam::$accountId:role/$roleName")
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard().withRegion(Regions.fromName(region)).build()
def credentials = new STSAssumeRoleSessionCredentialsProvider.Builder(iamPrincipalArn, UUID.randomUUID().toString())
.withStsClient(stsClient)
.build()
.getCredentials()

try {
endpoint = new URI(url);
} catch (URISyntaxException e) {
Expand All @@ -150,15 +161,15 @@ class CerberusApiActions {

System.out.println(String.format("Signing request with [%s] as host", url));

signRequest(requestToSign, DefaultAWSCredentialsProviderChain.getInstance().getCredentials(), region);
signRequest(requestToSign, credentials, region);

return requestToSign.getHeaders();
}

static def retrieveStsToken(String region) {
static def retrieveStsToken(String region, String accountId, String roleName) {
// get the encrypted payload and validate response

Map<String, String> signedHeaders = getSignedHeaders(region);
Map<String, String> signedHeaders = getSignedHeaders(region, accountId, roleName);

Response response =
given()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CerberusIamApiV2Tests {
mapper = new ObjectMapper()
TestUtils.configureRestAssured()
loadRequiredEnvVars()
cerberusAuthData = retrieveStsToken(region)
cerberusAuthData = retrieveStsToken(region, accountId, roleName)
cerberusAuthToken = cerberusAuthData."client_token"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class NegativeIamPermissionsApiTests {
String userGroupOfTestUser = ownerGroup

String iamPrincipalArn = updateArnWithPartition("arn:aws:iam::${accountId}:role/${roleName}")
def iamAuthData = retrieveStsToken(region)
def iamAuthData = retrieveStsToken(region, accountId, roleName)
iamAuthToken = iamAuthData."client_token"

String sdbCategoryId = getCategoryMap(userAuthToken).Applications
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class NegativeUserPermissionsApiTests {
loadRequiredEnvVars()
userAuthData = retrieveUserAuthToken(username, password, otpSecret, otpDeviceId)
String iamPrincipalArn = updateArnWithPartition("arn:aws:iam::${accountId}:role/${roleName}")
def iamAuthData = retrieveStsToken(region)
def iamAuthData = retrieveStsToken(region, accountId, roleName)
userAuthToken = userAuthData."client_token"
iamAuthToken = iamAuthData."client_token"
String userGroupOfTestUser = userGroup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ValidationErrorApiTests {
TestUtils.configureRestAssured()
loadRequiredEnvVars()
String iamPrincipalArn = updateArnWithPartition("arn:aws:iam::${accountId}:role/${roleName}")
def iamAuthData = retrieveStsToken(region)
def iamAuthData = retrieveStsToken(region, accountId, roleName)
iamAuthToken = iamAuthData."client_token"

String sdbCategoryId = getCategoryMap(iamAuthToken).Applications
Expand All @@ -67,7 +67,7 @@ class ValidationErrorApiTests {
testSdb = createSdbV2(iamAuthToken, TestUtils.generateRandomSdbName(), sdbDescription, sdbCategoryId, iamPrincipalArn, [], iamPrincipalPermissions)

// regenerate token to get policy for new SDB
iamAuthData = retrieveStsToken(region)
iamAuthData = retrieveStsToken(region, accountId, roleName)
iamAuthToken = iamAuthData."client_token"
}

Expand Down

0 comments on commit e54e214

Please sign in to comment.