Skip to content

Commit

Permalink
Enable user assigned identity for IMDS (#14667)
Browse files Browse the repository at this point in the history
* enable user assigned identity for IMDS

* update imds auth request test
  • Loading branch information
catalinaperalta committed May 13, 2021
1 parent 5ad85f3 commit 7be1698
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
9 changes: 6 additions & 3 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (c *managedIdentityClient) createAccessToken(res *azcore.Response) (*azcore
func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID string, scopes []string) (*azcore.Request, error) {
switch c.msiType {
case msiTypeIMDS:
return c.createIMDSAuthRequest(ctx, scopes)
return c.createIMDSAuthRequest(ctx, clientID, scopes)
case msiTypeAppServiceV20170901, msiTypeAppServiceV20190801:
return c.createAppServiceAuthRequest(ctx, clientID, scopes)
case msiTypeAzureArc:
Expand All @@ -162,7 +162,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID
}
}

func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, scopes []string) (*azcore.Request, error) {
func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, clientID string, scopes []string) (*azcore.Request, error) {
request, err := azcore.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
return nil, err
Expand All @@ -171,6 +171,9 @@ func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, scope
q := request.URL.Query()
q.Add("api-version", c.imdsAPIVersion)
q.Add("resource", strings.Join(scopes, " "))
if clientID != "" {
q.Add(qpClientID, clientID)
}
request.URL.RawQuery = q.Encode()
return request, nil
}
Expand Down Expand Up @@ -262,7 +265,7 @@ func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context,
data := url.Values{}
data.Set("resource", strings.Join(scopes, " "))
if clientID != "" {
data.Set("client_id", clientID)
data.Set(qpClientID, clientID)
}
dataEncoded := data.Encode()
body := azcore.NopCloser(strings.NewReader(dataEncoded))
Expand Down
5 changes: 4 additions & 1 deletion sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func TestManagedIdentityCredential_CreateIMDSAuthRequest(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
cred.client.endpoint = imdsEndpoint
req, err := cred.client.createIMDSAuthRequest(context.Background(), []string{msiScope})
req, err := cred.client.createIMDSAuthRequest(context.Background(), clientID, []string{msiScope})
if err != nil {
t.Fatal(err)
}
Expand All @@ -452,6 +452,9 @@ func TestManagedIdentityCredential_CreateIMDSAuthRequest(t *testing.T) {
if reqQueryParams["resource"][0] != msiScope {
t.Fatalf("Unexpected resource in resource query param")
}
if reqQueryParams["client_id"][0] != clientID {
t.Fatalf("Unexpected client ID. Expected: %s, Received: %s", clientID, reqQueryParams["client_id"][0])
}
if u := req.Request.URL.String(); !strings.HasPrefix(u, imdsEndpoint) {
t.Fatalf("Unexpected default authority host %s", u)
}
Expand Down

0 comments on commit 7be1698

Please sign in to comment.