diff --git a/aws/credentials/credentials.go b/aws/credentials/credentials.go index 42416fc2f0..ed086992f6 100644 --- a/aws/credentials/credentials.go +++ b/aws/credentials/credentials.go @@ -178,7 +178,8 @@ func (e *Expiry) IsExpired() bool { type Credentials struct { creds Value forceRefresh bool - m sync.Mutex + + m sync.RWMutex provider Provider } @@ -201,6 +202,17 @@ func NewCredentials(provider Provider) *Credentials { // If Credentials.Expire() was called the credentials Value will be force // expired, and the next call to Get() will cause them to be refreshed. func (c *Credentials) Get() (Value, error) { + // Check the cached credentials first with just the read lock. + c.m.RLock() + if !c.isExpired() { + creds := c.creds + c.m.RUnlock() + return creds, nil + } + c.m.RUnlock() + + // Credentials are expired need to retrieve the credentials taking the full + // lock. c.m.Lock() defer c.m.Unlock() @@ -234,8 +246,8 @@ func (c *Credentials) Expire() { // If the Credentials were forced to be expired with Expire() this will // reflect that override. func (c *Credentials) IsExpired() bool { - c.m.Lock() - defer c.m.Unlock() + c.m.RLock() + defer c.m.RUnlock() return c.isExpired() } diff --git a/aws/credentials/credentials_bench_test.go b/aws/credentials/credentials_bench_test.go new file mode 100644 index 0000000000..01a5d633bf --- /dev/null +++ b/aws/credentials/credentials_bench_test.go @@ -0,0 +1,90 @@ +// +build go1.9 + +package credentials + +import ( + "fmt" + "strconv" + "sync" + "testing" + "time" +) + +func BenchmarkCredentials_Get(b *testing.B) { + stub := &stubProvider{} + + cases := []int{1, 10, 100, 500, 1000, 10000} + + for _, c := range cases { + b.Run(strconv.Itoa(c), func(b *testing.B) { + creds := NewCredentials(stub) + var wg sync.WaitGroup + wg.Add(c) + for i := 0; i < c; i++ { + go func() { + for j := 0; j < b.N; j++ { + v, err := creds.Get() + if err != nil { + b.Fatalf("expect no error %v, %v", v, err) + } + } + wg.Done() + }() + } + b.ResetTimer() + + wg.Wait() + }) + } +} + +func BenchmarkCredentials_Get_Expire(b *testing.B) { + p := &blockProvider{} + + expRates := []int{10000, 1000, 100} + cases := []int{1, 10, 100, 500, 1000, 10000} + + for _, expRate := range expRates { + for _, c := range cases { + b.Run(fmt.Sprintf("%d-%d", expRate, c), func(b *testing.B) { + creds := NewCredentials(p) + var wg sync.WaitGroup + wg.Add(c) + for i := 0; i < c; i++ { + go func(id int) { + for j := 0; j < b.N; j++ { + v, err := creds.Get() + if err != nil { + b.Fatalf("expect no error %v, %v", v, err) + } + // periodically expire creds to cause rwlock + if id == 0 && j%expRate == 0 { + creds.Expire() + } + } + wg.Done() + }(i) + } + b.ResetTimer() + + wg.Wait() + }) + } + } +} + +type blockProvider struct { + creds Value + expired bool + err error +} + +func (s *blockProvider) Retrieve() (Value, error) { + s.expired = false + s.creds.ProviderName = "blockProvider" + time.Sleep(time.Millisecond) + return s.creds, s.err +} +func (s *blockProvider) IsExpired() bool { + return s.expired +}