diff --git a/sdk/ml/azure-ai-ml/tests/test_configs/workspace/workspace_full_rest_response.json b/sdk/ml/azure-ai-ml/tests/test_configs/workspace/workspace_full_rest_response.json new file mode 100644 index 000000000000..7dec272e8a2a --- /dev/null +++ b/sdk/ml/azure-ai-ml/tests/test_configs/workspace/workspace_full_rest_response.json @@ -0,0 +1,41 @@ +{ + "id": "/subscriptions/sub-id/test_workspace/providers/Microsoft.Storage/storageAccounts/storage-account-name", + "name": "test_workspace", + "location": "test_location", + "kind": "test_kind", + "sku": {"name": "test_sku"}, + "tags": {"test_tag": "test_value"}, + "properties": { + "description": "test_description", + "friendlyName": "test_friendly_name", + "discoveryUrl": "test_discovery_url", + "hbiWorkspace": true, + "storageAccount": "test_storage_account", + "containerRegistry": "test_container_registry", + "keyVault": "test_key_vault", + "applicationInsights": "test_application_insights", + "imageBuildCompute": "test_image_build_compute", + "publicNetworkAccess": "Enabled", + "enableDataIsolation": true, + "allowPublicAccessWhenBehindVnet": true, + "primaryUserAssignedIdentity": "test_primary_user_assigned_identity", + "encryption": { + "status": "Enabled", + "keyVaultProperties": { + "keyIdentifier": "key_identifier", + "keyVaultArmId": "key_vault_arm_id" + } + }, + "hubResourceId": "hub_resource_id", + "workspaceId": "workspace_id", + "mlFlowTrackingUri": "ml_flow_tracking_uri", + "allowRoleAssignmentOnRG": true, + "systemDatastoresAuthMode": "AccessKey", + "managedNetwork": {}, + "provisionNetworkNow": true, + "featureStoreSettings": {}, + "serverlessComputeSettings": {}, + "networkAcls": {} + }, + "identity": {"type": "ManagedServiceIdentityType"} +} diff --git a/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_entity.py b/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_entity.py index 61e94175bb58..4584f73b1989 100644 --- a/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_entity.py +++ b/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_entity.py @@ -1,10 +1,13 @@ from typing import Optional import pytest +import json from marshmallow.exceptions import ValidationError from azure.ai.ml import load_workspace -from azure.ai.ml._restclient.v2024_10_01_preview.models import Workspace +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + Workspace as RestWorkspace, +) from azure.ai.ml.constants._workspace import FirewallSku, IsolationMode from azure.ai.ml.entities import ServerlessComputeSettings, Workspace @@ -38,6 +41,95 @@ def test_serverless_compute_settings_loaded_from_rest_object( else: assert ServerlessComputeSettings._from_rest_object(rest_object.serverless_compute_settings) == settings + def test_from_rest_object(self) -> None: + with open("./tests/test_configs/workspace/workspace_full_rest_response.json", "r") as f: + rest_object = RestWorkspace.deserialize(json.load(f)) + + workspace = Workspace._from_rest_object(rest_object) + + assert ( + workspace.id + == "/subscriptions/sub-id/test_workspace/providers/Microsoft.Storage/storageAccounts/storage-account-name" + ) + assert workspace.name == "test_workspace" + assert workspace.location == "test_location" + assert workspace.description == "test_description" + assert workspace.tags == {"test_tag": "test_value"} + assert workspace.display_name == "test_friendly_name" + assert workspace.discovery_url == "test_discovery_url" + assert workspace.resource_group == "providers" + assert workspace.storage_account == "test_storage_account" + assert workspace.key_vault == "test_key_vault" + assert workspace.application_insights == "test_application_insights" + assert workspace.container_registry == "test_container_registry" + assert workspace.customer_managed_key.key_uri == "key_identifier" + assert workspace.customer_managed_key.key_vault == "key_vault_arm_id" + assert workspace.hbi_workspace is True + assert workspace.public_network_access == "Enabled" + assert workspace.image_build_compute == "test_image_build_compute" + assert workspace.discovery_url == "test_discovery_url" + assert workspace.mlflow_tracking_uri == "ml_flow_tracking_uri" + assert workspace.primary_user_assigned_identity == "test_primary_user_assigned_identity" + assert workspace.system_datastores_auth_mode == "AccessKey" + assert workspace.enable_data_isolation == True + assert workspace.allow_roleassignment_on_rg == True + assert workspace._hub_id == "hub_resource_id" + assert workspace._kind == "project" + assert workspace._workspace_id == "workspace_id" + assert workspace.identity is not None + assert workspace.managed_network is not None + assert workspace._feature_store_settings is not None + assert workspace.network_acls is not None + assert workspace.provision_network_now == True + assert workspace.serverless_compute is not None + assert workspace.network_acls is not None + + def test_from_rest_object_for_attributes_none(self) -> None: + with open("./tests/test_configs/workspace/workspace_full_rest_response.json", "r") as f: + rest_json = json.load(f) + del rest_json["properties"]["managedNetwork"] + del rest_json["properties"]["encryption"] + rest_json["id"] = "/subscriptions/sub-id" + del rest_json["identity"] + del rest_json["properties"]["featureStoreSettings"] + del rest_json["properties"]["serverlessComputeSettings"] + del rest_json["properties"]["networkAcls"] + rest_object = RestWorkspace.deserialize(rest_json) + + workspace = Workspace._from_rest_object(rest_object) + + assert workspace.id == "/subscriptions/sub-id" + assert workspace.name == "test_workspace" + assert workspace.location == "test_location" + assert workspace.description == "test_description" + assert workspace.tags == {"test_tag": "test_value"} + assert workspace.display_name == "test_friendly_name" + assert workspace.discovery_url == "test_discovery_url" + assert workspace.resource_group is None + assert workspace.storage_account == "test_storage_account" + assert workspace.key_vault == "test_key_vault" + assert workspace.application_insights == "test_application_insights" + assert workspace.container_registry == "test_container_registry" + assert workspace.customer_managed_key is None + assert workspace.hbi_workspace is True + assert workspace.public_network_access == "Enabled" + assert workspace.image_build_compute == "test_image_build_compute" + assert workspace.discovery_url == "test_discovery_url" + assert workspace.mlflow_tracking_uri == "ml_flow_tracking_uri" + assert workspace.primary_user_assigned_identity == "test_primary_user_assigned_identity" + assert workspace.system_datastores_auth_mode == "AccessKey" + assert workspace.enable_data_isolation == True + assert workspace.allow_roleassignment_on_rg == True + assert workspace._hub_id == "hub_resource_id" + assert workspace._kind == "project" + assert workspace._workspace_id == "workspace_id" + assert workspace.identity is None + assert workspace.managed_network is None + assert workspace._feature_store_settings is None + assert workspace.network_acls is None + assert workspace.provision_network_now == True + assert workspace.serverless_compute is None + def test_serverless_compute_settings_subnet_name_must_be_an_arm_id(self) -> None: with pytest.raises(ValidationError): ServerlessComputeSettings(custom_subnet="justaname", no_public_ip=True) @@ -65,11 +157,17 @@ def test_serverless_compute_settings_subnet_name_must_be_an_arm_id(self) -> None ) def test_workspace_load_override_serverless(self, settings: ServerlessComputeSettings) -> None: params_override = [ - {"serverless_compute": {"custom_subnet": settings.custom_subnet, "no_public_ip": settings.no_public_ip}} + { + "serverless_compute": { + "custom_subnet": settings.custom_subnet, + "no_public_ip": settings.no_public_ip, + } + } ] workspace_override = load_workspace( - "./tests/test_configs/workspace/workspace_serverless.yaml", params_override=params_override + "./tests/test_configs/workspace/workspace_serverless.yaml", + params_override=params_override, ) assert workspace_override.serverless_compute == settings