Skip to content

Commit

Permalink
chore(codegen): pass ServiceShape service to ShapeId.getName() (#2198)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivikr committed Mar 30, 2021
1 parent 94aefce commit 325d6b4
Show file tree
Hide file tree
Showing 15 changed files with 55 additions and 35 deletions.
Expand Up @@ -125,7 +125,7 @@ private static boolean operationUsesAwsAuth(Model model, ServiceShape service, O
if (testServiceId(service, "STS")) {
Boolean isUnsignedCommand = SetUtils
.of("AssumeRoleWithWebIdentity", "AssumeRoleWithSAML")
.contains(operation.getId().getName());
.contains(operation.getId().getName(service));
return !isUnsignedCommand;
}

Expand Down
Expand Up @@ -68,19 +68,19 @@ public List<RuntimeClientPlugin> getClientPlugins() {
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.EC2_MIDDLEWARE.dependency,
"CopySnapshotPresignedUrl", HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> o.getId().getName().equals("CopySnapshot")
.operationPredicate((m, s, o) -> o.getId().getName(s).equals("CopySnapshot")
&& testServiceId(s, "EC2"))
.build(),
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.MACHINELEARNING_MIDDLEWARE.dependency, "PredictEndpoint",
HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> o.getId().getName().equals("Predict")
.operationPredicate((m, s, o) -> o.getId().getName(s).equals("Predict")
&& testServiceId(s, "Machine Learning"))
.build(),
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.ROUTE53_MIDDLEWARE.dependency,
"ChangeResourceRecordSets", HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> o.getId().getName().equals("ChangeResourceRecordSets")
.operationPredicate((m, s, o) -> o.getId().getName(s).equals("ChangeResourceRecordSets")
&& testServiceId(s, "Route 53"))
.build(),
RuntimeClientPlugin.builder()
Expand All @@ -92,19 +92,19 @@ && testServiceId(s, "Route 53"))
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.SQS_MIDDLEWARE.dependency, "SendMessage",
HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> o.getId().getName().equals("SendMessage")
.operationPredicate((m, s, o) -> o.getId().getName(s).equals("SendMessage")
&& testServiceId(s, "SQS"))
.build(),
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.SQS_MIDDLEWARE.dependency, "SendMessageBatch",
HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> o.getId().getName().equals("SendMessageBatch")
.operationPredicate((m, s, o) -> o.getId().getName(s).equals("SendMessageBatch")
&& testServiceId(s, "SQS"))
.build(),
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.SQS_MIDDLEWARE.dependency, "ReceiveMessage",
HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> o.getId().getName().equals("ReceiveMessage")
.operationPredicate((m, s, o) -> o.getId().getName(s).equals("ReceiveMessage")
&& testServiceId(s, "SQS"))
.build(),
RuntimeClientPlugin.builder()
Expand Down
Expand Up @@ -43,13 +43,13 @@ public List<RuntimeClientPlugin> getClientPlugins() {
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.RDS_MIDDLEWARE.dependency, "CrossRegionPresignedUrl",
HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> RDS_PRESIGNED_URL_OPERATIONS.contains(o.getId().getName())
.operationPredicate((m, s, o) -> RDS_PRESIGNED_URL_OPERATIONS.contains(o.getId().getName(s))
&& testServiceId(s, "RDS"))
.build(),
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.RDS_MIDDLEWARE.dependency, "CrossRegionPresignedUrl",
HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> SHARED_PRESIGNED_URL_OPERATIONS.contains(o.getId().getName())
.operationPredicate((m, s, o) -> SHARED_PRESIGNED_URL_OPERATIONS.contains(o.getId().getName(s))
&& (testServiceId(s, "RDS") || testServiceId(s, "DocDB") || testServiceId(s, "Neptune")))
.build()
);
Expand Down
Expand Up @@ -127,7 +127,7 @@ public List<RuntimeClientPlugin> getClientPlugins() {
.withConventions(AwsDependency.S3_MIDDLEWARE.dependency, "throw200Exceptions",
HAS_MIDDLEWARE)
.operationPredicate(
(m, s, o) -> EXCEPTIONS_OF_200_OPERATIONS.contains(o.getId().getName())
(m, s, o) -> EXCEPTIONS_OF_200_OPERATIONS.contains(o.getId().getName(s))
&& testServiceId(s))
.build(),
RuntimeClientPlugin.builder()
Expand All @@ -143,7 +143,7 @@ && testServiceId(s))
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.LOCATION_CONSTRAINT.dependency, "LocationConstraint",
HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> o.getId().getName().equals("CreateBucket")
.operationPredicate((m, s, o) -> o.getId().getName(s).equals("CreateBucket")
&& testServiceId(s))
.build(),
/**
Expand All @@ -158,13 +158,13 @@ && testServiceId(s))
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.BUCKET_ENDPOINT_MIDDLEWARE.dependency, "BucketEndpoint",
HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> !NON_BUCKET_ENDPOINT_OPERATIONS.contains(o.getId().getName())
.operationPredicate((m, s, o) -> !NON_BUCKET_ENDPOINT_OPERATIONS.contains(o.getId().getName(s))
&& testServiceId(s))
.build(),
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.BODY_CHECKSUM.dependency, "ApplyMd5BodyChecksum",
HAS_MIDDLEWARE)
.operationPredicate((m, s, o) -> S3_MD5_OPERATIONS.contains(o.getId().getName())
.operationPredicate((m, s, o) -> S3_MD5_OPERATIONS.contains(o.getId().getName(s))
&& testServiceId(s))
.build()
);
Expand Down
Expand Up @@ -52,15 +52,15 @@ public List<RuntimeClientPlugin> getClientPlugins() {
"ProcessArnables",
HAS_MIDDLEWARE
)
.operationPredicate((m, s, o) -> isS3Control(s) && isArnableOperation(o))
.operationPredicate((m, s, o) -> isS3Control(s) && isArnableOperation(o, s))
.build(),
RuntimeClientPlugin.builder()
.withConventions(
AwsDependency.S3_CONTROL_MIDDLEWARE.dependency,
"RedirectFromPostId",
HAS_MIDDLEWARE
)
.operationPredicate((m, s, o) -> isS3Control(s) && !isArnableOperation(o))
.operationPredicate((m, s, o) -> isS3Control(s) && !isArnableOperation(o, s))
.build());
}

Expand All @@ -70,9 +70,10 @@ public Model preprocessModel(PluginContext context, TypeScriptSettings settings)
if (!isS3Control(settings.getService(model))) {
return model;
}
ServiceShape serviceShape = model.expectShape(settings.getService(), ServiceShape.class);
return ModelTransformer.create().mapShapes(model, shape -> {
Optional<MemberShape> modified = shape.asMemberShape()
.filter(memberShape -> memberShape.getTarget().getName().equals("AccountId"))
.filter(memberShape -> memberShape.getTarget().getName(serviceShape).equals("AccountId"))
.filter(memberShape -> model.expectShape(memberShape.getTarget()).isStringShape())
.filter(memberShape -> memberShape.isRequired())
.map(memberShape -> Shape.shapeToBuilder(memberShape).removeTrait(RequiredTrait.ID).build());
Expand All @@ -85,8 +86,8 @@ private static boolean isS3Control(ServiceShape service) {
return serviceId.equals("S3 Control");
}

private static boolean isArnableOperation(OperationShape operation) {
String operationName = operation.getId().getName();
private static boolean isArnableOperation(OperationShape operation, ServiceShape serviceShape) {
String operationName = operation.getId().getName(serviceShape);
return !operationName.equals("CreateBucket") && !operationName.equals("ListRegionalBuckets");
}
}
Expand Up @@ -19,6 +19,7 @@
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait;
import software.amazon.smithy.codegen.core.SymbolReference;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;
Expand Down Expand Up @@ -134,8 +135,9 @@ protected void serializeInputDocument(
writer.write("...$L,",
inputStructure.accept(new QueryMemberSerVisitor(context, "input", Format.DATE_TIME)));
// Set the protocol required values.
writer.write("Action: $S,", operation.getId().getName());
writer.write("Version: $S,", context.getService().getVersion());
ServiceShape serviceShape = context.getService();
writer.write("Action: $S,", operation.getId().getName(serviceShape));
writer.write("Version: $S,", serviceShape.getVersion());
});
}

Expand Down
Expand Up @@ -76,7 +76,7 @@ public void writeAdditionalFiles(

TopDownIndex topDownIndex = TopDownIndex.of(model);
OperationShape firstOperation = topDownIndex.getContainedOperations(service).iterator().next();
String operationName = firstOperation.getId().getName();
String operationName = firstOperation.getId().getName(service);
resource = resource.replaceAll(Pattern.quote("${commandName}"), operationName);
resource = resource.replaceAll(Pattern.quote("${operationName}"),
operationName.substring(0, 1).toLowerCase() + operationName.substring(1));
Expand Down
Expand Up @@ -26,6 +26,7 @@
import software.amazon.smithy.model.neighbor.Walker;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeVisitor;
import software.amazon.smithy.model.traits.IdempotencyTokenTrait;
Expand Down Expand Up @@ -178,8 +179,9 @@ static boolean generateUndefinedQueryInputBody(GenerationContext context, Operat
// Set the form encoded string.
writer.openBlock("const body = buildFormUrlencodedString({", "});", () -> {
// Set the protocol required values.
writer.write("Action: $S,", operation.getId().getName());
writer.write("Version: $S,", context.getService().getVersion());
ServiceShape serviceShape = context.getService();
writer.write("Action: $S,", operation.getId().getName(serviceShape));
writer.write("Version: $S,", serviceShape.getVersion());
});

return true;
Expand Down
Expand Up @@ -19,6 +19,7 @@
import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait;
import software.amazon.smithy.codegen.core.SymbolReference;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;
Expand Down Expand Up @@ -134,8 +135,9 @@ protected void serializeInputDocument(
writer.write("...$L,",
inputStructure.accept(new QueryMemberSerVisitor(context, "input", Format.DATE_TIME)));
// Set the protocol required values.
writer.write("Action: $S,", operation.getId().getName());
writer.write("Version: $S,", context.getService().getVersion());
ServiceShape serviceShape = context.getService();
writer.write("Action: $S,", operation.getId().getName(serviceShape));
writer.write("Version: $S,", serviceShape.getVersion());
});
}

Expand All @@ -160,7 +162,7 @@ protected void deserializeOutputDocument(
) {
TypeScriptWriter writer = context.getWriter();

String dataSource = "data." + operation.getId().getName() + "Result";
String dataSource = "data." + operation.getId().getName(context.getService()) + "Result";
writer.write("contents = $L;",
outputStructure.accept(new XmlMemberDeserVisitor(context, dataSource, Format.DATE_TIME)));
}
Expand Down
Expand Up @@ -24,6 +24,7 @@
import software.amazon.smithy.model.knowledge.HttpBinding.Location;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;
Expand Down Expand Up @@ -153,18 +154,19 @@ protected void serializeInputDocument(
return;
}

ServiceShape serviceShape = context.getService();
SymbolProvider symbolProvider = context.getSymbolProvider();
ShapeId inputShapeId = documentBindings.get(0).getMember().getContainer();

// Start with the XML declaration.
writer.write("body = \"<?xml version=\\\"1.0\\\" encoding=\\\"UTF-8\\\"?>\";");

writer.addImport("XmlNode", "__XmlNode", "@aws-sdk/xml-builder");
writer.write("const bodyNode = new __XmlNode($S);", inputShapeId.getName());
writer.write("const bodyNode = new __XmlNode($S);", inputShapeId.getName(serviceShape));

// Add @xmlNamespace value of the service to the root node,
// fall back to one from the input shape.
boolean serviceXmlns = AwsProtocolUtils.writeXmlNamespace(context, context.getService(), "bodyNode");
boolean serviceXmlns = AwsProtocolUtils.writeXmlNamespace(context, serviceShape, "bodyNode");
if (!serviceXmlns) {
StructureShape inputShape = context.getModel().expectShape(inputShapeId, StructureShape.class);
AwsProtocolUtils.writeXmlNamespace(context, inputShape, "bodyNode");
Expand Down
Expand Up @@ -17,6 +17,7 @@

import java.util.Set;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.TimestampFormatTrait.Format;
Expand Down Expand Up @@ -88,7 +89,8 @@ protected void writeDefaultHeaders(GenerationContext context, OperationShape ope
// AWS JSON RPC protocols use a combination of the service and operation shape names,
// separated by a '.' character, for the target header.
TypeScriptWriter writer = context.getWriter();
String target = context.getService().getId().getName() + "." + operation.getId().getName();
ServiceShape serviceShape = context.getService();
String target = serviceShape.getId().getName(serviceShape) + "." + operation.getId().getName(serviceShape);
writer.write("'x-amz-target': $S,", target);
}

Expand Down
Expand Up @@ -23,6 +23,7 @@
import software.amazon.smithy.model.shapes.DocumentShape;
import software.amazon.smithy.model.shapes.MapShape;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.shapes.UnionShape;
Expand Down Expand Up @@ -152,9 +153,10 @@ public void serializeStructure(GenerationContext context, StructureShape shape)
public void serializeUnion(GenerationContext context, UnionShape shape) {
TypeScriptWriter writer = context.getWriter();
Model model = context.getModel();
ServiceShape serviceShape = context.getService();

// Visit over the union type, then get the right serialization for the member.
writer.openBlock("return $L.visit(input, {", "});", shape.getId().getName(), () -> {
writer.openBlock("return $L.visit(input, {", "});", shape.getId().getName(serviceShape), () -> {
// Use a TreeMap to sort the members.
Map<String, MemberShape> members = new TreeMap<>(shape.getAllMembers());
members.forEach((memberName, memberShape) -> {
Expand Down
Expand Up @@ -21,6 +21,7 @@
import software.amazon.smithy.model.shapes.DocumentShape;
import software.amazon.smithy.model.shapes.MapShape;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.shapes.UnionShape;
Expand Down Expand Up @@ -272,12 +273,13 @@ protected boolean isFlattenedMember(GenerationContext context, MemberShape membe
@Override
protected void serializeUnion(GenerationContext context, UnionShape shape) {
TypeScriptWriter writer = context.getWriter();
ServiceShape serviceShape = context.getService();

// Set up a location to store the entry pair.
writer.write("const entries: any = {};");

// Visit over the union type, then get the right serialization for the member.
writer.openBlock("$L.visit(input, {", "});", shape.getId().getName(), () -> {
writer.openBlock("$L.visit(input, {", "});", shape.getId().getName(serviceShape), () -> {
shape.getAllMembers().forEach((memberName, memberShape) -> {
writer.openBlock("$L: value => {", "},", memberName, () -> {
serializeNamedMember(context, memberName, memberShape, "value");
Expand Down
Expand Up @@ -49,9 +49,11 @@
* @see <a href="https://awslabs.github.io/smithy/spec/xml.html">Smithy XML traits.</a>
*/
final class XmlMemberSerVisitor extends DocumentMemberSerVisitor {
private final GenerationContext context;

XmlMemberSerVisitor(GenerationContext context, String dataSource, Format defaultTimestampFormat) {
super(context, dataSource, defaultTimestampFormat);
this.context = context;
}

@Override
Expand Down Expand Up @@ -112,7 +114,7 @@ String getAsXmlText(Shape shape, String dataSource) {
// Handle the @xmlName trait for the shape itself.
String nodeName = shape.getTrait(XmlNameTrait.class)
.map(XmlNameTrait::getValue)
.orElse(shape.getId().getName());
.orElse(shape.getId().getName(context.getService()));

TypeScriptWriter writer = getContext().getWriter();
writer.addImport("XmlNode", "__XmlNode", "@aws-sdk/xml-builder");
Expand Down

0 comments on commit 325d6b4

Please sign in to comment.