Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for gMSA s3 test #3886

Merged
merged 6 commits into from Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 10 additions & 4 deletions agent/s3/factory/factory.go
Expand Up @@ -34,7 +34,7 @@ const (

type S3ClientCreator interface {
NewS3ManagerClient(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3ManagerClient, error)
NewS3Client(region string, creds credentials.IAMRoleCredentials) s3client.S3Client
NewS3Client(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3Client, error)
}

// NewS3ClientCreator provide 2 implementations
Expand Down Expand Up @@ -65,15 +65,21 @@ func (*s3ClientCreator) NewS3ManagerClient(bucket, region string,
}

// NewS3Client returns a new S3 client to support s3 operations which are not provided by s3manager.
func (*s3ClientCreator) NewS3Client(region string,
creds credentials.IAMRoleCredentials) s3client.S3Client {
func (*s3ClientCreator) NewS3Client(bucket, region string,
creds credentials.IAMRoleCredentials) (s3client.S3Client, error) {
cfg := aws.NewConfig().
WithHTTPClient(httpclient.New(roundtripTimeout, false)).
WithCredentials(
awscreds.NewStaticCredentials(creds.AccessKeyID, creds.SecretAccessKey,
creds.SessionToken)).WithRegion(region)
sess := session.Must(session.NewSession(cfg))
return s3.New(sess)
svc := s3.New(sess)
bucketRegion, err := getRegionFromBucket(svc, bucket)
if err != nil {
return nil, err
}
sessWithRegion := session.Must(session.NewSession(cfg.WithRegion(bucketRegion)))
return s3.New(sessWithRegion), nil
}
func getRegionFromBucket(svc *s3.S3, bucket string) (string, error) {
input := &s3.GetBucketLocationInput{
Expand Down
11 changes: 6 additions & 5 deletions agent/s3/factory/mocks/factory_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion agent/taskresource/credentialspec/credentialspec_linux.go
Expand Up @@ -435,7 +435,12 @@ func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialS
return
}

s3Client := cs.s3ClientCreator.NewS3Client(cs.region, iamCredentials)
s3Client, err := cs.s3ClientCreator.NewS3Client(bucket, cs.region, iamCredentials)
as14692 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
cs.setTerminalReason(err.Error())
errorEvents <- err
return
}

credSpecJsonStringUnformatted, err := s3.GetObject(bucket, key, s3Client)

Expand Down
Expand Up @@ -374,7 +374,7 @@ func TestHandleS3CredentialSpecFileGetS3SecretValue(t *testing.T) {
Body: io.NopCloser(strings.NewReader(testData)),
}
gomock.InOrder(
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client),
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil),
mockS3Client.EXPECT().GetObject(gomock.Any()).Return(s3GetObjectResponse, nil).Times(1),
)

Expand Down Expand Up @@ -439,7 +439,7 @@ func TestHandleS3DomainlessCredentialSpecFileGetS3SecretValue(t *testing.T) {
Body: io.NopCloser(strings.NewReader(testData)),
}
gomock.InOrder(
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client),
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil),
mockS3Client.EXPECT().GetObject(gomock.Any()).Return(s3GetObjectResponse, nil).Times(1),
)

Expand Down Expand Up @@ -501,7 +501,7 @@ func TestHandleS3CredentialSpecFileGetS3SecretValueErr(t *testing.T) {
}, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning)

gomock.InOrder(
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client),
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil),
mockS3Client.EXPECT().GetObject(gomock.Any()).Return(nil, errors.New("test-error")).Times(1),
)

Expand Down