Skip to content

Commit

Permalink
Merge pull request #412 from AzureAD/release-1.0.0
Browse files Browse the repository at this point in the history
MSAL Go 1.0.0
  • Loading branch information
rayluo committed Apr 20, 2023
2 parents 8801762 + 8e66327 commit 4d3329f
Show file tree
Hide file tree
Showing 10 changed files with 244 additions and 176 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Acquiring tokens with MSAL Go follows this general three step pattern. There mig
* Initializing a public client:

```go
publicClientApp, err := public.New("client_id", public.WithAuthority("https://login.microsoftonline.com/Enter_The_Tenant_Name_Here"))
publicClientApp, err := public.New("client_id", public.WithAuthority("https://login.microsoft.com/Enter_The_Tenant_Name_Here"))
```

* Initializing a confidential client:
Expand All @@ -54,7 +54,7 @@ Acquiring tokens with MSAL Go follows this general three step pattern. There mig
if err != nil {
return nil, fmt.Errorf("could not create a cred from a secret: %w", err)
}
confidentialClientApp, err := confidential.New("client_id", cred, confidential.WithAuthority("https://login.microsoftonline.com/Enter_The_Tenant_Name_Here"))
confidentialClientApp, err := confidential.New("client_id", cred, confidential.WithAuthority("https://login.microsoft.com/Enter_The_Tenant_Name_Here"))
```

1. MSAL comes packaged with an in-memory cache. Utilizing the cache is optional, but we would highly recommend it.
Expand Down
234 changes: 91 additions & 143 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/url"
"reflect"
"strings"
"sync"
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
Expand All @@ -27,31 +28,21 @@ const (
)

// manager provides an internal cache. It is defined to allow faking the cache in tests.
// In all production use it is a *storage.Manager.
// In production it's a *storage.Manager or *storage.PartitionedManager.
type manager interface {
Read(ctx context.Context, authParameters authority.AuthParams, account shared.Account) (storage.TokenResponse, error)
Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error)
cache.Serializer
Read(context.Context, authority.AuthParams) (storage.TokenResponse, error)
Write(authority.AuthParams, accesstokens.TokenResponse) (shared.Account, error)
}

// accountManager is a manager that also caches accounts. In production it's a *storage.Manager.
type accountManager interface {
manager
AllAccounts() []shared.Account
Account(homeAccountID string) shared.Account
RemoveAccount(account shared.Account, clientID string)
}

// partitionedManager provides an internal cache. It is defined to allow faking the cache in tests.
// In all production use it is a *storage.PartitionedManager.
type partitionedManager interface {
Read(ctx context.Context, authParameters authority.AuthParams) (storage.TokenResponse, error)
Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error)
}

type noopCacheAccessor struct{}

func (n noopCacheAccessor) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error {
return nil
}
func (n noopCacheAccessor) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error {
return nil
}

// AcquireTokenSilentParameters contains the parameters to acquire a token silently (from cache).
type AcquireTokenSilentParameters struct {
Scopes []string
Expand Down Expand Up @@ -137,12 +128,14 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco
// Client is a base client that provides access to common methods and primatives that
// can be used by multiple clients.
type Client struct {
Token *oauth.Client
manager manager // *storage.Manager or fakeManager in tests
pmanager partitionedManager // *storage.PartitionedManager or fakeManager in tests

AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New().
cacheAccessor cache.ExportReplace
Token *oauth.Client
manager accountManager // *storage.Manager or fakeManager in tests
// pmanager is a partitioned cache for OBO authentication. *storage.PartitionedManager or fakeManager in tests
pmanager manager

AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New().
cacheAccessor cache.ExportReplace
cacheAccessorMu *sync.RWMutex
}

// Option is an optional argument to the New constructor.
Expand Down Expand Up @@ -214,11 +207,11 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O
}
authParams := authority.NewAuthParams(clientID, authInfo)
client := Client{ // Note: Hey, don't even THINK about making Base into *Base. See "design notes" in public.go and confidential.go
Token: token,
AuthParams: authParams,
cacheAccessor: noopCacheAccessor{},
manager: storage.New(token),
pmanager: storage.NewPartitionedManager(token),
Token: token,
AuthParams: authParams,
cacheAccessorMu: &sync.RWMutex{},
manager: storage.New(token),
pmanager: storage.NewPartitionedManager(token),
}
for _, o := range options {
if err = o(&client); err != nil {
Expand Down Expand Up @@ -283,8 +276,9 @@ func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, s
return baseURL.String(), nil
}

func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (ar AuthResult, err error) {
// when tenant == "", the caller didn't specify a tenant and WithTenant will use the client's configured tenant
func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) {
ar := AuthResult{}
// when tenant == "", the caller didn't specify a tenant and WithTenant will choose the client's configured tenant
tenant := silent.TenantID
authParams, err := b.AuthParams.WithTenant(tenant)
if err != nil {
Expand All @@ -296,38 +290,23 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
authParams.Claims = silent.Claims
authParams.UserAssertion = silent.UserAssertion

var storageTokenResponse storage.TokenResponse
if authParams.AuthorizationType == authority.ATOnBehalfOf {
if s, ok := b.pmanager.(cache.Serializer); ok {
suggestedCacheKey := authParams.CacheKey(silent.IsAppCache)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
storageTokenResponse, err = b.pmanager.Read(ctx, authParams)
if err != nil {
return ar, err
}
} else {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := authParams.CacheKey(silent.IsAppCache)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
m := b.pmanager
if authParams.AuthorizationType != authority.ATOnBehalfOf {
authParams.AuthorizationType = authority.ATRefreshToken
storageTokenResponse, err = b.manager.Read(ctx, authParams, silent.Account)
if err != nil {
return ar, err
}
m = b.manager
}
if b.cacheAccessor != nil {
key := authParams.CacheKey(silent.IsAppCache)
b.cacheAccessorMu.RLock()
err = b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key})
b.cacheAccessorMu.RUnlock()
}
if err != nil {
return ar, err
}
storageTokenResponse, err := m.Read(ctx, authParams)
if err != nil {
return ar, err
}

// ignore cached access tokens when given claims
Expand All @@ -340,21 +319,17 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen

// redeem a cached refresh token, if available
if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() {
err = errors.New("no token found")
return ar, err
return ar, errors.New("no token found")
}
var cc *accesstokens.Credential
if silent.RequestType == accesstokens.ATConfidential {
cc = silent.Credential
}

token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken)
if err != nil {
return ar, err
}

ar, err = b.AuthResultFromToken(ctx, authParams, token, true)
return ar, err
return b.AuthResultFromToken(ctx, authParams, token, true)
}

func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) {
Expand Down Expand Up @@ -417,103 +392,76 @@ func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams Acq
return ar, err
}

func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (ar AuthResult, err error) {
func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) {
if !cacheWrite {
return NewAuthResult(token, shared.Account{})
}

var account shared.Account
var m manager = b.manager
if authParams.AuthorizationType == authority.ATOnBehalfOf {
if s, ok := b.pmanager.(cache.Serializer); ok {
suggestedCacheKey := token.CacheKey(authParams)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
account, err = b.pmanager.Write(authParams, token)
m = b.pmanager
}
key := token.CacheKey(authParams)
if b.cacheAccessor != nil {
b.cacheAccessorMu.Lock()
defer b.cacheAccessorMu.Unlock()
err := b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return ar, err
}
} else {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := token.CacheKey(authParams)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
account, err = b.manager.Write(authParams, token)
if err != nil {
return ar, err
return AuthResult{}, err
}
}
ar, err = NewAuthResult(token, account)
account, err := m.Write(authParams, token)
if err != nil {
return AuthResult{}, err
}
ar, err := NewAuthResult(token, account)
if err == nil && b.cacheAccessor != nil {
err = b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key})
}
return ar, err
}

func (b Client) AllAccounts(ctx context.Context) (accts []shared.Account, err error) {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) {
if b.cacheAccessor != nil {
b.cacheAccessorMu.RLock()
defer b.cacheAccessorMu.RUnlock()
key := b.AuthParams.CacheKey(false)
err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return accts, err
return nil, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}

accts = b.manager.AllAccounts()
return accts, err
return b.manager.AllAccounts(), nil
}

func (b Client) Account(ctx context.Context, homeAccountID string) (acct shared.Account, err error) {
authParams := b.AuthParams // This is a copy, as we dont' have a pointer receiver and .AuthParams is not a pointer.
authParams.AuthorizationType = authority.AccountByID
authParams.HomeAccountID = homeAccountID
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
func (b Client) Account(ctx context.Context, homeAccountID string) (shared.Account, error) {
if b.cacheAccessor != nil {
b.cacheAccessorMu.RLock()
defer b.cacheAccessorMu.RUnlock()
authParams := b.AuthParams // This is a copy, as we don't have a pointer receiver and .AuthParams is not a pointer.
authParams.AuthorizationType = authority.AccountByID
authParams.HomeAccountID = homeAccountID
key := b.AuthParams.CacheKey(false)
err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return acct, err
return shared.Account{}, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
acct = b.manager.Account(homeAccountID)
return acct, err
return b.manager.Account(homeAccountID), nil
}

// RemoveAccount removes all the ATs, RTs and IDTs from the cache associated with this account.
func (b Client) RemoveAccount(ctx context.Context, account shared.Account) (err error) {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
func (b Client) RemoveAccount(ctx context.Context, account shared.Account) error {
if b.cacheAccessor == nil {
b.manager.RemoveAccount(account, b.AuthParams.ClientID)
return nil
}
b.manager.RemoveAccount(account, b.AuthParams.ClientID)
return err
}

// export helps other methods defer exporting the cache after possibly updating its in-memory content.
// err is the error the calling method will return. If err isn't nil, export returns it without
// exporting the cache.
func (b Client) export(ctx context.Context, marshal cache.Marshaler, key string, err error) error {
b.cacheAccessorMu.Lock()
defer b.cacheAccessorMu.Unlock()
key := b.AuthParams.CacheKey(false)
err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return err
}
return b.cacheAccessor.Export(ctx, marshal, cache.ExportHints{PartitionKey: key})
b.manager.RemoveAccount(account, b.AuthParams.ClientID)
return b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key})
}

0 comments on commit 4d3329f

Please sign in to comment.