Skip to content

Commit

Permalink
Prevent ChainedTokenCredential data race (#17170)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell committed Mar 4, 2022
1 parent 3211915 commit 4cfc2d1
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 162 deletions.
9 changes: 4 additions & 5 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Release History

## 0.13.2 (Unreleased)

### Features Added

### Breaking Changes
## 0.13.2 (2022-03-08)

### Bugs Fixed
* Prevented a data race in `DefaultAzureCredential` and `ChainedTokenCredential`
([#17144](https://github.com/Azure/azure-sdk-for-go/issues/17144))

### Other Changes
* Upgraded App Service managed identity version from 2017-09-01 to 2019-08-01
([#17086](https://github.com/Azure/azure-sdk-for-go/pull/17086))

## 0.13.1 (2022-02-08)

Expand Down
65 changes: 52 additions & 13 deletions sdk/azidentity/chained_token_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"strings"
"sync"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand All @@ -27,10 +28,12 @@ type ChainedTokenCredentialOptions struct {
// By default, this credential will assume that the first successful credential should be the only credential used on future requests.
// If the `RetrySources` option is set to true, it will always try to get a token using all of the originally provided credentials.
type ChainedTokenCredential struct {
cond *sync.Cond
iterating bool
name string
retrySources bool
sources []azcore.TokenCredential
successfulCredential azcore.TokenCredential
retrySources bool
name string
}

// NewChainedTokenCredential creates a ChainedTokenCredential.
Expand All @@ -50,35 +53,71 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine
if options == nil {
options = &ChainedTokenCredentialOptions{}
}
return &ChainedTokenCredential{sources: cp, name: "ChainedTokenCredential", retrySources: options.RetrySources}, nil
return &ChainedTokenCredential{
cond: sync.NewCond(&sync.Mutex{}),
name: "ChainedTokenCredential",
retrySources: options.RetrySources,
sources: cp,
}, nil
}

// GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token. This method is called automatically by Azure SDK clients.
// ctx: Context controlling the request lifetime.
// opts: Options for the token request, in particular the desired scope of the access token.
func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (*azcore.AccessToken, error) {
if c.successfulCredential != nil && !c.retrySources {
return c.successfulCredential.GetToken(ctx, opts)
if !c.retrySources {
// ensure only one goroutine at a time iterates the sources and perhaps sets c.successfulCredential
c.cond.L.Lock()
for {
if c.successfulCredential != nil {
c.cond.L.Unlock()
return c.successfulCredential.GetToken(ctx, opts)
}
if !c.iterating {
c.iterating = true
// allow other goroutines to wait while this one iterates
c.cond.L.Unlock()
break
}
c.cond.Wait()
}
}

var err error
var errs []error
var token *azcore.AccessToken
var successfulCredential azcore.TokenCredential
for _, cred := range c.sources {
token, err := cred.GetToken(ctx, opts)
token, err = cred.GetToken(ctx, opts)
if err == nil {
log.Writef(EventAuthentication, "%s authenticated with %s", c.name, extractCredentialName(cred))
c.successfulCredential = cred
return token, nil
successfulCredential = cred
break
}
errs = append(errs, err)
if _, ok := err.(credentialUnavailableError); !ok {
break
}
}
if c.iterating {
c.cond.L.Lock()
c.successfulCredential = successfulCredential
c.iterating = false
c.cond.L.Unlock()
c.cond.Broadcast()
}
// err is the error returned by the last GetToken call. It will be nil when that call succeeds
if err != nil {
// return credentialUnavailableError iff all sources did so; return AuthenticationFailedError otherwise
msg := createChainedErrorMessage(errs)
if _, ok := err.(credentialUnavailableError); ok {
err = newCredentialUnavailableError(c.name, msg)
} else {
res := getResponseFromError(err)
msg := createChainedErrorMessage(errs)
return nil, newAuthenticationFailedError(c.name, msg, res)
err = newAuthenticationFailedError(c.name, msg, res)
}
}
// if we get here, all credentials returned credentialUnavailableError
msg := createChainedErrorMessage(errs)
return nil, newCredentialUnavailableError(c.name, msg)
return token, err
}

func createChainedErrorMessage(errs []error) string {
Expand Down
Loading

0 comments on commit 4cfc2d1

Please sign in to comment.