Skip to content

Commit

Permalink
Add Gzip request compression feature (#467)
Browse files Browse the repository at this point in the history
* Add and Merge request compression feature

* Modify and Merge request compression codegen

* Add and Merge changelog for last commit

* Modify logic of request compression middleware

* Add request compression algorithm codegen part

* resolve METAINFO conflict

* Change dependency format

* Revert dependency format

* Change request compression middleware to operation level

* Change codegen comment

* Change static middleware import

* Change go dependency codegen

* Add body compare fn to request compress op unit test

* Solve rebase conflict

---------

Co-authored-by: Tianyi Wang <wty@amazon.com>
  • Loading branch information
wty-Bryant and Tianyi Wang committed Dec 6, 2023
1 parent 88d16be commit 690fcaa
Show file tree
Hide file tree
Showing 12 changed files with 602 additions and 1 deletion.
8 changes: 8 additions & 0 deletions .changelog/80ed28327bcd4301a264f318efaf8216.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "80ed2832-7bcd-4301-a264-f318efaf8216",
"type": "feature",
"description": "Support modeled request compression.",
"modules": [
"."
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public final class SmithyGoDependency {
public static final GoDependency SMITHY_HTTP_TRANSPORT = smithy("transport/http", "smithyhttp");
public static final GoDependency SMITHY_MIDDLEWARE = smithy("middleware");
public static final GoDependency SMITHY_PRIVATE_PROTOCOL = smithy("private/protocol", "smithyprivateprotocol");
public static final GoDependency SMITHY_REQUEST_COMPRESSION =
smithy("private/requestcompression", "smithyrequestcompression");
public static final GoDependency SMITHY_TIME = smithy("time", "smithytime");
public static final GoDependency SMITHY_HTTP_BINDING = smithy("encoding/httpbinding");
public static final GoDependency SMITHY_JSON = smithy("encoding/json", "smithyjson");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,11 @@ public static final class Bearer {
public static final Symbol NewSignHTTPSMessage = SmithyGoDependency.SMITHY_AUTH_BEARER.valueSymbol("NewSignHTTPSMessage");
}
}

public static final class Private {
public static final class RequestCompression {
public static final Symbol AddRequestCompression = SmithyGoDependency.SMITHY_REQUEST_COMPRESSION.valueSymbol("AddRequestCompression");
public static final Symbol AddCaptureUncompressedRequest = SmithyGoDependency.SMITHY_REQUEST_COMPRESSION.valueSymbol("AddCaptureUncompressedRequestMiddleware");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,31 @@

package software.amazon.smithy.go.codegen.integration;

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.SmithyGoTypes.Private.RequestCompression.AddCaptureUncompressedRequest;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Consumer;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.model.traits.RequestCompressionTrait;
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase;
import software.amazon.smithy.utils.MapUtils;

/**
* Generates HTTP protocol unit tests for HTTP request test cases.
*/
public class HttpProtocolUnitTestRequestGenerator extends HttpProtocolUnitTestGenerator<HttpRequestTestCase> {
private static final Logger LOGGER = Logger.getLogger(HttpProtocolUnitTestRequestGenerator.class.getName());

private static final Set<String> ALLOWED_ALGORITHMS = new HashSet<>(Arrays.asList("gzip"));

/**
* Initializes the protocol test generator.
*
Expand Down Expand Up @@ -198,6 +209,10 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC
*/
protected void generateTestBodySetup(GoWriter writer) {
writer.write("actualReq := &http.Request{}");
if (operation.hasTrait(RequestCompressionTrait.class)) {
writer.addUseImports(SmithyGoDependency.BYTES);
writer.write("rawBodyBuf := &bytes.Buffer{}");
}
}

/**
Expand Down Expand Up @@ -227,8 +242,29 @@ protected void generateTestInvokeClientOperation(GoWriter writer, String clientN
writer.write("return $T(stack, actualReq)",
SymbolUtils.createValueSymbolBuilder("AddCaptureRequestMiddleware",
SmithyGoDependency.SMITHY_PRIVATE_PROTOCOL).build());
});
});
if (operation.hasTrait(RequestCompressionTrait.class)) {
writer.write(goTemplate("""
options.APIOptions = append(options.APIOptions, func(stack $stack:P) error {
return $captureRequest:T(stack, rawBodyBuf)
})
""",
MapUtils.of(
"stack", SmithyGoTypes.Middleware.Stack,
"captureRequest", AddCaptureUncompressedRequest
)));
}
});

if (operation.hasTrait(RequestCompressionTrait.class)) {
writer.write(goTemplate("""
disable := $client:L.Options().DisableRequestCompression
min := $client:L.Options().RequestMinCompressSizeBytes
""",
MapUtils.of(
"client", clientName
)));
}
}

/**
Expand Down Expand Up @@ -259,6 +295,20 @@ protected void generateTestAssertions(GoWriter writer) {
writer.write("t.Errorf(\"expect body equal, got %v\", err)");
});
});

if (operation.hasTrait(RequestCompressionTrait.class)) {
String algorithm = operation.expectTrait(RequestCompressionTrait.class).getEncodings()
.stream().filter(it -> ALLOWED_ALGORITHMS.contains(it)).findFirst().get();
writer.write(goTemplate("""
if err := smithytesting.CompareCompressedBytes(rawBodyBuf, actualReq.Body,
disable, min, $algorithm:S); err != nil {
t.Errorf("unzipped request body not match: %q", err)
}
""",
MapUtils.of(
"algorithm", algorithm
)));
}
}

public static class Builder extends HttpProtocolUnitTestGenerator.Builder<HttpRequestTestCase> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.smithy.go.codegen.requestcompression;

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;

import java.util.ArrayList;
import java.util.List;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoCodegenPlugin;
import software.amazon.smithy.go.codegen.GoDelegator;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.GoUniverseTypes;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.integration.ConfigField;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.traits.RequestCompressionTrait;
import software.amazon.smithy.utils.ListUtils;
import software.amazon.smithy.utils.MapUtils;


public final class RequestCompression implements GoIntegration {
private static final String DISABLE_REQUEST_COMPRESSION = "DisableRequestCompression";

private static final String REQUEST_MIN_COMPRESSION_SIZE_BYTES = "RequestMinCompressSizeBytes";

private final List<RuntimeClientPlugin> runtimeClientPlugins = new ArrayList<>();

// Write operation plugin for request compression middleware
@Override
public void processFinalizedModel(GoSettings settings, Model model) {
ServiceShape service = settings.getService(model);
TopDownIndex.of(model)
.getContainedOperations(service).forEach(operation -> {
if (!operation.hasTrait(RequestCompressionTrait.class)) {
return;
}
SymbolProvider symbolProvider = GoCodegenPlugin.createSymbolProvider(model, settings);
String funcName = getAddRequestCompressionMiddlewareFuncName(
symbolProvider.toSymbol(operation).getName()
);
runtimeClientPlugins.add(RuntimeClientPlugin.builder().operationPredicate((m, s, o) -> {
if (!o.hasTrait(RequestCompressionTrait.class)) {
return false;
}
return o.equals(operation);
}).registerMiddleware(MiddlewareRegistrar.builder()
.resolvedFunction(SymbolUtils.buildPackageSymbol(funcName))
.useClientOptions().build())
.build());
});
}

@Override
public void writeAdditionalFiles(
GoSettings settings,
Model model,
SymbolProvider symbolProvider,
GoDelegator goDelegator
) {
ServiceShape service = settings.getService(model);
for (ShapeId operationID : service.getAllOperations()) {
OperationShape operation = model.expectShape(operationID, OperationShape.class);
if (!operation.hasTrait(RequestCompressionTrait.class)) {
continue;
}
goDelegator.useShapeWriter(operation, writeMiddlewareHelper(symbolProvider, operation));
}
}


public static boolean isRequestCompressionService(Model model, ServiceShape service) {
return TopDownIndex.of(model)
.getContainedOperations(service).stream()
.anyMatch(it -> it.hasTrait(RequestCompressionTrait.class));
}

@Override
public List<RuntimeClientPlugin> getClientPlugins() {
runtimeClientPlugins.add(
RuntimeClientPlugin.builder()
.servicePredicate(RequestCompression::isRequestCompressionService)
.configFields(ListUtils.of(
ConfigField.builder()
.name(DISABLE_REQUEST_COMPRESSION)
.type(GoUniverseTypes.Bool)
.documentation(
"Whether to disable automatic request compression for supported operations.")
.build(),
ConfigField.builder()
.name(REQUEST_MIN_COMPRESSION_SIZE_BYTES)
.type(GoUniverseTypes.Int64)
.documentation("The minimum request body size, in bytes, at which compression "
+ "should occur. The default value is 10 KiB. Values must fall within "
+ "[0, 1MiB].")
.build()
))
.build()
);

return runtimeClientPlugins;
}

private GoWriter.Writable generateAlgorithmList(List<String> algorithms) {
return goTemplate("""
[]string{
$W
}
""",
GoWriter.ChainWritable.of(
algorithms.stream()
.map(it -> goTemplate("$S,", it))
.toList()
).compose(false));
}

private static String getAddRequestCompressionMiddlewareFuncName(String operationName) {
return String.format("addOperation%sRequestCompressionMiddleware", operationName);
}

private GoWriter.Writable writeMiddlewareHelper(SymbolProvider symbolProvider, OperationShape operation) {
String operationName = symbolProvider.toSymbol(operation).getName();
RequestCompressionTrait trait = operation.expectTrait(RequestCompressionTrait.class);

return goTemplate("""
func $add:L(stack $stack:P, options Options) error {
return $addInternal:T(stack, options.DisableRequestCompression, options.RequestMinCompressSizeBytes,
$algorithms:W)
}
""",
MapUtils.of(
"add", getAddRequestCompressionMiddlewareFuncName(operationName),
"stack", SmithyGoTypes.Middleware.Stack,
"addInternal", SmithyGoTypes.Private.RequestCompression.AddRequestCompression,
"algorithms", generateAlgorithmList(trait.getEncodings())
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ software.amazon.smithy.go.codegen.endpoints.EndpointClientPluginsGenerator
# modeled auth schemes
software.amazon.smithy.go.codegen.integration.auth.SigV4AuthScheme
software.amazon.smithy.go.codegen.integration.auth.AnonymousAuthScheme

software.amazon.smithy.go.codegen.requestcompression.RequestCompression
30 changes: 30 additions & 0 deletions private/requestcompression/gzip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package requestcompression

import (
"bytes"
"compress/gzip"
"fmt"
"io"
)

func gzipCompress(input io.Reader) ([]byte, error) {
var b bytes.Buffer
w, err := gzip.NewWriterLevel(&b, gzip.DefaultCompression)
if err != nil {
return nil, fmt.Errorf("failed to create gzip writer, %v", err)
}

inBytes, err := io.ReadAll(input)
if err != nil {
return nil, fmt.Errorf("failed read payload to compress, %v", err)
}

if _, err = w.Write(inBytes); err != nil {
return nil, fmt.Errorf("failed to write payload to be compressed, %v", err)
}
if err = w.Close(); err != nil {
return nil, fmt.Errorf("failed to flush payload being compressed, %v", err)
}

return b.Bytes(), nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package requestcompression

import (
"bytes"
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"io"
"net/http"
)

const captureUncompressedRequestID = "CaptureUncompressedRequest"

// AddCaptureUncompressedRequestMiddleware captures http request before compress encoding for check
func AddCaptureUncompressedRequestMiddleware(stack *middleware.Stack, buf *bytes.Buffer) error {
return stack.Serialize.Insert(&captureUncompressedRequestMiddleware{
buf: buf,
}, "RequestCompression", middleware.Before)
}

type captureUncompressedRequestMiddleware struct {
req *http.Request
buf *bytes.Buffer
bytes []byte
}

// ID returns id of the captureUncompressedRequestMiddleware
func (*captureUncompressedRequestMiddleware) ID() string {
return captureUncompressedRequestID
}

// HandleSerialize captures request payload before it is compressed by request compression middleware
func (m *captureUncompressedRequestMiddleware) HandleSerialize(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
) (
output middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
request, ok := input.Request.(*smithyhttp.Request)
if !ok {
return output, metadata, fmt.Errorf("error when retrieving http request")
}

_, err = io.Copy(m.buf, request.GetStream())
if err != nil {
return output, metadata, fmt.Errorf("error when copying http request stream: %q", err)
}
if err = request.RewindStream(); err != nil {
return output, metadata, fmt.Errorf("error when rewinding request stream: %q", err)
}

return next.HandleSerialize(ctx, input)
}
Loading

0 comments on commit 690fcaa

Please sign in to comment.