Skip to content

Commit

Permalink
add context suppress cancel utility
Browse files Browse the repository at this point in the history
  • Loading branch information
jasdel committed Apr 8, 2022
1 parent a888b08 commit eb6dd35
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 13 deletions.
7 changes: 6 additions & 1 deletion auth/bearer/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ func (m *AuthenticationMiddleware) HandleFinalize(
// not HTTPS.
type SignHTTPSMessage struct{}

// NewSignHTTPSMessage returns an initialized signer for HTTP messages.
func NewSignHTTPSMessage() *SignHTTPSMessage {
return &SignHTTPSMessage{}
}

// SignWithBearerToken returns a copy of the HTTP request with the bearer token
// added via the "Authorization" header, per [RFC 6750](https://datatracker.ietf.org/doc/html/rfc6750).
// added via the "Authorization" header, per RFC 6750, https://datatracker.ietf.org/doc/html/rfc6750.
//
// Returns an error if the request's URL scheme is not HTTPS, or the request
// message is not an smithy-go HTTP Request pointer type.
Expand Down
2 changes: 1 addition & 1 deletion auth/bearer/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (t Token) Expired(now time.Time) bool {
if !t.CanExpire {
return false
}
return now.Equal(t.Expires) || now.After(t.Expires)
return now.Round(0).Equal(t.Expires) || now.After(t.Expires)
}

// TokenProvider provides interface for retrieving bearer tokens.
Expand Down
10 changes: 9 additions & 1 deletion auth/bearer/token_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync/atomic"
"time"

smithycontext "github.com/aws/smithy-go/context"
"github.com/aws/smithy-go/internal/sync/singleflight"
)

Expand Down Expand Up @@ -99,6 +100,13 @@ func NewTokenCache(provider TokenProvider, optFns ...func(*TokenCacheOptions)) *
// and not be canceled with the Context. Set RetrieveBearerTokenTimeout to
// provide a timeout, preventing the underlying TokenProvider blocking forever.
//
// By default, if the passed in Context is canceled, all of its values will be
// considered expired. The wrapped TokenProvider will not be able to lookup the
// values from the Context once it is expired. This is done to protect against
// expired values no longer being valid. To disable this behavior, use
// smithy-go's context.WithPreserveExpiredValues to add a value to the Context
// before calling RetrieveBearerToken to enable support for expired values.
//
// Without RetrieveBearerTokenTimeout there is the potential for a underlying
// Provider's RetrieveBearerToken call to sit forever. Blocking in subsequent
// attempts at refreshing the token.
Expand Down Expand Up @@ -157,7 +165,7 @@ func (p *TokenCache) tryAsyncRefresh(ctx context.Context) {

func (p *TokenCache) refreshBearerToken(ctx context.Context) (Token, error) {
resCh := p.sfGroup.DoChan("refresh-token", func() (interface{}, error) {
var ctx context.Context = &suppressedContext{ctx}
ctx := smithycontext.WithSuppressCancel(ctx)
if v := p.options.RetrieveBearerTokenTimeout; v != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class HttpBearerAuth implements GoIntegration {
private static final String SIGNER_OPTION_NAME = "BearerAuthSigner";
private static final String NEW_DEFAULT_SIGNER_NAME = "newDefault" + SIGNER_OPTION_NAME;
private static final String SIGNER_RESOLVER_NAME = "resolve" + SIGNER_OPTION_NAME;
private static final String REGISTER_MIDDLEWARE_NAME = "add";
private static final String REGISTER_MIDDLEWARE_NAME = "add" + SIGNER_OPTION_NAME + "Middleware";

@Override
public void writeAdditionalFiles(
Expand All @@ -63,7 +63,7 @@ public void writeAdditionalFiles(

goDelegator.useShapeWriter(service, (writer) -> {
writeMiddlewareRegister(writer);
writeConfigFieldResolver(writer);
writeSignerConfigFieldResolver(writer);
writeNewSignerFunc(writer);
});
}
Expand All @@ -88,21 +88,19 @@ private void writeMiddlewareRegister(GoWriter writer) {
writer.popState();
}

private void writeConfigFieldResolver(GoWriter writer) {
private void writeSignerConfigFieldResolver(GoWriter writer) {
writer.pushState();

writer.putContext("funcName", SIGNER_RESOLVER_NAME);
writer.putContext("signer", SymbolUtils.createValueSymbolBuilder("TokenProvider",
SmithyGoDependency.SMITHY_AUTH_BEARER).build());

writer.putContext("signerOption", SIGNER_OPTION_NAME);
writer.putContext("newDefaultSigner", NEW_DEFAULT_SIGNER_NAME);

writer.write("""
func $funcName:L(o *Options) {
if o.$signerOption:L != nil {
return
}
o.$signerOption:L = $newDefaultSigner(*o)
o.$signerOption:L = $newDefaultSigner:L(*o)
}
""");

Expand All @@ -113,15 +111,15 @@ private void writeNewSignerFunc(GoWriter writer) {
writer.pushState();

writer.putContext("funcName", NEW_DEFAULT_SIGNER_NAME);
writer.putContext("signer", SymbolUtils.createValueSymbolBuilder("TokenProvider",
writer.putContext("signerInterface", SymbolUtils.createValueSymbolBuilder("Signer",
SmithyGoDependency.SMITHY_AUTH_BEARER).build());

// TODO this is HTTP specific, should be based on protocol/transport of API.
writer.putContext("newDefaultSigner", SymbolUtils.createValueSymbolBuilder("NewSignHTTPSMessage",
SmithyGoDependency.SMITHY_AUTH_BEARER).build());

writer.write("""
func $funcName:L(o Options) *$signer:T {
func $funcName:L(o Options) $signerInterface:T {
return $newDefaultSigner:T()
}
""");
Expand Down Expand Up @@ -161,7 +159,7 @@ public List<RuntimeClientPlugin> getClientPlugins() {
.operationPredicate(HttpBearerAuth::hasBearerAuthScheme)
.registerMiddleware(MiddlewareRegistrar.builder()
.resolvedFunction(SymbolUtils.createValueSymbolBuilder(
REGISTER_MIDDLEWARE_NAME).build())
REGISTER_MIDDLEWARE_NAME).build())
.useClientOptions()
.build())
.build()
Expand Down
81 changes: 81 additions & 0 deletions context/suppress_expired.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package context

import "context"

// valueOnlyContext provides a utility to preserve only the values of a
// Context. Suppressing any cancellation or deadline on that context being
// propagated downstream of this value.
//
// If preserveExpiredValues is false (default), and the valueCtx is canceled,
// calls to lookup values with the Values method, will always return nil. Setting
// preserveExpiredValues to true, will allow the valueOnlyContext to lookup
// values in valueCtx even if valueCtx is canceled.
//
// Based on the Go standard libraries net/lookup.go onlyValuesCtx utility.
// https://github.com/golang/go/blob/da2773fe3e2f6106634673a38dc3a6eb875fe7d8/src/net/lookup.go
type valueOnlyContext struct {
context.Context

preserveExpiredValues bool
valuesCtx context.Context
}

var _ context.Context = (*valueOnlyContext)(nil)

// Value looks up the key, returning its value. If configured to not preserve
// values of expired context, and the wrapping context is canceled, nil will be
// returned.
func (v *valueOnlyContext) Value(key interface{}) interface{} {
if !v.preserveExpiredValues {
select {
case <-v.valuesCtx.Done():
return nil
default:
}
}

return v.valuesCtx.Value(key)
}

// WithSuppressCancel wraps the Context value, suppressing its deadline and
// cancellation events being propagated downstream to consumer of the returned
// context.
//
// By default the wrapped Context's Values are available downstream until the
// wrapped Context is canceled. Once the wrapped Context is canceled, Values
// method called on the context return will no longer lookup any key. As they
// are now considered expired.
//
// To override this behavior, use WithPreserveExpiredValues on the Context
// before it is wrapped by WithSuppressCancel. This will make the Context
// returned by WithSuppressCancel allow lookup of expired values.
func WithSuppressCancel(ctx context.Context) context.Context {
return &valueOnlyContext{
Context: context.Background(),
valuesCtx: ctx,

preserveExpiredValues: GetPreserveExpiredValues(ctx),
}
}

type preserveExpiredValuesKey struct{}

// WithPreserveExpiredValues adds a Value to the Context if expired values
// should be preserved, and looked up by a Context wrapped by
// WithSuppressCancel.
//
// WithPreserveExpiredValues must be added as a value to a Context, before that
// Context is wrapped by WithSuppressCancel
func WithPreserveExpiredValues(ctx context.Context, enable bool) context.Context {
return context.WithValue(ctx, preserveExpiredValuesKey{}, enable)
}

// GetPreserveExpiredValues looks up, and returns the PreserveExpressValues
// value in the context. Returning true if enabled, false otherwise.
func GetPreserveExpiredValues(ctx context.Context) bool {
v := ctx.Value(preserveExpiredValuesKey{})
if v != nil {
return v.(bool)
}
return false
}

0 comments on commit eb6dd35

Please sign in to comment.