Skip to content

Commit

Permalink
[Identity] Add Azure Arc key validation checks
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
  • Loading branch information
pvaneck committed Jun 10, 2024
1 parent f07513c commit e16a704
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 16 deletions.
6 changes: 6 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Release History

## 1.16.1 (2024-06-11)

### Bugs Fixed

- Managed identity bug fixes

## 1.16.0 (2024-04-09)

### Other Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# ------------------------------------
import functools
import os
import sys
from typing import Any, Dict, Optional

from azure.core.exceptions import ClientAuthenticationError
Expand All @@ -24,7 +25,7 @@ def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]:
return ManagedIdentityClient(
_per_retry_policies=[ArcChallengeAuthPolicy()],
request_factory=functools.partial(_get_request, url),
**kwargs
**kwargs,
)
return None

Expand Down Expand Up @@ -70,6 +71,12 @@ def _get_secret_key(response: PipelineResponse) -> str:
raise ClientAuthenticationError(
message="Did not receive a correct value from WWW-Authenticate header: {}".format(header)
) from ex

try:
_validate_key_file(key_file)
except ValueError as ex:
raise ClientAuthenticationError(message="The key file path is invalid: {}".format(ex)) from ex

with open(key_file, "r", encoding="utf-8") as file:
try:
return file.read()
Expand All @@ -80,6 +87,53 @@ def _get_secret_key(response: PipelineResponse) -> str:
) from error


def _get_key_file_path() -> str:
"""Returns the expected path for the Azure Arc MSI key file based on the current platform.
Only Linux and Windows are supported.
:return: The expected path.
:rtype: str
:raises ValueError: If the current platform is not supported.
"""
if sys.platform.startswith("linux"):
return "/var/opt/azcmagent/tokens"
if sys.platform.startswith("win"):
program_data_path = os.environ.get("PROGRAMDATA")
if not program_data_path:
raise ValueError("PROGRAMDATA environment variable is not set or is empty.")
return os.path.join(f"{program_data_path}", "AzureConnectedMachineAgent", "Tokens")
raise ValueError(f"Azure Arc MSI is not supported on this platform {sys.platform}")


def _validate_key_file(file_path: str) -> None:
"""Validates that a given Azure Arc MSI file path is valid for use.
A valid file will:
1. Be in the expected path for the current platform.
2. Have a `.key` extension.
3. Be at most 4096 bytes in size.
:param str file_path: The path to the key file.
:raises ClientAuthenticationError: If the file path is invalid.
"""
if not file_path:
raise ValueError("The file path must not be empty.")

if not os.path.exists(file_path):
raise ValueError(f"The file path does not exist: {file_path}")

expected_directory = _get_key_file_path()
if not os.path.dirname(file_path) == expected_directory:
raise ValueError(f"Unexpected file path from HIMDS service: {file_path}")

if not file_path.endswith(".key"):
raise ValueError("The file path must have a '.key' extension.")

if os.path.getsize(file_path) > 4096:
raise ValueError("The file size must be less than or equal to 4096 bytes.")


class ArcChallengeAuthPolicy(HTTPPolicy):
"""Policy for handling Azure Arc's challenge authentication"""

Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/azure/identity/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
VERSION = "1.16.0"
VERSION = "1.16.1"
132 changes: 124 additions & 8 deletions sdk/identity/azure-identity/tests/test_managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import os
import sys
import time

try:
Expand Down Expand Up @@ -883,9 +884,10 @@ def test_azure_arc(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
token = ManagedIdentityCredential(transport=transport).get_token(scope)
assert token.token == access_token
assert token.expires_on == expires_on
with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
token = ManagedIdentityCredential(transport=transport).get_token(scope)
assert token.token == access_token
assert token.expires_on == expires_on


def test_azure_arc_tenant_id(tmpdir):
Expand Down Expand Up @@ -936,9 +938,10 @@ def test_azure_arc_tenant_id(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
assert token.token == access_token
assert token.expires_on == expires_on
with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
assert token.token == access_token
assert token.expires_on == expires_on


def test_azure_arc_client_id():
Expand All @@ -950,10 +953,123 @@ def test_azure_arc_client_id():
EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42",
},
):
credential = ManagedIdentityCredential(client_id="some-guid")
with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
credential = ManagedIdentityCredential(client_id="some-guid")

with pytest.raises(ClientAuthenticationError):
with pytest.raises(ClientAuthenticationError) as ex:
credential.get_token("scope")
assert "not supported" in str(ex.value)


def test_azure_arc_key_too_large(tmp_path):

api_version = "2019-11-01"
identity_endpoint = "http://localhost:42/token"
imds_endpoint = "http://localhost:42"
scope = "scope"
secret_key = "X" * 4097

key_file = tmp_path / "key_file.key"
key_file.write_text(secret_key)
assert key_file.read_text() == secret_key

transport = validating_transport(
requests=[
Request(
base_url=identity_endpoint,
method="GET",
required_headers={"Metadata": "true"},
required_params={"api-version": api_version, "resource": scope},
),
],
responses=[
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
],
)

with mock.patch(
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
with pytest.raises(ClientAuthenticationError) as ex:
ManagedIdentityCredential(transport=transport).get_token(scope)
assert "file size" in str(ex.value)


def test_azure_arc_key_not_exist(tmp_path):

api_version = "2019-11-01"
identity_endpoint = "http://localhost:42/token"
imds_endpoint = "http://localhost:42"
scope = "scope"

transport = validating_transport(
requests=[
Request(
base_url=identity_endpoint,
method="GET",
required_headers={"Metadata": "true"},
required_params={"api-version": api_version, "resource": scope},
),
],
responses=[
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=/path/to/key_file"}),
],
)

with mock.patch(
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
with pytest.raises(ClientAuthenticationError) as ex:
ManagedIdentityCredential(transport=transport).get_token(scope)
assert "not exist" in str(ex.value)


def test_azure_arc_key_invalid(tmp_path):

api_version = "2019-11-01"
identity_endpoint = "http://localhost:42/token"
imds_endpoint = "http://localhost:42"
scope = "scope"
key_file = tmp_path / "key_file.txt"
key_file.write_text("secret")

transport = validating_transport(
requests=[
Request(
base_url=identity_endpoint,
method="GET",
required_headers={"Metadata": "true"},
required_params={"api-version": api_version, "resource": scope},
),
Request(
base_url=identity_endpoint,
method="GET",
required_headers={"Metadata": "true"},
required_params={"api-version": api_version, "resource": scope},
),
],
responses=[
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
],
)

with mock.patch(
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
with pytest.raises(ClientAuthenticationError) as ex:
ManagedIdentityCredential(transport=transport).get_token(scope)
assert "Unexpected file path" in str(ex.value)

with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
with pytest.raises(ClientAuthenticationError) as ex:
ManagedIdentityCredential(transport=transport).get_token(scope)
assert "extension" in str(ex.value)


def test_token_exchange(tmpdir):
Expand Down
Loading

0 comments on commit e16a704

Please sign in to comment.