From 42260bb91c37d4e9f31587e92c434d462077d266 Mon Sep 17 00:00:00 2001 From: Steven Yuan Date: Thu, 20 Apr 2023 14:33:25 -0700 Subject: [PATCH] Fix EC2 Presigned URL customization (#4808) In some cases, the Presigned URL was NOT being sent when calling EC2 CopySnapshot. This PR fixes the customization responsible for generating the Presigned URL. --- service/ec2/customizations.go | 22 +++- service/ec2/customizations_test.go | 186 +++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+), 5 deletions(-) diff --git a/service/ec2/customizations.go b/service/ec2/customizations.go index 5b5395356fa..621712d29f0 100644 --- a/service/ec2/customizations.go +++ b/service/ec2/customizations.go @@ -11,6 +11,9 @@ import ( ) const ( + // ec2CopySnapshotPresignedUrlCustomization handler name + ec2CopySnapshotPresignedUrlCustomization = "ec2CopySnapshotPresignedUrl" + // customRetryerMinRetryDelay sets min retry delay customRetryerMinRetryDelay = 1 * time.Second @@ -21,7 +24,10 @@ const ( func init() { initRequest = func(r *request.Request) { if r.Operation.Name == opCopySnapshot { // fill the PresignedURL parameter - r.Handlers.Build.PushFront(fillPresignedURL) + r.Handlers.Build.PushFrontNamed(request.NamedHandler{ + Name: ec2CopySnapshotPresignedUrlCustomization, + Fn: fillPresignedURL, + }) } // only set the retryer on request if config doesn't have a retryer @@ -48,13 +54,15 @@ func fillPresignedURL(r *request.Request) { origParams := r.Params.(*CopySnapshotInput) - // Stop if PresignedURL/DestinationRegion is set - if origParams.PresignedUrl != nil || origParams.DestinationRegion != nil { + // Stop if PresignedURL is set + if origParams.PresignedUrl != nil { return } + // Always use config region as destination region for SDKs origParams.DestinationRegion = r.Config.Region - newParams := awsutil.CopyOf(r.Params).(*CopySnapshotInput) + + newParams := awsutil.CopyOf(origParams).(*CopySnapshotInput) // Create a new request based on the existing request. We will use this to // presign the CopySnapshot request against the source region. @@ -82,8 +90,12 @@ func fillPresignedURL(r *request.Request) { clientInfo.Endpoint = resolved.URL clientInfo.SigningRegion = resolved.SigningRegion + // Copy handlers without Presigned URL customization to avoid an infinite loop + handlersWithoutPresignCustomization := r.Handlers.Copy() + handlersWithoutPresignCustomization.Build.RemoveByName(ec2CopySnapshotPresignedUrlCustomization) + // Presign a CopySnapshot request with modified params - req := request.New(*cfg, clientInfo, r.Handlers, r.Retryer, r.Operation, newParams, r.Data) + req := request.New(*cfg, clientInfo, handlersWithoutPresignCustomization, r.Retryer, r.Operation, newParams, r.Data) url, err := req.Presign(5 * time.Minute) // 5 minutes should be enough. if err != nil { // bubble error back up to original request r.Error = err diff --git a/service/ec2/customizations_test.go b/service/ec2/customizations_test.go index 4df996fd593..68dee8ec342 100644 --- a/service/ec2/customizations_test.go +++ b/service/ec2/customizations_test.go @@ -6,14 +6,17 @@ package ec2_test import ( "bytes" "context" + "fmt" "io/ioutil" "net/http" "net/url" "regexp" + "strconv" "testing" "github.com/aws/aws-sdk-go/aws" sdkclient "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/awstesting/unit" "github.com/aws/aws-sdk-go/service/ec2" @@ -55,6 +58,189 @@ func TestCopySnapshotPresignedURL(t *testing.T) { } } +func TestCopySnapshotPresignedURLConfig(t *testing.T) { + const ( + inputKmsKeyId = "KMS_KEY_ID" + inputSnapshotId = "SNAPSHOT_ID" + clientRegion = endpoints.UsEast1RegionID + inputSourceRegion = endpoints.UsWest2RegionID + ) + cases := map[string]struct { + Encrypted bool + DestinationRegion string + KmsKeyId string + }{ + // Not Encrypted + "Not Encrypted": {}, + // Not Encrypted with KmsKeyId + "Not Encrypted with KmsKeyId": { + KmsKeyId: inputKmsKeyId, + }, + // Not Encrypted with DestinationRegion + "Not Encrypted with DestinationRegion": { + DestinationRegion: endpoints.UsEast2RegionID, + }, + // Not Encrypted with KmsKeyId and DestinationRegion + "Not Encrypted with KmsKeyId and DestinationRegion": { + KmsKeyId: inputKmsKeyId, + DestinationRegion: endpoints.UsEast2RegionID, + }, + // Encrypted + "Encrypted": { + Encrypted: true, + }, + // Encrypted with KmsKeyId + "Encrypted with KmsKeyId": { + Encrypted: true, + KmsKeyId: inputKmsKeyId, + }, + // Encrypted with DestinationRegion + "Encrypted with DestinationRegion": { + Encrypted: true, + DestinationRegion: endpoints.UsEast2RegionID, + }, + // Encrypted with KmsKeyId and DestinationRegion + "Encrypted with KmsKeyId and DestinationRegion": { + Encrypted: true, + KmsKeyId: inputKmsKeyId, + DestinationRegion: endpoints.UsEast2RegionID, + }, + } + + for name, config := range cases { + t.Run(name, func(t *testing.T) { + t.Log(name) + + // Set up new client + svc := ec2.New(unit.Session, &aws.Config{ + Region: aws.String(clientRegion), + }) + + // Base input + input := ec2.CopySnapshotInput{ + SourceRegion: aws.String(inputSourceRegion), + SourceSnapshotId: aws.String(inputSnapshotId), + } + + // Add input from test case config + if config.Encrypted != false { + input.Encrypted = &config.Encrypted + } + if config.DestinationRegion != "" { + input.DestinationRegion = &config.DestinationRegion + } + if config.KmsKeyId != "" { + input.KmsKeyId = &config.KmsKeyId + } + + // Execute request + req, _ := svc.CopySnapshotRequest(&input) + req.Sign() + + // Parse request + body, _ := ioutil.ReadAll(req.HTTPRequest.Body) + query, _ := url.ParseQuery(string(body)) + + // Test Body SourceRegion + sourceRegion := query.Get("SourceRegion") + if sourceRegion == "" { + t.Errorf("SourceRegion should always be sent in the request") + } + if sourceRegion != inputSourceRegion { + t.Errorf("SourceRegion should be `%v`, but found `%v`", inputSourceRegion, sourceRegion) + } + // Test Body SourceSnapshotId + sourceSnapshotId := query.Get("SourceSnapshotId") + if sourceSnapshotId == "" { + t.Errorf("SourceSnapshotId should always be sent in the request") + } + if sourceSnapshotId != inputSnapshotId { + t.Errorf("SourceSnapshotId should be `%v`, but found `%v`", inputSnapshotId, sourceSnapshotId) + } + // Test Body Encrypted + encrypted := query.Get("Encrypted") + if config.Encrypted && strconv.FormatBool(config.Encrypted) != encrypted { + t.Errorf("Encrypted should be `%v`, but found `%v`", config.Encrypted, encrypted) + } + if !config.Encrypted && encrypted != "" { + t.Errorf("Encrypted should be empty, but found `%v`", encrypted) + } + // Test Body DestinationRegion + destinationRegion := query.Get("DestinationRegion") + if destinationRegion != clientRegion { + t.Errorf("DestinationRegion should always be equal to the client region `%v`, but found `%v`", clientRegion, destinationRegion) + } + if destinationRegion == "" { + t.Errorf("DestinationRegion should never empty") + } + // Test Body KmsKeyId + kmsKeyId := query.Get("KmsKeyId") + if config.KmsKeyId != "" && config.KmsKeyId != kmsKeyId { + t.Errorf("KmsKeyId should be `%v`, but found `%v`", config.KmsKeyId, kmsKeyId) + } + if config.KmsKeyId == "" && kmsKeyId != "" { + t.Errorf("KmsKeyId should be empty, but found `%v`", kmsKeyId) + } + + // Assert PresignedUrl + presignedUrl, _ := url.QueryUnescape(query.Get("PresignedUrl")) + if presignedUrl == "" { + t.Errorf("PresignedUrl should always be sent in the request") + } + // Test PresignedUrl EC2 URL + baseEc2UrlRegex := regexp.MustCompile(fmt.Sprintf(`^https://ec2\.%s\.amazonaws\.com/`, inputSourceRegion)) + if !baseEc2UrlRegex.MatchString(presignedUrl) { + t.Errorf("Expected PresignedUrl to match `%v`, but found `%v`", baseEc2UrlRegex.String(), presignedUrl) + } + + presignedUrlQuery, _ := url.ParseQuery(presignedUrl) + // Test PresignedUrl SourceRegion + presignedUrlSourceRegion := presignedUrlQuery.Get("SourceRegion") + if presignedUrlSourceRegion == "" { + t.Errorf("PresignedUrl SourceRegion should always be sent in the request") + } + if presignedUrlSourceRegion != inputSourceRegion { + t.Errorf("PresignedUrl SourceRegion should be `%v`, but found `%v`", inputSourceRegion, presignedUrlSourceRegion) + } + // Test PresignedUrl SourceSnapshotId + presignedUrlSourceSnapshotId := presignedUrlQuery.Get("SourceSnapshotId") + if presignedUrlSourceSnapshotId == "" { + t.Errorf("PresignedUrl SourceSnapshotId should always be sent in the request") + } + if presignedUrlSourceSnapshotId != inputSnapshotId { + t.Errorf("PresignedUrl SourceSnapshotId should be `%v`, but found `%v`", inputSnapshotId, presignedUrlSourceSnapshotId) + } + // Test PresignedUrl Encrypted + presignedUrlEncrypted := query.Get("Encrypted") + if config.Encrypted && strconv.FormatBool(config.Encrypted) != presignedUrlEncrypted { + t.Errorf("PresignedUrl Encrypted should be `%v`, but found `%v`", config.Encrypted, presignedUrlEncrypted) + } + if !config.Encrypted && presignedUrlEncrypted != "" { + t.Errorf("PresignedUrl Encrypted should be empty, but found `%v`", presignedUrlEncrypted) + } + // Test PresignedUrl DestinationRegion + presignedUrlDestinationRegion := presignedUrlQuery.Get("DestinationRegion") + if presignedUrlDestinationRegion != clientRegion { + t.Errorf("PresignedUrl DestinationRegion should always be equal to the client region `%v`, but found `%v`", clientRegion, presignedUrlDestinationRegion) + } + // Test PresignedUrl KmsKeyId + presignedUrlKmsKeyId := query.Get("KmsKeyId") + if config.KmsKeyId != "" && config.KmsKeyId != presignedUrlKmsKeyId { + t.Errorf("PresignedUrl KmsKeyId should be `%v`, but found `%v`", config.KmsKeyId, presignedUrlKmsKeyId) + } + if config.KmsKeyId == "" && presignedUrlKmsKeyId != "" { + t.Errorf("PresignedUrl KmsKeyId should be empty, but found `%v`", presignedUrlKmsKeyId) + } + // Test PresignedUrl X-Amz-Credential + presignedUrlAmzCredential := presignedUrlQuery.Get("X-Amz-Credential") + amzCredentialRegex := regexp.MustCompile(fmt.Sprintf(`^\w{4}/\d{8}/%s/ec2/aws4_request$`, inputSourceRegion)) + if !amzCredentialRegex.MatchString(presignedUrlAmzCredential) { + t.Errorf("Expected PresignedUrl X-Amz-Credential to match `%v`, but found `%v`", amzCredentialRegex.String(), presignedUrlAmzCredential) + } + }) + } +} + func TestNoCustomRetryerWithMaxRetries(t *testing.T) { cases := map[string]struct { Config aws.Config