Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions internal/pkg/aws/sessions/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,24 @@ func UserAgentExtras(extras ...string) func(*Provider) {
}

// Default returns a session configured against the "default" AWS profile.
// Default assumes that a region must be present with a session, otherwise it returns an error.
func (p *Provider) Default() (*session.Session, error) {
if p.defaultSess != nil {
return p.defaultSess, nil
}

sess, err := session.NewSessionWithOptions(session.Options{
Config: *newConfig(),
SharedConfigState: session.SharedConfigEnable,
})
sess, err := p.defaultSession()
if err != nil {
return nil, err
}
if aws.StringValue(sess.Config.Region) == "" {
return nil, &errMissingRegion{}
}

sess.Handlers.Build.PushBackNamed(p.userAgentHandler())
p.defaultSess = sess
return sess, nil
}

// DefaultWithRegion returns a session configured against the "default" AWS profile and the input region.
func (p *Provider) DefaultWithRegion(region string) (*session.Session, error) {
sess, err := session.NewSessionWithOptions(session.Options{
Config: *newConfig().WithRegion(region),
SharedConfigState: session.SharedConfigEnable,
Config: *newConfig().WithRegion(region),
SharedConfigState: session.SharedConfigEnable,
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
})
if err != nil {
return nil, err
Expand All @@ -101,9 +93,10 @@ func (p *Provider) DefaultWithRegion(region string) (*session.Session, error) {
// FromProfile returns a session configured against the input profile name.
func (p *Provider) FromProfile(name string) (*session.Session, error) {
sess, err := session.NewSessionWithOptions(session.Options{
Config: *newConfig(),
SharedConfigState: session.SharedConfigEnable,
Profile: name,
Config: *newConfig(),
SharedConfigState: session.SharedConfigEnable,
Profile: name,
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
})
if err != nil {
return nil, err
Expand All @@ -117,14 +110,10 @@ func (p *Provider) FromProfile(name string) (*session.Session, error) {

// FromRole returns a session configured against the input role and region.
func (p *Provider) FromRole(roleARN string, region string) (*session.Session, error) {
defaultSession, err := session.NewSessionWithOptions(session.Options{
Config: *newConfig(),
SharedConfigState: session.SharedConfigEnable,
})
defaultSession, err := p.defaultSession()
if err != nil {
return nil, fmt.Errorf("error creating default session: %w", err)
return nil, fmt.Errorf("create default session: %w", err)
}
defaultSession.Handlers.Build.PushBackNamed(p.userAgentHandler())

creds := stscreds.NewCredentials(defaultSession, roleARN)
sess, err := session.NewSession(
Expand Down Expand Up @@ -153,6 +142,24 @@ func (p *Provider) FromStaticCreds(accessKeyID, secretAccessKey, sessionToken st
return sess, nil
}

func (p *Provider) defaultSession() (*session.Session, error) {
if p.defaultSess != nil {
return p.defaultSess, nil
}

sess, err := session.NewSessionWithOptions(session.Options{
Config: *newConfig(),
SharedConfigState: session.SharedConfigEnable,
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
})
if err != nil {
return nil, err
}
sess.Handlers.Build.PushBackNamed(p.userAgentHandler())
p.defaultSess = sess
return sess, nil
}

// AreCredsFromEnvVars returns true if the session's credentials provider is environment variables, false otherwise.
// An error is returned if the credentials are invalid or the request times out.
func AreCredsFromEnvVars(sess *session.Session) (bool, error) {
Expand Down