Skip to content

Commit

Permalink
update express cache key (#2414)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws committed Dec 7, 2023
1 parent 9b90af4 commit b3c7fbf
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 30 deletions.
8 changes: 8 additions & 0 deletions .changelog/8e6a01197da848c88aaab5adb296abc1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "8e6a0119-7da8-48c8-8aaa-b5adb296abc1",
"type": "bugfix",
"description": "Improve uniqueness of default S3Express sesssion credentials cache keying to prevent collision in multi-credential scenarios.",
"modules": [
"service/s3"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate;
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.SymbolUtils.buildPackageSymbol;

public class S3ExpressAuthScheme implements GoIntegration {
private static final ConfigField s3ExpressCredentials =
Expand All @@ -67,6 +68,14 @@ public class S3ExpressAuthScheme implements GoIntegration {
.withClientInput(true)
.build();

private static final ConfigFieldResolver s3ExpressCredentialsOperationFinalizer =
ConfigFieldResolver.builder()
.location(ConfigFieldResolver.Location.OPERATION)
.target(ConfigFieldResolver.Target.FINALIZATION)
.resolver(buildPackageSymbol("finalizeOperationExpressCredentials"))
.withClientInput(true)
.build();

@Override
public void writeAdditionalFiles(
GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator
Expand All @@ -84,6 +93,7 @@ public List<RuntimeClientPlugin> getClientPlugins() {
.addConfigField(s3ExpressCredentials)
.addConfigFieldResolver(s3ExpressCredentialsResolver)
.addConfigFieldResolver(s3ExpressCredentialsClientFinalizer)
.addConfigFieldResolver(s3ExpressCredentialsOperationFinalizer)
.addAuthSchemeDefinition(SigV4S3ExpressTrait.ID, new SigV4S3Express())
.build()
);
Expand Down
2 changes: 2 additions & 0 deletions service/s3/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

103 changes: 75 additions & 28 deletions service/s3/express_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package s3

import (
"context"
"crypto/hmac"
"crypto/sha256"
"errors"
"fmt"
"sync"
"time"

Expand All @@ -17,18 +20,49 @@ const s3ExpressCacheCap = 100

const s3ExpressRefreshWindow = 1 * time.Minute

type cacheKey struct {
CredentialsHash string // hmac(sigv4 akid, sigv4 secret)
Bucket string
}

func (c cacheKey) Slug() string {
return fmt.Sprintf("%s%s", c.CredentialsHash, c.Bucket)
}

type sessionCredsCache struct {
mu sync.Mutex
cache cache.Cache
}

func (c *sessionCredsCache) Get(key cacheKey) (*aws.Credentials, bool) {
c.mu.Lock()
defer c.mu.Unlock()

if v, ok := c.cache.Get(key); ok {
return v.(*aws.Credentials), true
}
return nil, false
}

func (c *sessionCredsCache) Put(key cacheKey, creds *aws.Credentials) {
c.mu.Lock()
defer c.mu.Unlock()

c.cache.Put(key, creds)
}

// The default S3Express provider uses an LRU cache with a capacity of 100.
//
// Credentials will be refreshed asynchronously when a Retrieve() call is made
// for cached credentials within an expiry window (1 minute, currently
// non-configurable).
type defaultS3ExpressCredentialsProvider struct {
mu sync.Mutex
sf singleflight.Group

client createSessionAPIClient
credsCache cache.Cache
cache *sessionCredsCache
refreshWindow time.Duration
v4creds aws.CredentialsProvider // underlying credentials used for CreateSession
}

type createSessionAPIClient interface {
Expand All @@ -37,35 +71,54 @@ type createSessionAPIClient interface {

func newDefaultS3ExpressCredentialsProvider() *defaultS3ExpressCredentialsProvider {
return &defaultS3ExpressCredentialsProvider{
credsCache: lru.New(s3ExpressCacheCap),
cache: &sessionCredsCache{
cache: lru.New(s3ExpressCacheCap),
},
refreshWindow: s3ExpressRefreshWindow,
}
}

// returns a cloned provider using new base credentials, used when per-op
// config mutations change the credentials provider
func (p *defaultS3ExpressCredentialsProvider) CloneWithBaseCredentials(v4creds aws.CredentialsProvider) *defaultS3ExpressCredentialsProvider {
return &defaultS3ExpressCredentialsProvider{
client: p.client,
cache: p.cache,
refreshWindow: p.refreshWindow,
v4creds: v4creds,
}
}

func (p *defaultS3ExpressCredentialsProvider) Retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
p.mu.Lock()
defer p.mu.Unlock()
v4creds, err := p.v4creds.Retrieve(ctx)
if err != nil {
return aws.Credentials{}, fmt.Errorf("get sigv4 creds: %w", err)
}

creds, ok := p.getCacheCredentials(bucket)
key := cacheKey{
CredentialsHash: gethmac(v4creds.AccessKeyID, v4creds.SecretAccessKey),
Bucket: bucket,
}
creds, ok := p.cache.Get(key)
if !ok || creds.Expired() {
return p.awaitDoChanRetrieve(ctx, bucket)
return p.awaitDoChanRetrieve(ctx, key)
}

if creds.Expires.Sub(sdk.NowTime()) <= p.refreshWindow {
p.doChanRetrieve(ctx, bucket)
p.doChanRetrieve(ctx, key)
}

return *creds, nil
}

func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, bucket string) <-chan singleflight.Result {
return p.sf.DoChan(bucket, func() (interface{}, error) {
return p.retrieve(ctx, bucket)
func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, key cacheKey) <-chan singleflight.Result {
return p.sf.DoChan(key.Slug(), func() (interface{}, error) {
return p.retrieve(ctx, key)
})
}

func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
ch := p.doChanRetrieve(ctx, bucket)
func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
ch := p.doChanRetrieve(ctx, key)

select {
case r := <-ch:
Expand All @@ -75,9 +128,9 @@ func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Co
}
}

func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
resp, err := p.client.CreateSession(ctx, &CreateSessionInput{
Bucket: aws.String(bucket),
Bucket: aws.String(key.Bucket),
})
if err != nil {
return aws.Credentials{}, err
Expand All @@ -88,22 +141,10 @@ func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, buck
return aws.Credentials{}, err
}

p.putCacheCredentials(bucket, creds)
p.cache.Put(key, creds)
return *creds, nil
}

func (p *defaultS3ExpressCredentialsProvider) getCacheCredentials(bucket string) (*aws.Credentials, bool) {
if v, ok := p.credsCache.Get(bucket); ok {
return v.(*aws.Credentials), true
}

return nil, false
}

func (p *defaultS3ExpressCredentialsProvider) putCacheCredentials(bucket string, creds *aws.Credentials) {
p.credsCache.Put(bucket, creds)
}

func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
if o.Credentials == nil {
return nil, errors.New("s3express session credentials unset")
Expand All @@ -121,3 +162,9 @@ func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
Expires: *o.Credentials.Expiration,
}, nil
}

func gethmac(p, key string) string {
hash := hmac.New(sha256.New, []byte(key))
hash.Write([]byte(p))
return string(hash.Sum(nil))
}
19 changes: 17 additions & 2 deletions service/s3/express_resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,26 @@ func resolveExpressCredentials(o *Options) {
}
}

// Config finalizer: if we're using the default S3Express implementation,
// grab a reference to the client for its CreateSession API.
// Config finalizer: if we're using the default S3Express implementation, grab
// a reference to the client for its CreateSession API, and the underlying
// sigv4 credentials provider for cache keying.
func finalizeExpressCredentials(o *Options, c *Client) {
if p, ok := o.ExpressCredentials.(*defaultS3ExpressCredentialsProvider); ok {
p.client = c
p.v4creds = o.Credentials
}
}

// Operation config finalizer: update the sigv4 credentials on the default
// express provider if it changed to ensure different cache keys
func finalizeOperationExpressCredentials(o *Options, c Client) {
p, ok := o.ExpressCredentials.(*defaultS3ExpressCredentialsProvider)
if !ok {
return
}

if c.options.Credentials != o.Credentials {
o.ExpressCredentials = p.CloneWithBaseCredentials(o.Credentials)
}
}

Expand Down

0 comments on commit b3c7fbf

Please sign in to comment.