Skip to content

Commit

Permalink
aws/credentials/stscreds: Expose Duration to WebIdentityRoleProvider (#…
Browse files Browse the repository at this point in the history
…3356)

Exposes Duration API parameter for retrieving credentials with WebIdentityRoleProvider
  • Loading branch information
AlexisMontagne committed Aug 10, 2020
1 parent eb7121d commit 29a5bc2
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
### SDK Features

### SDK Enhancements
* `aws/credentials/stscreds`: Add optional expiry duration to WebIdentityRoleProvider ([#3356](https://github.com/aws/aws-sdk-go/pull/3356))
* Adds a new optional field to the WebIdentityRoleProvider that allows you to specify the duration the assumed role credentials will be valid for.
* `example/service/s3/putObjectWithProgress`: Fix example for file upload with progress ([#3377](https://github.com/aws/aws-sdk-go/pull/3377))
* Fixes [#2468](https://github.com/aws/aws-sdk-go/issues/2468) by ignoring the first read of the progress reader wrapper. Since the first read is used for signing the request, not upload progress.
* Updated the example to write progress inline instead of newlines.
Expand Down
21 changes: 20 additions & 1 deletion aws/credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,21 @@ type WebIdentityRoleProvider struct {
credentials.Expiry
PolicyArns []*sts.PolicyDescriptorType

client stsiface.STSAPI
// Duration the STS credentials will be valid for. Truncated to seconds.
// If unset, the assumed role will use AssumeRoleWithWebIdentity's default
// expiry duration. See
// https://docs.aws.amazon.com/sdk-for-go/api/service/sts/#STS.AssumeRoleWithWebIdentity
// for more information.
Duration time.Duration

// The amount of time the credentials will be refreshed before they expire.
// This is useful refresh credentials before they expire to reduce risk of
// using credentials as they expire. If unset, will default to no expiry
// window.
ExpiryWindow time.Duration

client stsiface.STSAPI

tokenFetcher TokenFetcher
roleARN string
roleSessionName string
Expand Down Expand Up @@ -107,11 +119,18 @@ func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(now().UnixNano(), 10)
}

var duration *int64
if p.Duration != 0 {
duration = aws.Int64(int64(p.Duration / time.Second))
}

req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
PolicyArns: p.PolicyArns,
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
DurationSeconds: duration,
})

req.SetContext(ctx)
Expand Down
65 changes: 53 additions & 12 deletions aws/credentials/stscreds/web_identity_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package stscreds_test
import (
"net/http"
"reflect"
"strings"
"testing"
"time"

Expand All @@ -20,25 +21,27 @@ import (

func TestWebIdentityProviderRetrieve(t *testing.T) {
var reqCount int
cases := []struct {
name string
cases := map[string]struct {
onSendReq func(*testing.T, *request.Request)
roleARN string
tokenFilepath string
sessionName string
expectedError error
duration time.Duration
expectedError string
expectedCredValue credentials.Value
}{
{
name: "session name case",
"session name case": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) {
if e, a := "foo", *input.RoleSessionName; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if input.DurationSeconds != nil {
t.Errorf("expect no duration, got %v", *input.DurationSeconds)
}

data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Expand All @@ -57,8 +60,35 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
ProviderName: stscreds.WebIdentityProviderName,
},
},
{
name: "invalid token retry",
"with duration": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
duration: 15 * time.Minute,
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a {
t.Errorf("expect %v duration, got %v", e, a)
}

data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: stscreds.WebIdentityProviderName,
},
},
"invalid token retry": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
Expand Down Expand Up @@ -94,8 +124,8 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
for name, c := range cases {
t.Run(name, func(t *testing.T) {
reqCount = 0

svc := sts.New(unit.Session, &aws.Config{
Expand All @@ -116,9 +146,20 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
svc.Handlers.UnmarshalError.Clear()

p := stscreds.NewWebIdentityRoleProvider(svc, c.roleARN, c.sessionName, c.tokenFilepath)
p.Duration = c.duration

credValue, err := p.Retrieve()
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
if len(c.expectedError) != 0 {
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := c.expectedError, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %v, got %v", e, a)
}
return
}
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {
Expand Down

0 comments on commit 29a5bc2

Please sign in to comment.