diff --git a/contrib/aws/aws-sdk-go-v2/aws/aws.go b/contrib/aws/aws-sdk-go-v2/aws/aws.go index 52fffe7a77..8e07505308 100644 --- a/contrib/aws/aws-sdk-go-v2/aws/aws.go +++ b/contrib/aws/aws-sdk-go-v2/aws/aws.go @@ -22,6 +22,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/eventbridge" "github.com/aws/aws-sdk-go-v2/service/kinesis" @@ -62,9 +63,7 @@ const ( type spanTimestampKey struct{} -// AppendMiddleware takes the aws.Config and adds the Datadog tracing middleware into the APIOptions middleware stack. -// See https://aws.github.io/aws-sdk-go-v2/docs/middleware for more information. -func AppendMiddleware(awsCfg *aws.Config, opts ...Option) { +func prepConfig(opts ...Option) *config { cfg := &config{} defaults(cfg) @@ -72,8 +71,28 @@ func AppendMiddleware(awsCfg *aws.Config, opts ...Option) { opt(cfg) } + return cfg +} + +func appendMiddleware(cfg *config, opts *[]func(*middleware.Stack) error) { tm := traceMiddleware{cfg: cfg} - awsCfg.APIOptions = append(awsCfg.APIOptions, tm.initTraceMiddleware, tm.startTraceMiddleware, tm.deserializeTraceMiddleware) + *opts = append(*opts, tm.initTraceMiddleware, tm.startTraceMiddleware, tm.deserializeTraceMiddleware) +} + +// WithDataDogTracer returns an AWS config LoadOptionsFunc that adds the DataDog tracing middleware into the +// APIOptions middleware stack. +// See https://aws.github.io/aws-sdk-go-v2/docs/middleware for more information. +func WithDataDogTracer(opts ...Option) awsconfig.LoadOptionsFunc { + return func(awsOpt *awsconfig.LoadOptions) error { + appendMiddleware(prepConfig(opts...), &awsOpt.APIOptions) + return nil + } +} + +// AppendMiddleware takes the aws.Config and adds the Datadog tracing middleware into the APIOptions middleware stack. +// See https://aws.github.io/aws-sdk-go-v2/docs/middleware for more information. +func AppendMiddleware(awsCfg *aws.Config, opts ...Option) { + appendMiddleware(prepConfig(opts...), &awsCfg.APIOptions) } type traceMiddleware struct { diff --git a/contrib/aws/aws-sdk-go-v2/aws/aws_test.go b/contrib/aws/aws-sdk-go-v2/aws/aws_test.go index c5986cd02a..327e219db2 100644 --- a/contrib/aws/aws-sdk-go-v2/aws/aws_test.go +++ b/contrib/aws/aws-sdk-go-v2/aws/aws_test.go @@ -8,6 +8,7 @@ package aws import ( "context" "encoding/base64" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -34,7 +35,53 @@ import ( "github.com/stretchr/testify/require" ) -func newIntegrationTestConfig(t *testing.T, opts ...Option) aws.Config { +type awsConfMode string + +const ( + modeOpts awsConfMode = "opts" + modeCfg awsConfMode = "cfg" +) + +var ( + awsConfModes = []awsConfMode{modeCfg, modeOpts} +) + +func newMockAwsResolver(serverURL string) aws.EndpointResolverWithOptionsFunc { + return func(service, region string, options ...interface{}) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: serverURL, + SigningRegion: "eu-west-1", + }, nil + } +} + +func newMockTestConfig(t *testing.T, serverURL string, mode awsConfMode, awsOpts []func(options *awsconfig.LoadOptions) error, tracerOpts ...Option) aws.Config { + baseAwsOpts := []func(*awsconfig.LoadOptions) error{ + awsconfig.WithRegion("eu-west-1"), + awsconfig.WithCredentialsProvider(aws.AnonymousCredentials{}), + awsconfig.WithEndpointResolverWithOptions(newMockAwsResolver(serverURL)), + } + + if mode == modeOpts { + baseAwsOpts = append(baseAwsOpts, WithDataDogTracer(tracerOpts...)) + } + + awsCfg, err := awsconfig.LoadDefaultConfig( + context.Background(), + append(baseAwsOpts, awsOpts...)..., + ) + + require.NoError(t, err, "failed to init aws mock config") + + if mode == modeCfg { + AppendMiddleware(&awsCfg, tracerOpts...) + } + + return awsCfg +} + +func newIntegrationTestConfig(t *testing.T, mode awsConfMode, opts ...Option) aws.Config { if _, ok := os.LookupEnv("INTEGRATION"); !ok { t.Skip("🚧 Skipping integration test (INTEGRATION environment variable is not set)") } @@ -48,14 +95,27 @@ func newIntegrationTestConfig(t *testing.T, opts ...Option) aws.Config { SigningRegion: awsRegion, }, nil }) - cfg, err := awsconfig.LoadDefaultConfig( - context.Background(), + + baseAwsOpts := []func(*awsconfig.LoadOptions) error{ awsconfig.WithRegion(awsRegion), awsconfig.WithEndpointResolverWithOptions(customResolver), awsconfig.WithCredentialsProvider(aws.AnonymousCredentials{}), + } + + if mode == modeOpts { + baseAwsOpts = append(baseAwsOpts, WithDataDogTracer(opts...)) + } + + cfg, err := awsconfig.LoadDefaultConfig( + context.Background(), + baseAwsOpts..., ) require.NoError(t, err, "failed to load AWS config") - AppendMiddleware(&cfg, opts...) + + if mode == modeCfg { + AppendMiddleware(&cfg, opts...) + } + return cfg } @@ -77,59 +137,47 @@ func TestAppendMiddleware(t *testing.T) { expectedStatusCode: 200, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() - - server := mockAWS(tt.expectedStatusCode) - defer server.Close() - - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil - }) - - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } - - AppendMiddleware(&awsCfg) - - sqsClient := sqs.NewFromConfig(awsCfg) - sqsClient.SendMessage(context.Background(), &sqs.SendMessageInput{ - MessageBody: aws.String("foobar"), - QueueUrl: aws.String("https://sqs.us-west-2.amazonaws.com/123456789012/MyQueueName"), + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + server := mockAWS(tt.expectedStatusCode) + defer server.Close() + + awsCfg := newMockTestConfig(t, server.URL, cm, nil) + + sqsClient := sqs.NewFromConfig(awsCfg) + _, _ = sqsClient.SendMessage(context.Background(), &sqs.SendMessageInput{ + MessageBody: aws.String("foobar"), + QueueUrl: aws.String("https://sqs.us-west-2.amazonaws.com/123456789012/MyQueueName"), + }) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, "SQS.request", s.OperationName()) + assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") + assert.Equal(t, "SendMessage", s.Tag(tagAWSOperation)) + assert.Equal(t, "SQS", s.Tag(tagAWSService)) + assert.Equal(t, "SQS", s.Tag(tagService)) + assert.Equal(t, "MyQueueName", s.Tag(tagQueueName)) + + assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) + assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) + assert.Equal(t, "SQS.SendMessage", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.SQS", s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) + if tt.expectedStatusCode == 200 { + assert.Equal(t, "test_req", s.Tag("aws.request_id")) + } + assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) + assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) }) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, "SQS.request", s.OperationName()) - assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") - assert.Equal(t, "SendMessage", s.Tag(tagAWSOperation)) - assert.Equal(t, "SQS", s.Tag(tagAWSService)) - assert.Equal(t, "SQS", s.Tag(tagService)) - assert.Equal(t, "MyQueueName", s.Tag(tagQueueName)) - - assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) - assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) - assert.Equal(t, "SQS.SendMessage", s.Tag(ext.ResourceName)) - assert.Equal(t, "aws.SQS", s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) - if tt.expectedStatusCode == 200 { - assert.Equal(t, "test_req", s.Tag("aws.request_id")) - } - assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) - assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) - assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) - assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) - }) + } } } @@ -151,59 +199,48 @@ func TestAppendMiddlewareSqsDeleteMessage(t *testing.T) { expectedStatusCode: 200, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() - - server := mockAWS(tt.expectedStatusCode) - defer server.Close() - - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil - }) - - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } - - AppendMiddleware(&awsCfg) - - sqsClient := sqs.NewFromConfig(awsCfg) - sqsClient.DeleteMessage(context.Background(), &sqs.DeleteMessageInput{ - QueueUrl: aws.String("https://sqs.us-west-2.amazonaws.com/123456789012/MyQueueName"), - ReceiptHandle: aws.String("foobar"), + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + + mt := mocktracer.Start() + defer mt.Stop() + + server := mockAWS(tt.expectedStatusCode) + defer server.Close() + + awsCfg := newMockTestConfig(t, server.URL, cm, nil) + + sqsClient := sqs.NewFromConfig(awsCfg) + _, _ = sqsClient.DeleteMessage(context.Background(), &sqs.DeleteMessageInput{ + QueueUrl: aws.String("https://sqs.us-west-2.amazonaws.com/123456789012/MyQueueName"), + ReceiptHandle: aws.String("foobar"), + }) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, "SQS.request", s.OperationName()) + assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") + assert.Equal(t, "DeleteMessage", s.Tag(tagAWSOperation)) + assert.Equal(t, "SQS", s.Tag(tagAWSService)) + assert.Equal(t, "SQS", s.Tag(tagService)) + assert.Equal(t, "MyQueueName", s.Tag(tagQueueName)) + + assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) + assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) + assert.Equal(t, "SQS.DeleteMessage", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.SQS", s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) + if tt.expectedStatusCode == 200 { + assert.Equal(t, "test_req", s.Tag("aws.request_id")) + } + assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) + assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) }) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, "SQS.request", s.OperationName()) - assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") - assert.Equal(t, "DeleteMessage", s.Tag(tagAWSOperation)) - assert.Equal(t, "SQS", s.Tag(tagAWSService)) - assert.Equal(t, "SQS", s.Tag(tagService)) - assert.Equal(t, "MyQueueName", s.Tag(tagQueueName)) - - assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) - assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) - assert.Equal(t, "SQS.DeleteMessage", s.Tag(ext.ResourceName)) - assert.Equal(t, "aws.SQS", s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) - if tt.expectedStatusCode == 200 { - assert.Equal(t, "test_req", s.Tag("aws.request_id")) - } - assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) - assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) - assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) - assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) - }) + } } } func TestAppendMiddlewareSqsReceiveMessage(t *testing.T) { @@ -224,59 +261,47 @@ func TestAppendMiddlewareSqsReceiveMessage(t *testing.T) { expectedStatusCode: 200, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() - - server := mockAWS(tt.expectedStatusCode) - defer server.Close() - - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil - }) - - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } - - AppendMiddleware(&awsCfg) - - sqsClient := sqs.NewFromConfig(awsCfg) - sqsClient.ReceiveMessage(context.Background(), &sqs.ReceiveMessageInput{ - QueueUrl: aws.String("https://sqs.us-west-2.amazonaws.com/123456789012/MyQueueName"), + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + server := mockAWS(tt.expectedStatusCode) + defer server.Close() + + awsCfg := newMockTestConfig(t, server.URL, cm, nil) + + sqsClient := sqs.NewFromConfig(awsCfg) + _, _ = sqsClient.ReceiveMessage(context.Background(), &sqs.ReceiveMessageInput{ + QueueUrl: aws.String("https://sqs.us-west-2.amazonaws.com/123456789012/MyQueueName"), + }) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, "SQS.request", s.OperationName()) + assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") + assert.Equal(t, "ReceiveMessage", s.Tag(tagAWSOperation)) + assert.Equal(t, "SQS", s.Tag(tagAWSService)) + assert.Equal(t, "SQS", s.Tag(tagService)) + assert.Equal(t, "MyQueueName", s.Tag(tagQueueName)) + + assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) + assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) + assert.Equal(t, "SQS", s.Tag(tagAWSService)) + assert.Equal(t, "SQS.ReceiveMessage", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.SQS", s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) + if tt.expectedStatusCode == 200 { + assert.Equal(t, "test_req", s.Tag("aws.request_id")) + } + assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) + assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) }) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, "SQS.request", s.OperationName()) - assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") - assert.Equal(t, "ReceiveMessage", s.Tag(tagAWSOperation)) - assert.Equal(t, "SQS", s.Tag(tagAWSService)) - assert.Equal(t, "SQS", s.Tag(tagService)) - assert.Equal(t, "MyQueueName", s.Tag(tagQueueName)) - - assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) - assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) - assert.Equal(t, "SQS", s.Tag(tagAWSService)) - assert.Equal(t, "SQS.ReceiveMessage", s.Tag(ext.ResourceName)) - assert.Equal(t, "aws.SQS", s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) - if tt.expectedStatusCode == 200 { - assert.Equal(t, "test_req", s.Tag("aws.request_id")) - } - assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) - assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) - assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) - assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) - }) + } } } @@ -298,55 +323,43 @@ func TestAppendMiddlewareS3ListObjects(t *testing.T) { expectedStatusCode: 200, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() - - server := mockAWS(tt.expectedStatusCode) - defer server.Close() - - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil - }) - - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } - - AppendMiddleware(&awsCfg) - - s3Client := s3.NewFromConfig(awsCfg) - s3Client.ListObjects(context.Background(), &s3.ListObjectsInput{ - Bucket: aws.String("MyBucketName"), + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + server := mockAWS(tt.expectedStatusCode) + defer server.Close() + + awsCfg := newMockTestConfig(t, server.URL, cm, nil) + + s3Client := s3.NewFromConfig(awsCfg) + _, _ = s3Client.ListObjects(context.Background(), &s3.ListObjectsInput{ + Bucket: aws.String("MyBucketName"), + }) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, "S3.request", s.OperationName()) + assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") + assert.Equal(t, "ListObjects", s.Tag(tagAWSOperation)) + assert.Equal(t, "S3", s.Tag(tagAWSService)) + assert.Equal(t, "S3", s.Tag(tagService)) + assert.Equal(t, "MyBucketName", s.Tag(tagBucketName)) + + assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) + assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) + assert.Equal(t, "S3.ListObjects", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.S3", s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) + assert.Equal(t, "GET", s.Tag(ext.HTTPMethod)) + assert.Equal(t, server.URL+"/MyBucketName", s.Tag(ext.HTTPURL)) + assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) + assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) }) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, "S3.request", s.OperationName()) - assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") - assert.Equal(t, "ListObjects", s.Tag(tagAWSOperation)) - assert.Equal(t, "S3", s.Tag(tagAWSService)) - assert.Equal(t, "S3", s.Tag(tagService)) - assert.Equal(t, "MyBucketName", s.Tag(tagBucketName)) - - assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) - assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) - assert.Equal(t, "S3.ListObjects", s.Tag(ext.ResourceName)) - assert.Equal(t, "aws.S3", s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) - assert.Equal(t, "GET", s.Tag(ext.HTTPMethod)) - assert.Equal(t, server.URL+"/MyBucketName", s.Tag(ext.HTTPURL)) - assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) - assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) - }) + } } } @@ -394,53 +407,41 @@ func TestAppendMiddlewareSnsPublish(t *testing.T) { expectedStatusCode: 200, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() - - server := mockAWS(tt.expectedStatusCode) - defer server.Close() - - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + server := mockAWS(tt.expectedStatusCode) + defer server.Close() + + awsCfg := newMockTestConfig(t, server.URL, cm, nil) + + snsClient := sns.NewFromConfig(awsCfg) + _, _ = snsClient.Publish(context.Background(), tt.publishInput) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, "SNS.request", s.OperationName()) + assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") + assert.Equal(t, "Publish", s.Tag(tagAWSOperation)) + assert.Equal(t, "SNS", s.Tag(tagAWSService)) + assert.Equal(t, "SNS", s.Tag(tagService)) + assert.Equal(t, tt.expectedTagValue, s.Tag(tt.tagKey)) + + assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) + assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) + assert.Equal(t, "SNS.Publish", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.SNS", s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) + assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) + assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) }) - - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } - - AppendMiddleware(&awsCfg) - - snsClient := sns.NewFromConfig(awsCfg) - snsClient.Publish(context.Background(), tt.publishInput) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, "SNS.request", s.OperationName()) - assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") - assert.Equal(t, "Publish", s.Tag(tagAWSOperation)) - assert.Equal(t, "SNS", s.Tag(tagAWSService)) - assert.Equal(t, "SNS", s.Tag(tagService)) - assert.Equal(t, tt.expectedTagValue, s.Tag(tt.tagKey)) - - assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) - assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) - assert.Equal(t, "SNS.Publish", s.Tag(ext.ResourceName)) - assert.Equal(t, "aws.SNS", s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) - assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) - assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) - assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) - assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) - }) + } } } @@ -462,55 +463,43 @@ func TestAppendMiddlewareDynamodbGetItem(t *testing.T) { expectedStatusCode: 200, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() - - server := mockAWS(tt.expectedStatusCode) - defer server.Close() - - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil - }) - - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } - - AppendMiddleware(&awsCfg) - - dynamoClient := dynamodb.NewFromConfig(awsCfg) - dynamoClient.Query(context.Background(), &dynamodb.QueryInput{ - TableName: aws.String("MyTableName"), + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + server := mockAWS(tt.expectedStatusCode) + defer server.Close() + + awsCfg := newMockTestConfig(t, server.URL, cm, nil) + + dynamoClient := dynamodb.NewFromConfig(awsCfg) + _, _ = dynamoClient.Query(context.Background(), &dynamodb.QueryInput{ + TableName: aws.String("MyTableName"), + }) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, "DynamoDB.request", s.OperationName()) + assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") + assert.Equal(t, "Query", s.Tag(tagAWSOperation)) + assert.Equal(t, "DynamoDB", s.Tag(tagAWSService)) + assert.Equal(t, "DynamoDB", s.Tag(tagService)) + assert.Equal(t, "MyTableName", s.Tag(tagTableName)) + + assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) + assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) + assert.Equal(t, "DynamoDB.Query", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.DynamoDB", s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) + assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) + assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) }) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, "DynamoDB.request", s.OperationName()) - assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") - assert.Equal(t, "Query", s.Tag(tagAWSOperation)) - assert.Equal(t, "DynamoDB", s.Tag(tagAWSService)) - assert.Equal(t, "DynamoDB", s.Tag(tagService)) - assert.Equal(t, "MyTableName", s.Tag(tagTableName)) - - assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) - assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) - assert.Equal(t, "DynamoDB.Query", s.Tag(ext.ResourceName)) - assert.Equal(t, "aws.DynamoDB", s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) - assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) - assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) - assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) - assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) - }) + } } } @@ -532,57 +521,45 @@ func TestAppendMiddlewareKinesisPutRecord(t *testing.T) { expectedStatusCode: 200, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() - - server := mockAWS(tt.expectedStatusCode) - defer server.Close() - - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + server := mockAWS(tt.expectedStatusCode) + defer server.Close() + + awsCfg := newMockTestConfig(t, server.URL, cm, nil) + + kinesisClient := kinesis.NewFromConfig(awsCfg) + _, _ = kinesisClient.PutRecord(context.Background(), &kinesis.PutRecordInput{ + StreamName: aws.String("my-kinesis-stream"), + Data: []byte("Hello, Kinesis!"), + PartitionKey: aws.String("my-partition-key"), + }) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, "Kinesis.request", s.OperationName()) + assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") + assert.Equal(t, "PutRecord", s.Tag(tagAWSOperation)) + assert.Equal(t, "Kinesis", s.Tag(tagAWSService)) + assert.Equal(t, "Kinesis", s.Tag(tagService)) + assert.Equal(t, "my-kinesis-stream", s.Tag(tagStreamName)) + + assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) + assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) + assert.Equal(t, "Kinesis.PutRecord", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.Kinesis", s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) + assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) + assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) }) - - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } - - AppendMiddleware(&awsCfg) - - kinesisClient := kinesis.NewFromConfig(awsCfg) - kinesisClient.PutRecord(context.Background(), &kinesis.PutRecordInput{ - StreamName: aws.String("my-kinesis-stream"), - Data: []byte("Hello, Kinesis!"), - PartitionKey: aws.String("my-partition-key"), - }) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, "Kinesis.request", s.OperationName()) - assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") - assert.Equal(t, "PutRecord", s.Tag(tagAWSOperation)) - assert.Equal(t, "Kinesis", s.Tag(tagAWSService)) - assert.Equal(t, "Kinesis", s.Tag(tagService)) - assert.Equal(t, "my-kinesis-stream", s.Tag(tagStreamName)) - - assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) - assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) - assert.Equal(t, "Kinesis.PutRecord", s.Tag(ext.ResourceName)) - assert.Equal(t, "aws.Kinesis", s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) - assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) - assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) - assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) - assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) - }) + } } } @@ -604,55 +581,43 @@ func TestAppendMiddlewareEventBridgePutRule(t *testing.T) { expectedStatusCode: 200, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() - - server := mockAWS(tt.expectedStatusCode) - defer server.Close() - - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil - }) - - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } - - AppendMiddleware(&awsCfg) - - eventbridgeClient := eventbridge.NewFromConfig(awsCfg) - eventbridgeClient.PutRule(context.Background(), &eventbridge.PutRuleInput{ - Name: aws.String("my-event-rule-name"), + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + server := mockAWS(tt.expectedStatusCode) + defer server.Close() + + awsCfg := newMockTestConfig(t, server.URL, cm, nil) + + eventbridgeClient := eventbridge.NewFromConfig(awsCfg) + _, _ = eventbridgeClient.PutRule(context.Background(), &eventbridge.PutRuleInput{ + Name: aws.String("my-event-rule-name"), + }) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, "EventBridge.request", s.OperationName()) + assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") + assert.Equal(t, "PutRule", s.Tag(tagAWSOperation)) + assert.Equal(t, "EventBridge", s.Tag(tagAWSService)) + assert.Equal(t, "EventBridge", s.Tag(tagService)) + assert.Equal(t, "my-event-rule-name", s.Tag(tagRuleName)) + + assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) + assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) + assert.Equal(t, "EventBridge.PutRule", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.EventBridge", s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) + assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) + assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) }) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, "EventBridge.request", s.OperationName()) - assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") - assert.Equal(t, "PutRule", s.Tag(tagAWSOperation)) - assert.Equal(t, "EventBridge", s.Tag(tagAWSService)) - assert.Equal(t, "EventBridge", s.Tag(tagService)) - assert.Equal(t, "my-event-rule-name", s.Tag(tagRuleName)) - - assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) - assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) - assert.Equal(t, "EventBridge.PutRule", s.Tag(ext.ResourceName)) - assert.Equal(t, "aws.EventBridge", s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) - assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) - assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) - assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) - assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) - }) + } } } @@ -674,82 +639,59 @@ func TestAppendMiddlewareSfnDescribeStateMachine(t *testing.T) { expectedStatusCode: 200, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() - - server := mockAWS(tt.expectedStatusCode) - defer server.Close() - - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + server := mockAWS(tt.expectedStatusCode) + defer server.Close() + + awsCfg := newMockTestConfig(t, server.URL, cm, nil) + + sfnClient := sfn.NewFromConfig(awsCfg) + _, _ = sfnClient.DescribeStateMachine(context.Background(), &sfn.DescribeStateMachineInput{ + StateMachineArn: aws.String("arn:aws:states:us-west-2:123456789012:stateMachine:HelloWorld-StateMachine"), + }) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, "SFN.request", s.OperationName()) + assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") + assert.Equal(t, "DescribeStateMachine", s.Tag(tagAWSOperation)) + assert.Equal(t, "SFN", s.Tag(tagAWSService)) + assert.Equal(t, "SFN", s.Tag(tagService)) + assert.Equal(t, "HelloWorld-StateMachine", s.Tag(tagStateMachineName)) + + assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) + assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) + assert.Equal(t, "SFN.DescribeStateMachine", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.SFN", s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) + assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) + assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) }) - - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } - - AppendMiddleware(&awsCfg) - - sfnClient := sfn.NewFromConfig(awsCfg) - sfnClient.DescribeStateMachine(context.Background(), &sfn.DescribeStateMachineInput{ - StateMachineArn: aws.String("arn:aws:states:us-west-2:123456789012:stateMachine:HelloWorld-StateMachine"), - }) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, "SFN.request", s.OperationName()) - assert.Contains(t, s.Tag(tagAWSAgent), "aws-sdk-go-v2") - assert.Equal(t, "DescribeStateMachine", s.Tag(tagAWSOperation)) - assert.Equal(t, "SFN", s.Tag(tagAWSService)) - assert.Equal(t, "SFN", s.Tag(tagService)) - assert.Equal(t, "HelloWorld-StateMachine", s.Tag(tagStateMachineName)) - - assert.Equal(t, "eu-west-1", s.Tag(tagAWSRegion)) - assert.Equal(t, "eu-west-1", s.Tag(tagRegion)) - assert.Equal(t, "SFN.DescribeStateMachine", s.Tag(ext.ResourceName)) - assert.Equal(t, "aws.SFN", s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedStatusCode, s.Tag(ext.HTTPCode)) - assert.Equal(t, "POST", s.Tag(ext.HTTPMethod)) - assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) - assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) - assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) - }) + } } } func TestAppendMiddleware_WithNoTracer(t *testing.T) { - server := mockAWS(200) - defer server.Close() + for _, cm := range awsConfModes { + t.Run(string(cm), func(t *testing.T) { + server := mockAWS(200) + defer server.Close() - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil - }) + awsCfg := newMockTestConfig(t, server.URL, cm, nil) - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, + sqsClient := sqs.NewFromConfig(awsCfg) + _, err := sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) + assert.NoError(t, err) + }) } - - AppendMiddleware(&awsCfg) - - sqsClient := sqs.NewFromConfig(awsCfg) - _, err := sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) - assert.NoError(t, err) - } func mockAWS(statusCode int) *httptest.Server { @@ -757,7 +699,7 @@ func mockAWS(statusCode int) *httptest.Server { func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Amz-RequestId", "test_req") w.WriteHeader(statusCode) - w.Write([]byte(`{}`)) + _, _ = w.Write([]byte(`{}`)) })) } @@ -805,209 +747,195 @@ func TestAppendMiddleware_WithOpts(t *testing.T) { expectedRate: nil, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() - server := mockAWS(200) - defer server.Close() + server := mockAWS(200) + defer server.Close() - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil - }) + awsCfg := newMockTestConfig(t, server.URL, cm, nil, tt.opts...) - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } + sqsClient := sqs.NewFromConfig(awsCfg) + _, _ = sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) - AppendMiddleware(&awsCfg, tt.opts...) - - sqsClient := sqs.NewFromConfig(awsCfg) - sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) - - spans := mt.FinishedSpans() - assert.Len(t, spans, 1) - s := spans[0] - assert.Equal(t, tt.expectedServiceName, s.Tag(ext.ServiceName)) - assert.Equal(t, tt.expectedRate, s.Tag(ext.EventSampleRate)) - }) + spans := mt.FinishedSpans() + assert.Len(t, spans, 1) + s := spans[0] + assert.Equal(t, tt.expectedServiceName, s.Tag(ext.ServiceName)) + assert.Equal(t, tt.expectedRate, s.Tag(ext.EventSampleRate)) + }) + } } } func TestHTTPCredentials(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() + for _, cm := range awsConfModes { + t.Run(string(cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() - var auth string + var auth string + + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if enc, ok := r.Header["Authorization"]; ok { + encoded := strings.TrimPrefix(enc[0], "Basic ") + if b64, err := base64.StdEncoding.DecodeString(encoded); err == nil { + auth = string(b64) + } + } + + w.Header().Set("X-Amz-RequestId", "test_req") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{}`)) + })) + defer server.Close() - server := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if enc, ok := r.Header["Authorization"]; ok { - encoded := strings.TrimPrefix(enc[0], "Basic ") - if b64, err := base64.StdEncoding.DecodeString(encoded); err == nil { - auth = string(b64) - } - } + u, err := url.Parse(server.URL) + require.NoError(t, err) + u.User = url.UserPassword("myuser", "mypassword") - w.Header().Set("X-Amz-RequestId", "test_req") - w.WriteHeader(200) - w.Write([]byte(`{}`)) - })) - defer server.Close() + awsCfg := newMockTestConfig(t, u.String(), cm, nil) - u, err := url.Parse(server.URL) - require.NoError(t, err) - u.User = url.UserPassword("myuser", "mypassword") + sqsClient := sqs.NewFromConfig(awsCfg) + _, _ = sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: u.String(), - SigningRegion: "eu-west-1", - }, nil - }) + spans := mt.FinishedSpans() - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, + s := spans[0] + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.NotContains(t, s.Tag(ext.HTTPURL), "mypassword") + assert.NotContains(t, s.Tag(ext.HTTPURL), "myuser") + // Make sure we haven't modified the outgoing request, and the server still + // receives the auth request. + assert.Equal(t, auth, "myuser:mypassword") + }) } - - AppendMiddleware(&awsCfg) - - sqsClient := sqs.NewFromConfig(awsCfg) - sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) - - spans := mt.FinishedSpans() - - s := spans[0] - assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) - assert.NotContains(t, s.Tag(ext.HTTPURL), "mypassword") - assert.NotContains(t, s.Tag(ext.HTTPURL), "myuser") - // Make sure we haven't modified the outgoing request, and the server still - // receives the auth request. - assert.Equal(t, auth, "myuser:mypassword") } func TestNamingSchema(t *testing.T) { - genSpans := namingschematest.GenSpansFn(func(t *testing.T, serviceOverride string) []mocktracer.Span { - var opts []Option - if serviceOverride != "" { - opts = append(opts, WithServiceName(serviceOverride)) - } - mt := mocktracer.Start() - defer mt.Stop() - - awsCfg := newIntegrationTestConfig(t, opts...) - ctx := context.Background() - ec2Client := ec2.NewFromConfig(awsCfg) - s3Client := s3.NewFromConfig(awsCfg) - sqsClient := sqs.NewFromConfig(awsCfg) - snsClient := sns.NewFromConfig(awsCfg) - - _, err := ec2Client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{}) - require.NoError(t, err) - _, err = s3Client.ListBuckets(ctx, &s3.ListBucketsInput{}) - require.NoError(t, err) - _, err = sqsClient.ListQueues(ctx, &sqs.ListQueuesInput{}) - require.NoError(t, err) - _, err = snsClient.ListTopics(ctx, &sns.ListTopicsInput{}) - require.NoError(t, err) - - return mt.FinishedSpans() - }) - assertOpV0 := func(t *testing.T, spans []mocktracer.Span) { - require.Len(t, spans, 4) - assert.Equal(t, "EC2.request", spans[0].OperationName()) - assert.Equal(t, "S3.request", spans[1].OperationName()) - assert.Equal(t, "SQS.request", spans[2].OperationName()) - assert.Equal(t, "SNS.request", spans[3].OperationName()) - } - assertOpV1 := func(t *testing.T, spans []mocktracer.Span) { - require.Len(t, spans, 4) - assert.Equal(t, "aws.ec2.request", spans[0].OperationName()) - assert.Equal(t, "aws.s3.request", spans[1].OperationName()) - assert.Equal(t, "aws.sqs.request", spans[2].OperationName()) - assert.Equal(t, "aws.sns.request", spans[3].OperationName()) - } - serviceOverride := namingschematest.TestServiceOverride - wantServiceNameV0 := namingschematest.ServiceNameAssertions{ - WithDefaults: []string{"aws.EC2", "aws.S3", "aws.SQS", "aws.SNS"}, - WithDDService: []string{"aws.EC2", "aws.S3", "aws.SQS", "aws.SNS"}, - WithDDServiceAndOverride: []string{serviceOverride, serviceOverride, serviceOverride, serviceOverride}, + for _, cm := range awsConfModes { + t.Run(string(cm), func(t *testing.T) { + genSpans := namingschematest.GenSpansFn(func(t *testing.T, serviceOverride string) []mocktracer.Span { + var opts []Option + if serviceOverride != "" { + opts = append(opts, WithServiceName(serviceOverride)) + } + mt := mocktracer.Start() + defer mt.Stop() + + awsCfg := newIntegrationTestConfig(t, cm, opts...) + ctx := context.Background() + ec2Client := ec2.NewFromConfig(awsCfg) + s3Client := s3.NewFromConfig(awsCfg) + sqsClient := sqs.NewFromConfig(awsCfg) + snsClient := sns.NewFromConfig(awsCfg) + + _, err := ec2Client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{}) + require.NoError(t, err) + _, err = s3Client.ListBuckets(ctx, &s3.ListBucketsInput{}) + require.NoError(t, err) + _, err = sqsClient.ListQueues(ctx, &sqs.ListQueuesInput{}) + require.NoError(t, err) + _, err = snsClient.ListTopics(ctx, &sns.ListTopicsInput{}) + require.NoError(t, err) + + return mt.FinishedSpans() + }) + assertOpV0 := func(t *testing.T, spans []mocktracer.Span) { + require.Len(t, spans, 4) + assert.Equal(t, "EC2.request", spans[0].OperationName()) + assert.Equal(t, "S3.request", spans[1].OperationName()) + assert.Equal(t, "SQS.request", spans[2].OperationName()) + assert.Equal(t, "SNS.request", spans[3].OperationName()) + } + assertOpV1 := func(t *testing.T, spans []mocktracer.Span) { + require.Len(t, spans, 4) + assert.Equal(t, "aws.ec2.request", spans[0].OperationName()) + assert.Equal(t, "aws.s3.request", spans[1].OperationName()) + assert.Equal(t, "aws.sqs.request", spans[2].OperationName()) + assert.Equal(t, "aws.sns.request", spans[3].OperationName()) + } + serviceOverride := namingschematest.TestServiceOverride + wantServiceNameV0 := namingschematest.ServiceNameAssertions{ + WithDefaults: []string{"aws.EC2", "aws.S3", "aws.SQS", "aws.SNS"}, + WithDDService: []string{"aws.EC2", "aws.S3", "aws.SQS", "aws.SNS"}, + WithDDServiceAndOverride: []string{serviceOverride, serviceOverride, serviceOverride, serviceOverride}, + } + t.Run("ServiceName", namingschematest.NewServiceNameTest(genSpans, wantServiceNameV0)) + t.Run("SpanName", namingschematest.NewSpanNameTest(genSpans, assertOpV0, assertOpV1)) + }) } - t.Run("ServiceName", namingschematest.NewServiceNameTest(genSpans, wantServiceNameV0)) - t.Run("SpanName", namingschematest.NewSpanNameTest(genSpans, assertOpV0, assertOpV1)) } func TestMessagingNamingSchema(t *testing.T) { - genSpans := namingschematest.GenSpansFn(func(t *testing.T, serviceOverride string) []mocktracer.Span { - var opts []Option - if serviceOverride != "" { - opts = append(opts, WithServiceName(serviceOverride)) - } - mt := mocktracer.Start() - defer mt.Stop() + for _, cm := range awsConfModes { + t.Run(string(cm), func(t *testing.T) { + genSpans := namingschematest.GenSpansFn(func(t *testing.T, serviceOverride string) []mocktracer.Span { + var opts []Option + if serviceOverride != "" { + opts = append(opts, WithServiceName(serviceOverride)) + } + mt := mocktracer.Start() + defer mt.Stop() - awsCfg := newIntegrationTestConfig(t, opts...) - resourceName := "test-naming-schema-aws-v2" - ctx := context.Background() - sqsClient := sqs.NewFromConfig(awsCfg) - snsClient := sns.NewFromConfig(awsCfg) + awsCfg := newIntegrationTestConfig(t, cm, opts...) + resourceName := "test-naming-schema-aws-v2" + ctx := context.Background() + sqsClient := sqs.NewFromConfig(awsCfg) + snsClient := sns.NewFromConfig(awsCfg) - // create a SQS queue - sqsResp, err := sqsClient.CreateQueue(ctx, &sqs.CreateQueueInput{QueueName: aws.String(resourceName)}) - require.NoError(t, err) + // create a SQS queue + sqsResp, err := sqsClient.CreateQueue(ctx, &sqs.CreateQueueInput{QueueName: aws.String(resourceName)}) + require.NoError(t, err) - msg := &sqs.SendMessageInput{QueueUrl: sqsResp.QueueUrl, MessageBody: aws.String("body")} - _, err = sqsClient.SendMessage(ctx, msg) - require.NoError(t, err) + msg := &sqs.SendMessageInput{QueueUrl: sqsResp.QueueUrl, MessageBody: aws.String("body")} + _, err = sqsClient.SendMessage(ctx, msg) + require.NoError(t, err) - entry := types.SendMessageBatchRequestEntry{Id: aws.String("1"), MessageBody: aws.String("body")} - batchMsg := &sqs.SendMessageBatchInput{QueueUrl: sqsResp.QueueUrl, Entries: []types.SendMessageBatchRequestEntry{entry}} - _, err = sqsClient.SendMessageBatch(ctx, batchMsg) - require.NoError(t, err) + entry := types.SendMessageBatchRequestEntry{Id: aws.String("1"), MessageBody: aws.String("body")} + batchMsg := &sqs.SendMessageBatchInput{QueueUrl: sqsResp.QueueUrl, Entries: []types.SendMessageBatchRequestEntry{entry}} + _, err = sqsClient.SendMessageBatch(ctx, batchMsg) + require.NoError(t, err) - // create an SNS topic - snsResp, err := snsClient.CreateTopic(ctx, &sns.CreateTopicInput{Name: aws.String(resourceName)}) - require.NoError(t, err) + // create an SNS topic + snsResp, err := snsClient.CreateTopic(ctx, &sns.CreateTopicInput{Name: aws.String(resourceName)}) + require.NoError(t, err) - _, err = snsClient.Publish(ctx, &sns.PublishInput{TopicArn: snsResp.TopicArn, Message: aws.String("message")}) - require.NoError(t, err) + _, err = snsClient.Publish(ctx, &sns.PublishInput{TopicArn: snsResp.TopicArn, Message: aws.String("message")}) + require.NoError(t, err) - return mt.FinishedSpans() - }) - assertOpV0 := func(t *testing.T, spans []mocktracer.Span) { - require.Len(t, spans, 5) - assert.Equal(t, "SQS.request", spans[0].OperationName()) - assert.Equal(t, "SQS.request", spans[1].OperationName()) - assert.Equal(t, "SQS.request", spans[2].OperationName()) - assert.Equal(t, "SNS.request", spans[3].OperationName()) - assert.Equal(t, "SNS.request", spans[4].OperationName()) - } - assertOpV1 := func(t *testing.T, spans []mocktracer.Span) { - require.Len(t, spans, 5) - assert.Equal(t, "aws.sqs.request", spans[0].OperationName()) - assert.Equal(t, "aws.sqs.send", spans[1].OperationName()) - assert.Equal(t, "aws.sqs.send", spans[2].OperationName()) - assert.Equal(t, "aws.sns.request", spans[3].OperationName()) - assert.Equal(t, "aws.sns.send", spans[4].OperationName()) - } - serviceOverride := namingschematest.TestServiceOverride - wantServiceNameV0 := namingschematest.ServiceNameAssertions{ - WithDefaults: []string{"aws.SQS", "aws.SQS", "aws.SQS", "aws.SNS", "aws.SNS"}, - WithDDService: []string{"aws.SQS", "aws.SQS", "aws.SQS", "aws.SNS", "aws.SNS"}, - WithDDServiceAndOverride: repeat(serviceOverride, 5), + return mt.FinishedSpans() + }) + assertOpV0 := func(t *testing.T, spans []mocktracer.Span) { + require.Len(t, spans, 5) + assert.Equal(t, "SQS.request", spans[0].OperationName()) + assert.Equal(t, "SQS.request", spans[1].OperationName()) + assert.Equal(t, "SQS.request", spans[2].OperationName()) + assert.Equal(t, "SNS.request", spans[3].OperationName()) + assert.Equal(t, "SNS.request", spans[4].OperationName()) + } + assertOpV1 := func(t *testing.T, spans []mocktracer.Span) { + require.Len(t, spans, 5) + assert.Equal(t, "aws.sqs.request", spans[0].OperationName()) + assert.Equal(t, "aws.sqs.send", spans[1].OperationName()) + assert.Equal(t, "aws.sqs.send", spans[2].OperationName()) + assert.Equal(t, "aws.sns.request", spans[3].OperationName()) + assert.Equal(t, "aws.sns.send", spans[4].OperationName()) + } + serviceOverride := namingschematest.TestServiceOverride + wantServiceNameV0 := namingschematest.ServiceNameAssertions{ + WithDefaults: []string{"aws.SQS", "aws.SQS", "aws.SQS", "aws.SNS", "aws.SNS"}, + WithDDService: []string{"aws.SQS", "aws.SQS", "aws.SQS", "aws.SNS", "aws.SNS"}, + WithDDServiceAndOverride: repeat(serviceOverride, 5), + } + t.Run("ServiceName", namingschematest.NewServiceNameTest(genSpans, wantServiceNameV0)) + t.Run("SpanName", namingschematest.NewSpanNameTest(genSpans, assertOpV0, assertOpV1)) + }) } - t.Run("ServiceName", namingschematest.NewServiceNameTest(genSpans, wantServiceNameV0)) - t.Run("SpanName", namingschematest.NewSpanNameTest(genSpans, assertOpV0, assertOpV1)) } func repeat(s string, n int) []string { @@ -1035,7 +963,8 @@ func TestWithErrorCheck(t *testing.T) { return true })}, errExist: true, - }, { + }, + { name: "with errCheck false", opts: []Option{WithErrorCheck(func(err error) bool { return false @@ -1043,37 +972,25 @@ func TestWithErrorCheck(t *testing.T) { errExist: false, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt := mocktracer.Start() - defer mt.Stop() + for _, cm := range awsConfModes { + for _, tt := range tests { + t.Run(fmt.Sprintf("%s %s", tt.name, cm), func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() - server := mockAWS(400) - defer server.Close() + server := mockAWS(400) + defer server.Close() - resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { - return aws.Endpoint{ - PartitionID: "aws", - URL: server.URL, - SigningRegion: "eu-west-1", - }, nil - }) + awsCfg := newMockTestConfig(t, server.URL, cm, nil, tt.opts...) - awsCfg := aws.Config{ - Region: "eu-west-1", - Credentials: aws.AnonymousCredentials{}, - EndpointResolver: resolver, - } + sqsClient := sqs.NewFromConfig(awsCfg) + _, _ = sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) - AppendMiddleware(&awsCfg, tt.opts...) - - sqsClient := sqs.NewFromConfig(awsCfg) - sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) - - spans := mt.FinishedSpans() - assert.Len(t, spans, 1) - s := spans[0] - assert.Equal(t, tt.errExist, s.Tag(ext.Error) != nil) - }) + spans := mt.FinishedSpans() + assert.Len(t, spans, 1) + s := spans[0] + assert.Equal(t, tt.errExist, s.Tag(ext.Error) != nil) + }) + } } } diff --git a/contrib/aws/aws-sdk-go-v2/aws/example_test.go b/contrib/aws/aws-sdk-go-v2/aws/example_test.go index 5aa271f153..b91bb34065 100644 --- a/contrib/aws/aws-sdk-go-v2/aws/example_test.go +++ b/contrib/aws/aws-sdk-go-v2/aws/example_test.go @@ -24,5 +24,15 @@ func Example() { awstrace.AppendMiddleware(&awsCfg) sqsClient := sqs.NewFromConfig(awsCfg) - sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) + _, _ = sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) +} + +func ExampleLoadOptions() { + awsCfg, err := awscfg.LoadDefaultConfig(context.TODO(), awstrace.WithDataDogTracer()) + if err != nil { + log.Fatalf(err.Error()) + } + + sqsClient := sqs.NewFromConfig(awsCfg) + _, _ = sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) }