Skip to content

Commit

Permalink
Fix MSAL caching (#2300)
Browse files Browse the repository at this point in the history
This change updates our MSAL cache adapter to no longer use the hinted partition key as a cache key. Instead, the current user has a fixed cache key equal to the empty string `""`. This creates the behavior of `~/.azd/auth/msal/cache.[bin|json]` being the single file that contains MSAL multi-account data, as defined by this [contract](https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/27c98c8f9db6bc564c5be43677f3e6276b7c4fef/apps/internal/base/internal/storage/items.go#L18).

Also, fix logout not resetting `cache.json` due to using `config.json` (userConfigManger.Load) and not `auth.json` (readAuthConfig).

Fixes #2299
  • Loading branch information
weikanglim committed May 25, 2023
1 parent e6ece84 commit 8407ba5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 38 deletions.
21 changes: 17 additions & 4 deletions cli/azd/pkg/auth/cache.go
Expand Up @@ -10,13 +10,26 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
)

// The MSAL cache key for the current user. The stored MSAL cached data contains
// all accounts with stored credentials, across all tenants.
// Currently, the underlying MSAL cache data is represented as [Contract] inside the library.
//
// For simplicity in naming the final cached file, which has a unique directory (see [fileCache]),
// and for historical purposes, we use empty string as the key.
//
// It may be tempting to instead use the partition key provided by [cache.ReplaceHints],
// but note that the key is a partitioning key and not a unique user key.
// Also, given that the data contains auth data for all users, we only need a single key
// to store all cached auth information.
const cCurrentUserCacheKey = ""

// msalCacheAdapter adapts our interface to the one expected by cache.ExportReplace.
type msalCacheAdapter struct {
cache Cache
}

func (a *msalCacheAdapter) Replace(ctx context.Context, cache cache.Unmarshaler, cacheHints cache.ReplaceHints) error {
val, err := a.cache.Read(cacheHints.PartitionKey)
func (a *msalCacheAdapter) Replace(ctx context.Context, cache cache.Unmarshaler, _ cache.ReplaceHints) error {
val, err := a.cache.Read(cCurrentUserCacheKey)
if errors.Is(err, errCacheKeyNotFound) {
return nil
} else if err != nil {
Expand All @@ -30,13 +43,13 @@ func (a *msalCacheAdapter) Replace(ctx context.Context, cache cache.Unmarshaler,
return nil
}

func (a *msalCacheAdapter) Export(ctx context.Context, cache cache.Marshaler, cacheHints cache.ExportHints) error {
func (a *msalCacheAdapter) Export(ctx context.Context, cache cache.Marshaler, _ cache.ExportHints) error {
val, err := cache.Marshal()
if err != nil {
return err
}

return a.cache.Set(cacheHints.PartitionKey, val)
return a.cache.Set(cCurrentUserCacheKey, val)
}

type Cache interface {
Expand Down
69 changes: 36 additions & 33 deletions cli/azd/pkg/auth/cache_test.go
Expand Up @@ -5,66 +5,69 @@ package auth

import (
"context"
"math/rand"
"testing"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/stretchr/testify/require"
)

var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")

func randSeq(n int, rng rand.Rand) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rng.Intn(len(letters))]
}
return string(b)
}

func TestCache(t *testing.T) {
root := t.TempDir()
ctx := context.Background()
c := newCache(root)
// weak rng is fine for testing
//nolint:gosec
rng := rand.New(rand.NewSource(0))

d1 := fixedMarshaller{
val: []byte("some data"),
key := func() string {
return randSeq(10, *rng)
}

d2 := fixedMarshaller{
val: []byte("some different data"),
data := fixedMarshaller{
val: []byte("some data"),
}

// write some data.
err := c.Export(ctx, &d1, cache.ExportHints{PartitionKey: "d1"})
err := c.Export(ctx, &data, cache.ExportHints{PartitionKey: key()})
require.NoError(t, err)
err = c.Export(ctx, &d2, cache.ExportHints{PartitionKey: "d2"})
require.NoError(t, err)

var r1 fixedMarshaller
var r2 fixedMarshaller

// read back that data we wrote.
err = c.Replace(ctx, &r1, cache.ReplaceHints{PartitionKey: "d1"})
var reader fixedMarshaller
err = c.Replace(ctx, &reader, cache.ReplaceHints{PartitionKey: key()})
require.NoError(t, err)
err = c.Replace(ctx, &r2, cache.ReplaceHints{PartitionKey: "d2"})
require.NoError(t, err)

require.NotNil(t, r1.val)
require.NotNil(t, r2.val)
require.Equal(t, d1.val, r1.val)
require.Equal(t, d2.val, r2.val)
require.NotNil(t, reader.val)
require.Equal(t, data.val, reader.val)

// the data should be shared across instances.
c = newCache(root)

err = c.Replace(ctx, &r1, cache.ReplaceHints{PartitionKey: "d1"})
require.NoError(t, err)
err = c.Replace(ctx, &r2, cache.ReplaceHints{PartitionKey: "d2"})
reader = fixedMarshaller{}
err = c.Replace(ctx, &reader, cache.ReplaceHints{PartitionKey: key()})
require.NoError(t, err)
require.Equal(t, data.val, reader.val)

require.NotNil(t, r1.val)
require.NotNil(t, r2.val)
require.Equal(t, d1.val, r1.val)
require.Equal(t, d2.val, r2.val)

// read some non-existing data
nonExist := fixedMarshaller{
val: []byte("some data"),
// update existing data
otherData := fixedMarshaller{
val: []byte("other data"),
}
err = c.Replace(ctx, &nonExist, cache.ReplaceHints{PartitionKey: "nonExist"})
err = c.Export(ctx, &otherData, cache.ExportHints{PartitionKey: key()})
require.NoError(t, err)

// read back data
err = c.Replace(ctx, &reader, cache.ReplaceHints{PartitionKey: key()})
require.NoError(t, err)
// data should not be overwritten when key is not found.
require.Equal(t, []byte("some data"), nonExist.val)
require.NotNil(t, reader.val)
require.Equal(t, otherData.val, reader.val)
}

func TestCredentialCache(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion cli/azd/pkg/auth/manager.go
Expand Up @@ -585,7 +585,7 @@ func (m *Manager) saveLoginForServicePrincipal(tenantId, clientId string, secret
// getSignedInAccount fetches the public.Account for the signed in user, or nil if one does not exist
// (e.g when logged in with a service principal).
func (m *Manager) getSignedInAccount(ctx context.Context) (*public.Account, error) {
cfg, err := m.userConfigManager.Load()
cfg, err := m.readAuthConfig()
if err != nil {
return nil, fmt.Errorf("fetching current user: %w", err)
}
Expand Down

0 comments on commit 8407ba5

Please sign in to comment.