Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use singleflight to coalesce credentials Get() calls into a single Retrieve() #3127

Merged
merged 5 commits into from Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Expand Up @@ -6,5 +6,7 @@
* This utility is no longer relevant. The utility allowed users the beta pre-release v0 SDK to update to the v1.0 released version of the SDK.

### SDK Enhancements
* `aws/credentials`: Add grouping of concurrent refresh of credentials ([#3127](https://github.com/aws/aws-sdk-go/pull/3127/)
* Concurrent calls to `Credentials.Get` are now grouped in order to prevent numerous synchronous calls to refresh the credentials. Replacing the mutex with a singleflight reduces the overall amount of time request signatures need to wait while retrieving credentials. This is improvement becomes pronounced when many requests are being made concurrently.

### SDK Bugs
3 changes: 2 additions & 1 deletion Makefile
Expand Up @@ -6,6 +6,7 @@ LINTIGNOREINFLECTS3UPLOAD='service/s3/s3manager/upload\.go:.+struct field SSEKMS
LINTIGNOREENDPOINTS='aws/endpoints/(defaults|dep_service_ids).go:.+(method|const) .+ should be '
LINTIGNOREDEPS='vendor/.+\.go'
LINTIGNOREPKGCOMMENT='service/[^/]+/doc_custom.go:.+package comment should be of the form'
LINTIGNORESINGLEFIGHT='internal/sync/singleflight/singleflight.go:.+error should be the last type'
UNIT_TEST_TAGS="example codegen awsinclude"
ALL_TAGS="example codegen awsinclude integration perftest"

Expand Down Expand Up @@ -191,7 +192,7 @@ verify: lint vet
lint:
@echo "go lint SDK and vendor packages"
@lint=`golint ./...`; \
dolint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS} -e ${LINTIGNOREINFLECTS3UPLOAD} -e ${LINTIGNOREPKGCOMMENT} -e ${LINTIGNOREENDPOINTS}`; \
dolint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS} -e ${LINTIGNOREINFLECTS3UPLOAD} -e ${LINTIGNOREPKGCOMMENT} -e ${LINTIGNOREENDPOINTS} -e ${LINTIGNORESINGLEFIGHT}`; \
echo "$$dolint"; \
if [ "$$dolint" != "" ]; then exit 1; fi

Expand Down
67 changes: 26 additions & 41 deletions aws/credentials/credentials.go
Expand Up @@ -50,10 +50,11 @@ package credentials

import (
"fmt"
"sync"
"sync/atomic"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/sync/singleflight"
)

// AnonymousCredentials is an empty Credential object that can be used as
Expand Down Expand Up @@ -197,20 +198,19 @@ func (e *Expiry) ExpiresAt() time.Time {
// first instance of the credentials Value. All calls to Get() after that
// will return the cached credentials Value until IsExpired() returns true.
type Credentials struct {
creds Value
forceRefresh bool

m sync.RWMutex
creds atomic.Value
sf singleflight.Group

provider Provider
}

// NewCredentials returns a pointer to a new Credentials with the provider set.
func NewCredentials(provider Provider) *Credentials {
return &Credentials{
provider: provider,
forceRefresh: true,
c := &Credentials{
provider: provider,
}
c.creds.Store(Value{})
return c
}

// Get returns the credentials value, or error if the credentials Value failed
Expand All @@ -223,30 +223,24 @@ func NewCredentials(provider Provider) *Credentials {
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) Get() (Value, error) {
// Check the cached credentials first with just the read lock.
c.m.RLock()
if !c.isExpired() {
creds := c.creds
c.m.RUnlock()
return creds, nil
if creds := c.creds.Load(); !c.isExpired(creds) {
return creds.(Value), nil
}
c.m.RUnlock()

// Credentials are expired need to retrieve the credentials taking the full
// lock.
c.m.Lock()
defer c.m.Unlock()
creds, err, _ := c.sf.Do("", func() (interface{}, error) {
if creds := c.creds.Load(); !c.isExpired(creds) {
return creds.(Value), nil
}

if c.isExpired() {
creds, err := c.provider.Retrieve()
if err != nil {
return Value{}, err
if err == nil {
c.creds.Store(creds)
}
c.creds = creds
c.forceRefresh = false
}

return c.creds, nil
return creds, err
})

return creds.(Value), err
}

// Expire expires the credentials and forces them to be retrieved on the
Expand All @@ -255,10 +249,7 @@ func (c *Credentials) Get() (Value, error) {
// This will override the Provider's expired state, and force Credentials
// to call the Provider's Retrieve().
func (c *Credentials) Expire() {
c.m.Lock()
defer c.m.Unlock()

c.forceRefresh = true
c.creds.Store(Value{})
}

// IsExpired returns if the credentials are no longer valid, and need
Expand All @@ -267,31 +258,25 @@ func (c *Credentials) Expire() {
// If the Credentials were forced to be expired with Expire() this will
// reflect that override.
func (c *Credentials) IsExpired() bool {
c.m.RLock()
defer c.m.RUnlock()

return c.isExpired()
return c.isExpired(c.creds.Load())
}

// isExpired helper method wrapping the definition of expired credentials.
func (c *Credentials) isExpired() bool {
return c.forceRefresh || c.provider.IsExpired()
func (c *Credentials) isExpired(creds interface{}) bool {
return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired()
}

// ExpiresAt provides access to the functionality of the Expirer interface of
// the underlying Provider, if it supports that interface. Otherwise, it returns
// an error.
func (c *Credentials) ExpiresAt() (time.Time, error) {
c.m.RLock()
defer c.m.RUnlock()

expirer, ok := c.provider.(Expirer)
if !ok {
return time.Time{}, awserr.New("ProviderNotExpirer",
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.ProviderName),
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.Load().(Value).ProviderName),
nil)
}
if c.forceRefresh {
if c.creds.Load().(Value) == (Value{}) {
// set expiration time to the distant past
return time.Time{}, nil
}
Expand Down
33 changes: 32 additions & 1 deletion aws/credentials/credentials_test.go
Expand Up @@ -71,7 +71,7 @@ func TestCredentialsExpire(t *testing.T) {
t.Errorf("Expected to be expired")
}

c.forceRefresh = false
c.Get()
if c.IsExpired() {
t.Errorf("Expected not to be expired")
}
Expand Down Expand Up @@ -171,3 +171,34 @@ func TestCredentialsExpiresAt_HasExpirer(t *testing.T) {
t.Errorf("Expected distant past expiration, got %v", expiration)
}
}

type stubProviderConcurrent struct {
stubProvider
done chan struct{}
}

func (s *stubProviderConcurrent) Retrieve() (Value, error) {
<-s.done
return s.stubProvider.Retrieve()
}

func TestCredentialsGetConcurrent(t *testing.T) {
stub := &stubProviderConcurrent{
done: make(chan struct{}),
}

c := NewCredentials(stub)
done := make(chan struct{})

for i := 0; i < 2; i++ {
go func() {
c.Get()
done <- struct{}{}
}()
}

// Validates that a single call to Retrieve is shared between two calls to Get
stub.done <- struct{}{}
<-done
<-done
}
27 changes: 27 additions & 0 deletions internal/sync/singleflight/LICENSE
@@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
120 changes: 120 additions & 0 deletions internal/sync/singleflight/singleflight.go
@@ -0,0 +1,120 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package singleflight provides a duplicate function call suppression
// mechanism.
package singleflight

import "sync"

// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup

// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val interface{}
err error

// forgotten indicates whether Forget was called with this call's key
// while the call was still in flight.
forgotten bool

// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
chans []chan<- Result
}

// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}

// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val interface{}
Err error
Shared bool
}

// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value shared indicates whether v was given to multiple callers.
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err, true
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()

g.doCall(c, key, fn)
return c.val, c.err, c.dups > 0
}

// DoChan is like Do but returns a channel that will receive the
// results when they are ready.
func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
ch := make(chan Result, 1)
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
c.chans = append(c.chans, ch)
g.mu.Unlock()
return ch
}
c := &call{chans: []chan<- Result{ch}}
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()

go g.doCall(c, key, fn)

return ch
}

// doCall handles the single call for a key.
func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) {
c.val, c.err = fn()
c.wg.Done()

g.mu.Lock()
if !c.forgotten {
delete(g.m, key)
}
for _, ch := range c.chans {
ch <- Result{c.val, c.err, c.dups > 0}
}
g.mu.Unlock()
}

// Forget tells the singleflight to forget about a key. Future calls
// to Do for this key will call the function rather than waiting for
// an earlier call to complete.
func (g *Group) Forget(key string) {
g.mu.Lock()
if c, ok := g.m[key]; ok {
c.forgotten = true
}
delete(g.m, key)
g.mu.Unlock()
}