diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/AwsCustomGoDependency.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/AwsCustomGoDependency.java index 4e3d20b59c6..cc354182804 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/AwsCustomGoDependency.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/AwsCustomGoDependency.java @@ -35,6 +35,8 @@ public final class AwsCustomGoDependency extends AwsGoDependency { "service/internal/accept-encoding", null, Versions.INTERNAL_ACCEPTENCODING, "acceptencodingcust"); public static final GoDependency KINESIS_CUSTOMIZATION = aws( "service/kinesis/internal/customizations", "kinesiscust"); + public static final GoDependency MACHINE_LEARNING_CUSTOMIZATION = aws( + "service/machinelearning/internal/customizations", "mlcust"); private AwsCustomGoDependency() { super(); diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/MachineLearningCustomizations.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/MachineLearningCustomizations.java new file mode 100644 index 00000000000..ebdaf34ec85 --- /dev/null +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/MachineLearningCustomizations.java @@ -0,0 +1,97 @@ +package software.amazon.smithy.aws.go.codegen.customization; + +import java.util.List; +import software.amazon.smithy.aws.traits.ServiceTrait; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.GoDelegator; +import software.amazon.smithy.go.codegen.GoSettings; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.go.codegen.SymbolUtils; +import software.amazon.smithy.go.codegen.integration.GoIntegration; +import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar; +import software.amazon.smithy.go.codegen.integration.ProtocolUtils; +import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.StructureShape; +import software.amazon.smithy.utils.ListUtils; + +public class MachineLearningCustomizations implements GoIntegration { + private static final String ADD_PREDICT_ENDPOINT = "AddPredictEndpointMiddleware"; + private static final String ENDPOINT_ACCESSOR = "getPredictEndpoint"; + + @Override + public byte getOrder() { + // This needs to be run after the generic endpoint resolver gets added + return 50; + } + + @Override + public List getClientPlugins() { + return ListUtils.of( + RuntimeClientPlugin.builder() + .operationPredicate(MachineLearningCustomizations::isPredict) + .registerMiddleware(MiddlewareRegistrar.builder() + .resolvedFunction(SymbolUtils.createValueSymbolBuilder(ADD_PREDICT_ENDPOINT, + AwsCustomGoDependency.MACHINE_LEARNING_CUSTOMIZATION).build()) + .functionArguments(ListUtils.of( + SymbolUtils.createValueSymbolBuilder(ENDPOINT_ACCESSOR).build() + )) + .build()) + .build() + ); + } + + @Override + public void writeAdditionalFiles( + GoSettings settings, + Model model, + SymbolProvider symbolProvider, + GoDelegator goDelegator + ) { + ServiceShape service = settings.getService(model); + if (!isMachineLearning(model, service)) { + return; + } + + service.getAllOperations().stream() + .filter(shapeId -> shapeId.getName().equalsIgnoreCase("Predict")) + .findAny() + .map(model::expectShape) + .flatMap(Shape::asOperationShape) + .ifPresent(operation -> { + goDelegator.useShapeWriter(operation, writer -> writeEndpointAccessor( + writer, model, symbolProvider, operation)); + }); + } + + private void writeEndpointAccessor( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + OperationShape operation + ) { + StructureShape input = ProtocolUtils.expectInput(model, operation); + writer.openBlock("func $L(input interface{}) (*string, error) {", "}", ENDPOINT_ACCESSOR, () -> { + writer.write("in, ok := input.($P)", symbolProvider.toSymbol(input)); + writer.openBlock("if !ok {", "}", () -> { + writer.addUseImports(SmithyGoDependency.SMITHY); + writer.addUseImports(SmithyGoDependency.FMT); + writer.write("return nil, &smithy.SerializationError{Err: fmt.Errorf(" + + "\"expected $P, but was %T\", input)}", symbolProvider.toSymbol(input)); + }); + writer.write("return in.PredictEndpoint, nil"); + }); + } + + private static boolean isPredict(Model model, ServiceShape service, OperationShape operation) { + return isMachineLearning(model, service) && operation.getId().getName().equalsIgnoreCase("Predict"); + } + + private static boolean isMachineLearning(Model model, ServiceShape service) { + return service.expectTrait(ServiceTrait.class).getSdkId().equalsIgnoreCase("Machine Learning"); + } +} diff --git a/codegen/smithy-aws-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration b/codegen/smithy-aws-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration index 519a33b9369..7107577f00f 100644 --- a/codegen/smithy-aws-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration +++ b/codegen/smithy-aws-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration @@ -18,5 +18,6 @@ software.amazon.smithy.aws.go.codegen.customization.S3ResponseErrorWrapper software.amazon.smithy.aws.go.codegen.customization.S3MetadataRetriever software.amazon.smithy.aws.go.codegen.customization.S3ContentSHA256Header software.amazon.smithy.aws.go.codegen.customization.BackfillS3ObjectSizeMemberShapeType +software.amazon.smithy.aws.go.codegen.customization.MachineLearningCustomizations software.amazon.smithy.aws.go.codegen.customization.S3AcceptEncodingGzip software.amazon.smithy.aws.go.codegen.customization.KinesisCustomizations diff --git a/service/machinelearning/api_op_Predict.go b/service/machinelearning/api_op_Predict.go index d1ce8e53cbe..13a12e0d1e9 100644 --- a/service/machinelearning/api_op_Predict.go +++ b/service/machinelearning/api_op_Predict.go @@ -4,9 +4,11 @@ package machinelearning import ( "context" + "fmt" awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + mlcust "github.com/aws/aws-sdk-go-v2/service/machinelearning/internal/customizations" "github.com/aws/aws-sdk-go-v2/service/machinelearning/types" smithy "github.com/awslabs/smithy-go" "github.com/awslabs/smithy-go/middleware" @@ -35,6 +37,7 @@ func (c *Client) Predict(ctx context.Context, params *PredictInput, optFns ...fu smithyhttp.AddCloseResponseBodyMiddleware(stack) addOpPredictValidationMiddleware(stack) stack.Initialize.Add(newServiceMetadataMiddleware_opPredict(options.Region), middleware.Before) + mlcust.AddPredictEndpointMiddleware(stack, getPredictEndpoint) addRequestIDRetrieverMiddleware(stack) addResponseErrorMiddleware(stack) @@ -105,3 +108,11 @@ func newServiceMetadataMiddleware_opPredict(region string) awsmiddleware.Registe OperationName: "Predict", } } + +func getPredictEndpoint(input interface{}) (*string, error) { + in, ok := input.(*PredictInput) + if !ok { + return nil, &smithy.SerializationError{Err: fmt.Errorf("expected *PredictInput, but was %T", input)} + } + return in.PredictEndpoint, nil +} diff --git a/service/machinelearning/internal/customizations/doc.go b/service/machinelearning/internal/customizations/doc.go new file mode 100644 index 00000000000..a7d6dac5095 --- /dev/null +++ b/service/machinelearning/internal/customizations/doc.go @@ -0,0 +1,13 @@ +/* +Package customizations provides customizations for the Machine Learning API client. + +The Machine Learning API client uses one customization to support the PredictEndpoint +input parameter. + +Predict Endpoint + +The predict endpoint customization runs after normal endpoint resolution happens. If +the user has provided a value for PredictEndpoint then this customization will +overwrite the request's endpoint with that value. + */ +package customizations diff --git a/service/machinelearning/internal/customizations/predictendpoint.go b/service/machinelearning/internal/customizations/predictendpoint.go new file mode 100644 index 00000000000..b0818f0f3a4 --- /dev/null +++ b/service/machinelearning/internal/customizations/predictendpoint.go @@ -0,0 +1,58 @@ +package customizations + +import ( + "context" + "fmt" + "github.com/awslabs/smithy-go" + "github.com/awslabs/smithy-go/middleware" + smithyhttp "github.com/awslabs/smithy-go/transport/http" + "net/url" +) + +// AddPredictEndpointMiddleware adds the middleware required to set the endpoint +// based on Predict's PredictEndpoint input member. +func AddPredictEndpointMiddleware(stack *middleware.Stack, endpoint func(interface{}) (*string, error)) { + stack.Serialize.Insert(&predictEndpointMiddleware{}, "ResolveEndpoint", middleware.After) +} + +// predictEndpointMiddleware rewrites the endpoint with whatever is specified in the +// operation input if it is non-nil and non-empty. +type predictEndpointMiddleware struct{ + fetchPredictEndpoint func(interface{}) (*string, error) +} + +// ID returns the id for the middleware. +func (*predictEndpointMiddleware) ID() string { return "MachineLearning:PredictEndpoint" } + +// HandleSerialize implements the SerializeMiddleware interface. +func (m *predictEndpointMiddleware) HandleSerialize( + ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, +) ( + out middleware.SerializeOutput, metadata middleware.Metadata, err error, +) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, &smithy.SerializationError{ + Err: fmt.Errorf("unknown request type %T", in.Request), + } + } + + endpoint, err := m.fetchPredictEndpoint(in.Parameters) + if err != nil { + return out, metadata, &smithy.SerializationError{ + Err: fmt.Errorf("failed to fetch PredictEndpoint value, %v", err), + } + } + + if endpoint != nil && len(*endpoint) != 0 { + uri, err := url.Parse(*endpoint) + if err != nil { + return out, metadata, &smithy.SerializationError{ + Err: fmt.Errorf("unable to parse predict endpoint, %v", err), + } + } + req.URL = uri + } + + return next.HandleSerialize(ctx, in) +} diff --git a/service/machinelearning/internal/customizations/predictendpoint_test.go b/service/machinelearning/internal/customizations/predictendpoint_test.go new file mode 100644 index 00000000000..0421aff9478 --- /dev/null +++ b/service/machinelearning/internal/customizations/predictendpoint_test.go @@ -0,0 +1,75 @@ +package customizations + +import ( + "context" + "github.com/awslabs/smithy-go/middleware" + "github.com/awslabs/smithy-go/ptr" + smithyhttp "github.com/awslabs/smithy-go/transport/http" + "strings" + "testing" +) + +func TestPredictEndpointMiddleware(t *testing.T) { + cases := map[string]struct { + PredictEndpoint *string + ExpectedEndpoint string + ExpectedErr string + }{ + "nil endpoint": {}, + "empty endpoint": { + PredictEndpoint: ptr.String(""), + }, + "invalid endpoint": { + PredictEndpoint: ptr.String("::::::::"), + ExpectedErr: "unable to parse", + }, + "valid endpoint": { + PredictEndpoint: ptr.String("https://example.amazonaws.com/"), + ExpectedEndpoint: "https://example.amazonaws.com/", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + m := &predictEndpointMiddleware{ + fetchPredictEndpoint: func(i interface{}) (*string, error) { + return c.PredictEndpoint, nil + }, + } + _, _, err := m.HandleSerialize(context.Background(), + middleware.SerializeInput{ + Request: smithyhttp.NewStackRequest(), + }, + middleware.SerializeHandlerFunc( + func(ctx context.Context, input middleware.SerializeInput) ( + output middleware.SerializeOutput, metadata middleware.Metadata, err error, + ) { + + req, ok := input.Request.(*smithyhttp.Request) + if !ok || req == nil { + t.Fatalf("expect smithy request, got %T", input.Request) + } + + if c.ExpectedEndpoint != req.URL.String() { + t.Errorf("expected url to be `%v`, but was `%v`", c.ExpectedEndpoint, req.URL.String()) + } + + return output, metadata, err + }), + ) + if len(c.ExpectedErr) != 0 { + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := c.ExpectedErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %v, got %v", e, a) + } + } else { + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + } + }) + } + +}