Skip to content

Commit

Permalink
changing requestRoleCredentials (#162)
Browse files Browse the repository at this point in the history
* changing requestRoleCredentials

* remove duration since it's default anyways'

* removing idea
  • Loading branch information
wild-endeavor authored and jprobinson committed Nov 6, 2018
1 parent 19b29a0 commit 5d6ab62
Showing 1 changed file with 20 additions and 29 deletions.
49 changes: 20 additions & 29 deletions pubsub/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/aws/aws-sdk-go/service/sns/snsiface"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
)
Expand Down Expand Up @@ -43,24 +42,24 @@ func NewPublisher(cfg SNSConfig) (pubsub.Publisher, error) {
return p, errors.New("SNS region is required")
}

sess, err := session.NewSession()
if err != nil {
return p, err
}

var creds *credentials.Credentials
if cfg.AccessKey != "" {
creds = credentials.NewStaticCredentials(cfg.AccessKey, cfg.SecretKey, cfg.SessionToken)
} else if cfg.RoleARN != "" {
var err error
creds, err = requestRoleCredentials(creds, cfg.RoleARN, cfg.MFASerialNumber)
creds, err = requestRoleCredentials(sess, cfg.RoleARN, cfg.MFASerialNumber)
if err != nil {
return p, err
}
} else {
creds = credentials.NewEnvCredentials()
}

sess, err := session.NewSession()
if err != nil {
return p, err
}

p.sns = sns.New(sess, &aws.Config{
Credentials: creds,
Region: &cfg.Region,
Expand Down Expand Up @@ -199,23 +198,24 @@ func NewSubscriber(cfg SQSConfig) (pubsub.Subscriber, error) {
return s, errors.New("sqs queue name or url is required")
}

sess, err := session.NewSession()
if err != nil {
return s, err
}

var creds *credentials.Credentials
if cfg.AccessKey != "" {
creds = credentials.NewStaticCredentials(cfg.AccessKey, cfg.SecretKey, cfg.SessionToken)
} else if cfg.RoleARN != "" {
var err error
creds, err = requestRoleCredentials(creds, cfg.RoleARN, cfg.MFASerialNumber)
creds, err = requestRoleCredentials(sess, cfg.RoleARN, cfg.MFASerialNumber)
if err != nil {
return s, err
}
} else {
creds = credentials.NewEnvCredentials()
}

sess, err := session.NewSession()
if err != nil {
return s, err
}
s.sqs = sqs.New(sess, &aws.Config{
Credentials: creds,
Region: &cfg.Region,
Expand Down Expand Up @@ -399,24 +399,15 @@ func (s *subscriber) Err() error {

// requestRoleCredentials return the credentials from AssumeRoleProvider to assume the role
// referenced by the roleARN. If MFASerialNumber is specified, prompt for MFA token from stdin.
func requestRoleCredentials(creds *credentials.Credentials, roleARN string, MFASerialNumber string) (*credentials.Credentials, error) {
func requestRoleCredentials(sess *session.Session, roleARN string, MFASerialNumber string) (*credentials.Credentials, error) {
if roleARN == "" {
return nil, errors.New("role ARN is required")
}
sess, err := session.NewSessionWithOptions(session.Options{
Config: *aws.NewConfig().WithCredentials(creds),
})
if err != nil {
return nil, err
}
assumeRole := &stscreds.AssumeRoleProvider{
Client: sts.New(sess),
RoleARN: roleARN,
Duration: stscreds.DefaultDuration,
}
if MFASerialNumber != "" {
assumeRole.SerialNumber = &MFASerialNumber
assumeRole.TokenProvider = stscreds.StdinTokenProvider
}
return credentials.NewCredentials(assumeRole), nil

return stscreds.NewCredentials(sess, roleARN, func(provider *stscreds.AssumeRoleProvider) {
if MFASerialNumber != "" {
provider.SerialNumber = &MFASerialNumber
provider.TokenProvider = stscreds.StdinTokenProvider
}
}), nil
}

0 comments on commit 5d6ab62

Please sign in to comment.