Skip to content

Commit

Permalink
Support Azure ML managed identity (#21851)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell committed Nov 3, 2023
1 parent 6da3b75 commit 1922a11
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 3 deletions.
1 change: 1 addition & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 1.5.0-beta.2 (Unreleased)

### Features Added
* `DefaultAzureCredential` and `ManagedIdentityCredential` support Azure ML managed identity

### Breaking Changes

Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/azidentity",
"Tag": "go/azidentity_ae45facec3"
"Tag": "go/azidentity_db4a26f583"
}
22 changes: 22 additions & 0 deletions sdk/azidentity/live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ const (
recordingDirectory = "sdk/azidentity/testdata"
azidentityRunManualTests = "AZIDENTITY_RUN_MANUAL_TESTS"
fakeClientID = "fake-client-id"
fakeMIEndpoint = "https://fake.local"
fakeResourceID = "/fake/resource/ID"
fakeTenantID = "fake-tenant"
fakeUsername = "fake@user"
fakeAdfsAuthority = "fake.adfs.local"
fakeAdfsScope = "fake.adfs.local/fake-scope/.default"
liveTestScope = "https://management.core.windows.net//.default"
redacted = "redacted"
)

var adfsLiveSP = struct {
Expand Down Expand Up @@ -159,6 +161,9 @@ func run(m *testing.M) int {
strings.TrimPrefix(adfsScope, "https://"): fakeAdfsScope,
strings.TrimPrefix(adfsAuthority, "https://"): fakeAdfsAuthority,
}
if id := os.Getenv(defaultIdentityClientID); id != "" {
pathVars[id] = fakeClientID
}
for target, replacement := range pathVars {
if target != "" {
err := recording.AddURISanitizer(replacement, target, nil)
Expand All @@ -184,6 +189,23 @@ func run(m *testing.M) int {
if err != nil {
panic(err)
}
// some managed identity requests include a "secret" header. It isn't dangerous
// to record the value, however it must be static for matching to work in playback
err = recording.AddHeaderRegexSanitizer("secret", redacted, "", nil)
if err != nil {
panic(err)
}
if url, ok := os.LookupEnv(msiEndpoint); ok {
err = recording.AddURISanitizer(fakeMIEndpoint, url, nil)
if err == nil {
if clientID, ok := os.LookupEnv(defaultIdentityClientID); ok {
err = recording.AddURISanitizer(fakeClientID, clientID, nil)
}
}
if err != nil {
panic(err)
}
}
// redact secrets returned by Microsoft Entra ID
for _, key := range []string{"access_token", "device_code", "message", "refresh_token", "user_code"} {
err = recording.AddBodyKeySanitizer("$."+key, "redacted", "", nil)
Expand Down
37 changes: 35 additions & 2 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ import (

const (
arcIMDSEndpoint = "IMDS_ENDPOINT"
defaultIdentityClientID = "DEFAULT_IDENTITY_CLIENT_ID"
identityEndpoint = "IDENTITY_ENDPOINT"
identityHeader = "IDENTITY_HEADER"
identityServerThumbprint = "IDENTITY_SERVER_THUMBPRINT"
headerMetadata = "Metadata"
imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
msiEndpoint = "MSI_ENDPOINT"
msiSecret = "MSI_SECRET"
imdsAPIVersion = "2018-02-01"
azureArcAPIVersion = "2019-08-15"
serviceFabricAPIVersion = "2019-07-01-preview"
Expand All @@ -47,6 +49,7 @@ type msiType int
const (
msiTypeAppService msiType = iota
msiTypeAzureArc
msiTypeAzureML
msiTypeCloudShell
msiTypeIMDS
msiTypeServiceFabric
Expand Down Expand Up @@ -135,9 +138,14 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
c.msiType = msiTypeAzureArc
}
} else if endpoint, ok := os.LookupEnv(msiEndpoint); ok {
env = "Cloud Shell"
c.endpoint = endpoint
c.msiType = msiTypeCloudShell
if _, ok := os.LookupEnv(msiSecret); ok {
env = "Azure ML"
c.msiType = msiTypeAzureML
} else {
env = "Cloud Shell"
c.msiType = msiTypeCloudShell
}
} else {
setIMDSRetryOptionDefaults(&cp.Retry)
}
Expand Down Expand Up @@ -247,6 +255,8 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err)
}
return c.createAzureArcAuthRequest(ctx, id, scopes, key)
case msiTypeAzureML:
return c.createAzureMLAuthRequest(ctx, id, scopes)
case msiTypeServiceFabric:
return c.createServiceFabricAuthRequest(ctx, id, scopes)
case msiTypeCloudShell:
Expand Down Expand Up @@ -296,6 +306,29 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context,
return request, nil
}

func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
return nil, err
}
request.Raw().Header.Set("secret", os.Getenv(msiSecret))
q := request.Raw().URL.Query()
q.Add("api-version", "2017-09-01")
q.Add("resource", strings.Join(scopes, " "))
q.Add("clientid", os.Getenv(defaultIdentityClientID))
if id != nil {
if id.idKind() == miResourceID {
log.Write(EventAuthentication, "WARNING: Azure ML doesn't support specifying a managed identity by resource ID")
q.Set("clientid", "")
q.Set(qpResID, id.String())
} else {
q.Set("clientid", id.String())
}
}
request.Raw().URL.RawQuery = q.Encode()
return request, nil
}

func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
Expand Down
28 changes: 28 additions & 0 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,34 @@ func TestManagedIdentityCredential_AzureArcErrors(t *testing.T) {
})
}

func TestManagedIdentityCredential_AzureMLLive(t *testing.T) {
switch recording.GetRecordMode() {
case recording.LiveMode:
t.Skip("this test doesn't run in live mode because it can't pass in CI")
case recording.PlaybackMode:
t.Setenv(defaultIdentityClientID, fakeClientID)
t.Setenv(msiEndpoint, fakeMIEndpoint)
t.Setenv(msiSecret, redacted)
case recording.RecordingMode:
missing := []string{}
for _, v := range []string{defaultIdentityClientID, msiEndpoint, msiSecret} {
if len(os.Getenv(v)) == 0 {
missing = append(missing, v)
}
}
if len(missing) > 0 {
t.Skip("no value for " + strings.Join(missing, ", "))
}
}
opts, stop := initRecording(t)
defer stop()
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ClientOptions: opts})
if err != nil {
t.Fatal(err)
}
testGetTokenSuccess(t, cred)
}

func TestManagedIdentityCredential_CloudShell(t *testing.T) {
validateReq := func(req *http.Request) *http.Response {
err := req.ParseForm()
Expand Down

0 comments on commit 1922a11

Please sign in to comment.