diff --git a/internal/pkg/aws/sessions/sessions.go b/internal/pkg/aws/sessions/sessions.go index 022cadf0b92..18159f3f5bf 100644 --- a/internal/pkg/aws/sessions/sessions.go +++ b/internal/pkg/aws/sessions/sessions.go @@ -64,32 +64,24 @@ func UserAgentExtras(extras ...string) func(*Provider) { } // Default returns a session configured against the "default" AWS profile. +// Default assumes that a region must be present with a session, otherwise it returns an error. func (p *Provider) Default() (*session.Session, error) { - if p.defaultSess != nil { - return p.defaultSess, nil - } - - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *newConfig(), - SharedConfigState: session.SharedConfigEnable, - }) + sess, err := p.defaultSession() if err != nil { return nil, err } if aws.StringValue(sess.Config.Region) == "" { return nil, &errMissingRegion{} } - - sess.Handlers.Build.PushBackNamed(p.userAgentHandler()) - p.defaultSess = sess return sess, nil } // DefaultWithRegion returns a session configured against the "default" AWS profile and the input region. func (p *Provider) DefaultWithRegion(region string) (*session.Session, error) { sess, err := session.NewSessionWithOptions(session.Options{ - Config: *newConfig().WithRegion(region), - SharedConfigState: session.SharedConfigEnable, + Config: *newConfig().WithRegion(region), + SharedConfigState: session.SharedConfigEnable, + AssumeRoleTokenProvider: stscreds.StdinTokenProvider, }) if err != nil { return nil, err @@ -101,9 +93,10 @@ func (p *Provider) DefaultWithRegion(region string) (*session.Session, error) { // FromProfile returns a session configured against the input profile name. func (p *Provider) FromProfile(name string) (*session.Session, error) { sess, err := session.NewSessionWithOptions(session.Options{ - Config: *newConfig(), - SharedConfigState: session.SharedConfigEnable, - Profile: name, + Config: *newConfig(), + SharedConfigState: session.SharedConfigEnable, + Profile: name, + AssumeRoleTokenProvider: stscreds.StdinTokenProvider, }) if err != nil { return nil, err @@ -117,14 +110,10 @@ func (p *Provider) FromProfile(name string) (*session.Session, error) { // FromRole returns a session configured against the input role and region. func (p *Provider) FromRole(roleARN string, region string) (*session.Session, error) { - defaultSession, err := session.NewSessionWithOptions(session.Options{ - Config: *newConfig(), - SharedConfigState: session.SharedConfigEnable, - }) + defaultSession, err := p.defaultSession() if err != nil { - return nil, fmt.Errorf("error creating default session: %w", err) + return nil, fmt.Errorf("create default session: %w", err) } - defaultSession.Handlers.Build.PushBackNamed(p.userAgentHandler()) creds := stscreds.NewCredentials(defaultSession, roleARN) sess, err := session.NewSession( @@ -153,6 +142,24 @@ func (p *Provider) FromStaticCreds(accessKeyID, secretAccessKey, sessionToken st return sess, nil } +func (p *Provider) defaultSession() (*session.Session, error) { + if p.defaultSess != nil { + return p.defaultSess, nil + } + + sess, err := session.NewSessionWithOptions(session.Options{ + Config: *newConfig(), + SharedConfigState: session.SharedConfigEnable, + AssumeRoleTokenProvider: stscreds.StdinTokenProvider, + }) + if err != nil { + return nil, err + } + sess.Handlers.Build.PushBackNamed(p.userAgentHandler()) + p.defaultSess = sess + return sess, nil +} + // AreCredsFromEnvVars returns true if the session's credentials provider is environment variables, false otherwise. // An error is returned if the credentials are invalid or the request times out. func AreCredsFromEnvVars(sess *session.Session) (bool, error) {