Skip to content

Commit

Permalink
fix awsjson error deserialization to not expect string code (#2489)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws committed Feb 15, 2024
1 parent 6ae62c2 commit a264562
Show file tree
Hide file tree
Showing 142 changed files with 33,740 additions and 53,137 deletions.
144 changes: 144 additions & 0 deletions .changelog/ee9db7c34c7946709eba3af651c48631.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
{
"id": "ee9db7c3-4c79-4670-9eba-3af651c48631",
"type": "bugfix",
"collapse": true,
"description": "Correct failure to determine the error type in awsJson services that could occur when errors were modeled with a non-string `code` field.",
"modules": [
"service/acm",
"service/acmpca",
"service/alexaforbusiness",
"service/applicationautoscaling",
"service/applicationdiscoveryservice",
"service/applicationinsights",
"service/apprunner",
"service/appstream",
"service/athena",
"service/autoscalingplans",
"service/b2bi",
"service/backupgateway",
"service/bcmdataexports",
"service/budgets",
"service/cloud9",
"service/cloudcontrol",
"service/cloudhsm",
"service/cloudhsmv2",
"service/cloudtrail",
"service/cloudwatchevents",
"service/cloudwatchlogs",
"service/codebuild",
"service/codecommit",
"service/codedeploy",
"service/codepipeline",
"service/codestar",
"service/codestarconnections",
"service/cognitoidentity",
"service/cognitoidentityprovider",
"service/comprehend",
"service/comprehendmedical",
"service/computeoptimizer",
"service/configservice",
"service/costandusagereportservice",
"service/costexplorer",
"service/costoptimizationhub",
"service/databasemigrationservice",
"service/datapipeline",
"service/datasync",
"service/dax",
"service/devicefarm",
"service/directconnect",
"service/directoryservice",
"service/dynamodb",
"service/dynamodbstreams",
"service/ec2instanceconnect",
"service/ecr",
"service/ecrpublic",
"service/ecs",
"service/emr",
"service/eventbridge",
"service/firehose",
"service/fms",
"service/forecast",
"service/forecastquery",
"service/frauddetector",
"service/freetier",
"service/fsx",
"service/gamelift",
"service/globalaccelerator",
"service/glue",
"service/health",
"service/healthlake",
"service/identitystore",
"service/inspector",
"service/iotfleetwise",
"service/iotsecuretunneling",
"service/iotthingsgraph",
"service/kendra",
"service/kendraranking",
"service/keyspaces",
"service/kinesis",
"service/kinesisanalytics",
"service/kinesisanalyticsv2",
"service/kms",
"service/licensemanager",
"service/lightsail",
"service/lookoutequipment",
"service/machinelearning",
"service/marketplaceagreement",
"service/marketplacecommerceanalytics",
"service/marketplaceentitlementservice",
"service/marketplacemetering",
"service/mediastore",
"service/memorydb",
"service/migrationhub",
"service/migrationhubconfig",
"service/mturk",
"service/networkfirewall",
"service/opensearchserverless",
"service/opsworks",
"service/opsworkscm",
"service/organizations",
"service/paymentcryptography",
"service/personalize",
"service/pi",
"service/pinpointsmsvoicev2",
"service/pricing",
"service/proton",
"service/qldbsession",
"service/redshiftdata",
"service/redshiftserverless",
"service/rekognition",
"service/resourcegroupstaggingapi",
"service/route53domains",
"service/route53recoverycluster",
"service/route53resolver",
"service/sagemaker",
"service/secretsmanager",
"service/servicecatalog",
"service/servicediscovery",
"service/servicequotas",
"service/sfn",
"service/shield",
"service/sms",
"service/snowball",
"service/sqs",
"service/ssm",
"service/ssmcontacts",
"service/ssoadmin",
"service/storagegateway",
"service/support",
"service/swf",
"service/textract",
"service/timestreamquery",
"service/timestreamwrite",
"service/transcribe",
"service/transfer",
"service/translate",
"service/verifiedpermissions",
"service/voiceid",
"service/waf",
"service/wafregional",
"service/wafv2",
"service/workmail",
"service/workspaces"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,30 +196,6 @@ static void generateHttpProtocolTests(GenerationContext context) {
).generateProtocolTests();
}

public static void writeJsonErrorMessageCodeDeserializer(GenerationContext context) {
GoWriter writer = context.getWriter().get();
// The error code could be in the headers, even though for this protocol it should be in the body.
writer.write("headerCode := response.Header.Get(\"X-Amzn-ErrorType\")");
writer.write("if len(headerCode) != 0 { errorCode = restjson.SanitizeErrorCode(headerCode) }");
writer.write("");

initializeJsonDecoder(writer, "errorBody");
writer.addUseImports(AwsGoDependency.AWS_REST_JSON_PROTOCOL);
// This will check various body locations for the error code and error message
writer.write("jsonCode, message, err := restjson.GetErrorInfo(decoder)");
handleDecodeError(writer);

writer.addUseImports(SmithyGoDependency.IO);
// Reset the body in case it needs to be used for anything else.
writer.write("errorBody.Seek(0, io.SeekStart)");

// Only set the values if something was found so that we keep the default values.
// The header version of the error wins out over either of the body fields.
writer.write("if len(headerCode) == 0 && len(jsonCode) != 0 { errorCode = restjson.SanitizeErrorCode(jsonCode) }");
writer.write("if len(message) != 0 { errorMessage = message }");
writer.write("");
}

public static void initializeJsonDecoder(GoWriter writer, String bodyLocation) {
// Use a ring buffer and tee reader to help in pinpointing any deserialization errors.
writer.addUseImports(SmithyGoDependency.SMITHY_IO);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import static software.amazon.smithy.aws.go.codegen.AwsProtocolUtils.handleDecodeError;
import static software.amazon.smithy.aws.go.codegen.AwsProtocolUtils.initializeJsonDecoder;
import static software.amazon.smithy.aws.go.codegen.AwsProtocolUtils.writeJsonErrorMessageCodeDeserializer;
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;

import java.util.HashSet;
import java.util.Set;
Expand All @@ -44,6 +44,7 @@
import software.amazon.smithy.model.traits.EventPayloadTrait;
import software.amazon.smithy.go.codegen.endpoints.EndpointResolutionGenerator;
import software.amazon.smithy.go.codegen.endpoints.FnGenerator;
import software.amazon.smithy.utils.MapUtils;

/**
* Handles generating the aws.rest-json protocol for services.
Expand Down Expand Up @@ -174,7 +175,30 @@ public void generateProtocolTests(GenerationContext context) {

@Override
protected void writeErrorMessageCodeDeserializer(GenerationContext context) {
writeJsonErrorMessageCodeDeserializer(context);
var tmpl = goTemplate("""
headerCode := response.Header.Get("X-Amzn-ErrorType")
$initDecoder:W
bodyInfo, err := getProtocolErrorInfo(decoder)
$handleDecodeError:W
errorBody.Seek(0, io.SeekStart)
if typ, ok := resolveProtocolErrorType(headerCode, bodyInfo); ok {
errorCode = restjson.SanitizeErrorCode(typ)
}
if len(bodyInfo.Message) != 0 {
errorMessage = bodyInfo.Message
}
""",
MapUtils.of(
"initDecoder", (GoWriter.Writable) writer -> initializeJsonDecoder(writer, "errorBody"),
"handleDecodeError", (GoWriter.Writable) AwsProtocolUtils::handleDecodeError
));
context.getWriter().get()
.addUseImports(AwsGoDependency.AWS_REST_JSON_PROTOCOL)
.addUseImports(SmithyGoDependency.IO)
.write(tmpl);
}

@Override
Expand Down Expand Up @@ -367,4 +391,47 @@ public void generateEndpointResolution(GenerationContext context) {
generator.generate(context);
}

@Override
public void generateSharedDeserializerComponents(GenerationContext context) {
super.generateSharedDeserializerComponents(context);
writeGetProtocolErrorInfo(context);
}

private void writeGetProtocolErrorInfo(GenerationContext context) {
var tmpl = goTemplate("""
type protocolErrorInfo struct {
Type string `json:"__type"`
Message string
Code any // nonstandard for awsjson but some services do present the type here
}
func getProtocolErrorInfo(decoder *json.Decoder) (protocolErrorInfo, error) {
var errInfo protocolErrorInfo
if err := decoder.Decode(&errInfo); err != nil {
if err == io.EOF {
return errInfo, nil
}
return errInfo, err
}
return errInfo, nil
}
func resolveProtocolErrorType(headerType string, bodyInfo protocolErrorInfo) (string, bool) {
if len(headerType) != 0 {
return headerType, true
} else if len(bodyInfo.Type) != 0 {
return bodyInfo.Type, true
} else if code, ok := bodyInfo.Code.(string); ok && len(code) != 0 {
return code, true
}
return "", false
}
""");
context.getWriter().get()
.addUseImports(SmithyGoDependency.JSON)
.addUseImports(SmithyGoDependency.IO)
.write(tmpl);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import static software.amazon.smithy.go.codegen.integration.HttpProtocolGeneratorUtils.isShapeWithResponseBindings;
import static software.amazon.smithy.aws.go.codegen.AwsProtocolUtils.handleDecodeError;
import static software.amazon.smithy.aws.go.codegen.AwsProtocolUtils.initializeJsonDecoder;
import static software.amazon.smithy.aws.go.codegen.AwsProtocolUtils.writeJsonErrorMessageCodeDeserializer;

import java.util.HashSet;
import java.util.Optional;
Expand Down Expand Up @@ -363,7 +362,27 @@ protected void generateOperationDocumentDeserializer(

@Override
protected void writeErrorMessageCodeDeserializer(GenerationContext context) {
writeJsonErrorMessageCodeDeserializer(context);
GoWriter writer = context.getWriter().get();
// The error code could be in the headers, even though for this protocol it should be in the body.
writer.write("headerCode := response.Header.Get(\"X-Amzn-ErrorType\")");
writer.write("if len(headerCode) != 0 { errorCode = restjson.SanitizeErrorCode(headerCode) }");
writer.write("");

initializeJsonDecoder(writer, "errorBody");
writer.addUseImports(AwsGoDependency.AWS_REST_JSON_PROTOCOL);
// This will check various body locations for the error code and error message
writer.write("jsonCode, message, err := restjson.GetErrorInfo(decoder)");
handleDecodeError(writer);

writer.addUseImports(SmithyGoDependency.IO);
// Reset the body in case it needs to be used for anything else.
writer.write("errorBody.Seek(0, io.SeekStart)");

// Only set the values if something was found so that we keep the default values.
// The header version of the error wins out over either of the body fields.
writer.write("if len(headerCode) == 0 && len(jsonCode) != 0 { errorCode = restjson.SanitizeErrorCode(jsonCode) }");
writer.write("if len(message) != 0 { errorMessage = message }");
writer.write("");
}

@Override
Expand Down
Loading

0 comments on commit a264562

Please sign in to comment.