Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws/credentials: Add ProviderWithContext optional interface to support passing contexts on credential retrieval #3223

Merged
merged 5 commits into from
Apr 1, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions aws/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ type Provider interface {
IsExpired() bool
}

// ProviderWithContext is a Provider that can retrieve credentials with a Context
type ProviderWithContext interface {
Provider

RetrieveWithContext(Context) (Value, error)
}

// An Expirer is an interface that Providers can implement to expose the expiration
// time, if known. If the Provider cannot accurately provide this info,
// it should not implement this interface.
Expand Down Expand Up @@ -233,7 +240,9 @@ func (c *Credentials) GetWithContext(ctx Context) (Value, error) {
// Cannot pass context down to the actual retrieve, because the first
// context would cancel the whole group when there is not direct
// association of items in the group.
resCh := c.sf.DoChan("", c.singleRetrieve)
resCh := c.sf.DoChan("", func() (interface{}, error) {
return c.singleRetrieve(&suppressedContext{ctx})
})
select {
case res := <-resCh:
return res.Val.(Value), res.Err
Expand All @@ -243,12 +252,16 @@ func (c *Credentials) GetWithContext(ctx Context) (Value, error) {
}
}

func (c *Credentials) singleRetrieve() (interface{}, error) {
func (c *Credentials) singleRetrieve(ctx Context) (creds interface{}, err error) {
if curCreds := c.creds.Load(); !c.isExpired(curCreds) {
return curCreds.(Value), nil
}

creds, err := c.provider.Retrieve()
if p, ok := c.provider.(ProviderWithContext); ok {
creds, err = p.RetrieveWithContext(ctx)
} else {
creds, err = c.provider.Retrieve()
}
if err == nil {
c.creds.Store(creds)
}
Expand Down Expand Up @@ -308,3 +321,19 @@ func (c *Credentials) ExpiresAt() (time.Time, error) {
}
return expirer.ExpiresAt(), nil
}

type suppressedContext struct {
Context
}

func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}

func (s *suppressedContext) Done() <-chan struct{} {
return nil
}

func (s *suppressedContext) Err() error {
return nil
}
20 changes: 14 additions & 6 deletions aws/credentials/ec2rolecreds/ec2_role_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
Expand Down Expand Up @@ -87,7 +88,14 @@ func NewCredentialsWithClient(client *ec2metadata.EC2Metadata, options ...func(*
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
credsList, err := requestCredList(m.Client)
return m.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
credsList, err := requestCredList(ctx, m.Client)
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
Expand All @@ -97,7 +105,7 @@ func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
}
credsName := credsList[0]

roleCreds, err := requestCred(m.Client, credsName)
roleCreds, err := requestCred(ctx, m.Client, credsName)
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
Expand Down Expand Up @@ -130,8 +138,8 @@ const iamSecurityCredsPath = "iam/security-credentials/"

// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
resp, err := client.GetMetadata(iamSecurityCredsPath)
func requestCredList(ctx aws.Context, client *ec2metadata.EC2Metadata) ([]string, error) {
resp, err := client.GetMetadataWithContext(ctx, iamSecurityCredsPath)
if err != nil {
return nil, awserr.New("EC2RoleRequestError", "no EC2 instance role found", err)
}
Expand All @@ -154,8 +162,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(sdkuri.PathJoin(iamSecurityCredsPath, credsName))
func requestCred(ctx aws.Context, client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadataWithContext(ctx, sdkuri.PathJoin(iamSecurityCredsPath, credsName))
if err != nil {
return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError",
Expand Down
11 changes: 9 additions & 2 deletions aws/credentials/endpointcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ func (p *Provider) IsExpired() bool {
// Retrieve will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) Retrieve() (credentials.Value, error) {
resp, err := p.getCredentials()
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
resp, err := p.getCredentials(ctx)
if err != nil {
return credentials.Value{ProviderName: ProviderName},
awserr.New("CredentialsEndpointError", "failed to load credentials", err)
Expand Down Expand Up @@ -148,14 +154,15 @@ type errorOutput struct {
Message string `json:"message"`
}

func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error) {
op := &request.Operation{
Name: "GetCredentials",
HTTPMethod: "GET",
}

out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out)
req.SetContext(ctx)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
Expand Down
20 changes: 19 additions & 1 deletion aws/credentials/stscreds/assume_role_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkrand"
"github.com/aws/aws-sdk-go/service/sts"
)
Expand Down Expand Up @@ -118,6 +119,10 @@ type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}

type assumeRolerWithContext interface {
AssumeRoleWithContext(aws.Context, *sts.AssumeRoleInput, ...request.Option) (*sts.AssumeRoleOutput, error)
}

// DefaultDuration is the default amount of time in minutes that the credentials
// will be valid for.
var DefaultDuration = time.Duration(15) * time.Minute
Expand Down Expand Up @@ -265,6 +270,11 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*

// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
Expand Down Expand Up @@ -304,7 +314,15 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
}
}

roleOutput, err := p.Client.AssumeRole(input)
var roleOutput *sts.AssumeRoleOutput
var err error

if c, ok := p.Client.(assumeRolerWithContext); ok {
roleOutput, err = c.AssumeRoleWithContext(ctx, input)
} else {
roleOutput, err = p.Client.AssumeRole(input)
}

if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
Expand Down
41 changes: 41 additions & 0 deletions aws/credentials/stscreds/assume_role_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/sts"
)

Expand All @@ -29,6 +31,16 @@ func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput,
}, nil
}

type stubSTSWithContext struct {
stubSTS
called chan struct{}
}

func (s *stubSTSWithContext) AssumeRoleWithContext(context credentials.Context, input *sts.AssumeRoleInput, option ...request.Option) (*sts.AssumeRoleOutput, error) {
<-s.called
return s.stubSTS.AssumeRole(input)
}

func TestAssumeRoleProvider(t *testing.T) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Expand Down Expand Up @@ -223,3 +235,32 @@ func TestAssumeRoleProvider_WithTags(t *testing.T) {
t.Errorf("expect error")
}
}

func TestAssumeRoleProvider_RetrieveWithContext(t *testing.T) {
stub := &stubSTSWithContext{
called: make(chan struct{}),
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}

go func() {
stub.called <- struct{}{}
}()

creds, err := p.RetrieveWithContext(aws.BackgroundContext())
if err != nil {
t.Errorf("expect nil, got %v", err)
}

if e, a := "roleARN", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
10 changes: 10 additions & 0 deletions aws/credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, p
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
b, err := ioutil.ReadFile(p.tokenFilePath)
if err != nil {
errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath)
Expand All @@ -81,6 +88,9 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})

req.SetContext(ctx)

// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.
req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
Expand Down
Loading