Skip to content

Commit

Permalink
Implement retry mechanism when creating and assuming an IAM role (#358)
Browse files Browse the repository at this point in the history
* Implement retry mechanism when creating and assuming an IAM role (closes #353)

* Add temporary debug code
  • Loading branch information
christophetd committed May 30, 2023
1 parent eb3b361 commit a6351e3
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 21 deletions.
1 change: 1 addition & 0 deletions v2/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/ssm v1.31.0
github.com/aws/aws-sdk-go-v2/service/sts v1.16.19
github.com/aws/smithy-go v1.13.3
github.com/cenkalti/backoff/v4 v4.2.1
github.com/fatih/color v1.13.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
Expand Down
2 changes: 2 additions & 0 deletions v2/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.16.19 h1:9pPi0PsFNAGILFfPCk8Y0iyEBGc
github.com/aws/aws-sdk-go-v2/service/sts v1.16.19/go.mod h1:h4J3oPZQbxLhzGnk+j9dfYHi5qIOVJ5kczZd658/ydM=
github.com/aws/smithy-go v1.13.3 h1:l7LYxGuzK6/K+NzJ2mC+VvLUbae0sL3bXU//04MkmnA=
github.com/aws/smithy-go v1.13.3/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA=
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ import (
"context"
_ "embed"
"errors"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/datadog/stratus-red-team/v2/internal/utils"
"github.com/datadog/stratus-red-team/v2/pkg/stratus"
"github.com/datadog/stratus-red-team/v2/pkg/stratus/mitreattack"
Expand Down Expand Up @@ -52,12 +49,12 @@ func detonate(params map[string]string, providers stratus.CloudProviders) error
roleArn := params["role_arn"]

awsConnection := providers.AWS().GetConnection()
stsClient := sts.NewFromConfig(awsConnection)
awsConnection.Credentials = aws.NewCredentialsCache(stscreds.NewAssumeRoleProvider(stsClient, roleArn))
if err := utils.WaitForAndAssumeAWSRole(&awsConnection, roleArn); err != nil {
return err
}
ec2Client := ec2.NewFromConfig(awsConnection)

log.Println("Running ec2:GetPasswordData on " + strconv.Itoa(numCalls) + " random instance IDs")

for i := 0; i < numCalls; i++ {
// Generate a fake, real-looking instance ID
// Since we don't have the permission, we don't care if the instance actually exists
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import (
"context"
_ "embed"
"errors"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/organizations"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/datadog/stratus-red-team/v2/internal/utils"
"github.com/datadog/stratus-red-team/v2/pkg/stratus"
"github.com/datadog/stratus-red-team/v2/pkg/stratus/mitreattack"
"log"
Expand Down Expand Up @@ -51,8 +49,9 @@ func detonate(params map[string]string, providers stratus.CloudProviders) error
roleArn := params["role_arn"]

awsConnection := providers.AWS().GetConnection()
stsClient := sts.NewFromConfig(awsConnection)
awsConnection.Credentials = aws.NewCredentialsCache(stscreds.NewAssumeRoleProvider(stsClient, roleArn))
if err := utils.WaitForAndAssumeAWSRole(&awsConnection, roleArn); err != nil {
return err
}
organizationsClient := organizations.NewFromConfig(awsConnection)

log.Println("Attempting to leave the AWS organization (will trigger an Access Denied error)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ import (
"github.com/datadog/stratus-red-team/v2/pkg/stratus/mitreattack"
"log"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
)

//go:embed main.tf
Expand Down Expand Up @@ -58,10 +55,11 @@ See:
const numCalls = 15

func detonate(params map[string]string, providers stratus.CloudProviders) error {

roleArn := params["role_arn"]
awsConnection := providers.AWS().GetConnection()
stsClient := sts.NewFromConfig(awsConnection)
awsConnection.Credentials = aws.NewCredentialsCache(stscreds.NewAssumeRoleProvider(stsClient, params["role_arn"]))
if err := utils.WaitForAndAssumeAWSRole(&awsConnection, roleArn); err != nil {
return err
}
ec2Client := ec2.NewFromConfig(awsConnection)

for i := 0; i < numCalls; i++ {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ import (
_ "embed"
"errors"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/datadog/stratus-red-team/v2/internal/utils"
"github.com/datadog/stratus-red-team/v2/pkg/stratus"
"github.com/datadog/stratus-red-team/v2/pkg/stratus/mitreattack"
"log"
Expand Down Expand Up @@ -57,8 +56,9 @@ func detonate(params map[string]string, providers stratus.CloudProviders) error
roleArn := params["role_arn"]
subnetId := params["subnet_id"]

stsClient := sts.NewFromConfig(awsConnection)
awsConnection.Credentials = aws.NewCredentialsCache(stscreds.NewAssumeRoleProvider(stsClient, roleArn))
if err := utils.WaitForAndAssumeAWSRole(&awsConnection, roleArn); err != nil {
return err
}
ec2Client := ec2.NewFromConfig(awsConnection)

log.Printf("Attempting to run up to %d instances of type %s\n", numInstances, string(instanceType))
Expand Down
29 changes: 29 additions & 0 deletions v2/internal/utils/aws_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ package utils

import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
backoff "github.com/cenkalti/backoff/v4"
"log"
"strings"
"time"
)

func GetCurrentAccountId(cfg aws.Config) (string, error) {
Expand All @@ -31,6 +35,31 @@ func AwsConfigFromCredentials(accessKeyId string, secretAccessKey string, sessio
return cfg
}

// WaitForAndAssumeAWSRole waits for an AWS role to be assumable (due to eventual consistency)
// then sets a credentials provider that can be used to assume the role.
func WaitForAndAssumeAWSRole(awsConnection *aws.Config, roleArn string) error {
assumeRoleProvider := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(*awsConnection), roleArn)
backoffStrategy := backoff.NewExponentialBackOff()
backoffStrategy.InitialInterval = 1 * time.Second // try to assume the role after 1s
backoffStrategy.Multiplier = 2 // double the interval after each failed attempt
backoffStrategy.MaxInterval = 10 * time.Second // never wait more than 10s between attempts
backoffStrategy.MaxElapsedTime = 1 * time.Minute // stop trying after 1 minute
err := backoff.Retry(func() error {
_, err := assumeRoleProvider.Retrieve(context.Background())
if err == nil {
log.Println("[DEBUG] Successfully assumed role!")
} else {
log.Println("[DEBUG] Unable to assume role, error: ", err.Error())
}
return err
}, backoffStrategy)
if err != nil {
return fmt.Errorf("unable to assume role %s: %v", roleArn, err)
}
awsConnection.Credentials = aws.NewCredentialsCache(assumeRoleProvider)
return nil
}

func IsErrorDueToEBSEncryptionByDefault(err error) bool {
if err == nil {
return false
Expand Down

0 comments on commit a6351e3

Please sign in to comment.