Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"id": "/subscriptions/sub-id/test_workspace/providers/Microsoft.Storage/storageAccounts/storage-account-name",
"name": "test_workspace",
"location": "test_location",
"kind": "test_kind",
"sku": {"name": "test_sku"},
"tags": {"test_tag": "test_value"},
"properties": {
"description": "test_description",
"friendlyName": "test_friendly_name",
"discoveryUrl": "test_discovery_url",
"hbiWorkspace": true,
"storageAccount": "test_storage_account",
"containerRegistry": "test_container_registry",
"keyVault": "test_key_vault",
"applicationInsights": "test_application_insights",
"imageBuildCompute": "test_image_build_compute",
"publicNetworkAccess": "Enabled",
"enableDataIsolation": true,
"allowPublicAccessWhenBehindVnet": true,
"primaryUserAssignedIdentity": "test_primary_user_assigned_identity",
"encryption": {
"status": "Enabled",
"keyVaultProperties": {
"keyIdentifier": "key_identifier",
"keyVaultArmId": "key_vault_arm_id"
}
},
"hubResourceId": "hub_resource_id",
"workspaceId": "workspace_id",
"mlFlowTrackingUri": "ml_flow_tracking_uri",
"allowRoleAssignmentOnRG": true,
"systemDatastoresAuthMode": "AccessKey",
"managedNetwork": {},
"provisionNetworkNow": true,
"featureStoreSettings": {},
"serverlessComputeSettings": {},
"networkAcls": {}
},
"identity": {"type": "ManagedServiceIdentityType"}
}
104 changes: 101 additions & 3 deletions sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_entity.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Optional

import pytest
import json
from marshmallow.exceptions import ValidationError

from azure.ai.ml import load_workspace
from azure.ai.ml._restclient.v2024_10_01_preview.models import Workspace
from azure.ai.ml._restclient.v2024_10_01_preview.models import (
Workspace as RestWorkspace,
)
from azure.ai.ml.constants._workspace import FirewallSku, IsolationMode
from azure.ai.ml.entities import ServerlessComputeSettings, Workspace

Expand Down Expand Up @@ -38,6 +41,95 @@ def test_serverless_compute_settings_loaded_from_rest_object(
else:
assert ServerlessComputeSettings._from_rest_object(rest_object.serverless_compute_settings) == settings

def test_from_rest_object(self) -> None:
with open("./tests/test_configs/workspace/workspace_full_rest_response.json", "r") as f:
rest_object = RestWorkspace.deserialize(json.load(f))

workspace = Workspace._from_rest_object(rest_object)

assert (
workspace.id
== "/subscriptions/sub-id/test_workspace/providers/Microsoft.Storage/storageAccounts/storage-account-name"
)
assert workspace.name == "test_workspace"
assert workspace.location == "test_location"
assert workspace.description == "test_description"
assert workspace.tags == {"test_tag": "test_value"}
assert workspace.display_name == "test_friendly_name"
assert workspace.discovery_url == "test_discovery_url"
assert workspace.resource_group == "providers"
assert workspace.storage_account == "test_storage_account"
assert workspace.key_vault == "test_key_vault"
assert workspace.application_insights == "test_application_insights"
assert workspace.container_registry == "test_container_registry"
assert workspace.customer_managed_key.key_uri == "key_identifier"
assert workspace.customer_managed_key.key_vault == "key_vault_arm_id"
assert workspace.hbi_workspace is True
assert workspace.public_network_access == "Enabled"
assert workspace.image_build_compute == "test_image_build_compute"
assert workspace.discovery_url == "test_discovery_url"
assert workspace.mlflow_tracking_uri == "ml_flow_tracking_uri"
assert workspace.primary_user_assigned_identity == "test_primary_user_assigned_identity"
assert workspace.system_datastores_auth_mode == "AccessKey"
assert workspace.enable_data_isolation == True
assert workspace.allow_roleassignment_on_rg == True
assert workspace._hub_id == "hub_resource_id"
assert workspace._kind == "project"
assert workspace._workspace_id == "workspace_id"
assert workspace.identity is not None
assert workspace.managed_network is not None
assert workspace._feature_store_settings is not None
assert workspace.network_acls is not None
assert workspace.provision_network_now == True
assert workspace.serverless_compute is not None
assert workspace.network_acls is not None

def test_from_rest_object_for_attributes_none(self) -> None:
with open("./tests/test_configs/workspace/workspace_full_rest_response.json", "r") as f:
rest_json = json.load(f)
del rest_json["properties"]["managedNetwork"]
del rest_json["properties"]["encryption"]
rest_json["id"] = "/subscriptions/sub-id"
del rest_json["identity"]
del rest_json["properties"]["featureStoreSettings"]
del rest_json["properties"]["serverlessComputeSettings"]
del rest_json["properties"]["networkAcls"]
rest_object = RestWorkspace.deserialize(rest_json)

workspace = Workspace._from_rest_object(rest_object)

assert workspace.id == "/subscriptions/sub-id"
assert workspace.name == "test_workspace"
assert workspace.location == "test_location"
assert workspace.description == "test_description"
assert workspace.tags == {"test_tag": "test_value"}
assert workspace.display_name == "test_friendly_name"
assert workspace.discovery_url == "test_discovery_url"
assert workspace.resource_group is None
assert workspace.storage_account == "test_storage_account"
assert workspace.key_vault == "test_key_vault"
assert workspace.application_insights == "test_application_insights"
assert workspace.container_registry == "test_container_registry"
assert workspace.customer_managed_key is None
assert workspace.hbi_workspace is True
assert workspace.public_network_access == "Enabled"
assert workspace.image_build_compute == "test_image_build_compute"
assert workspace.discovery_url == "test_discovery_url"
assert workspace.mlflow_tracking_uri == "ml_flow_tracking_uri"
assert workspace.primary_user_assigned_identity == "test_primary_user_assigned_identity"
assert workspace.system_datastores_auth_mode == "AccessKey"
assert workspace.enable_data_isolation == True
assert workspace.allow_roleassignment_on_rg == True
assert workspace._hub_id == "hub_resource_id"
assert workspace._kind == "project"
assert workspace._workspace_id == "workspace_id"
assert workspace.identity is None
assert workspace.managed_network is None
assert workspace._feature_store_settings is None
assert workspace.network_acls is None
assert workspace.provision_network_now == True
assert workspace.serverless_compute is None

def test_serverless_compute_settings_subnet_name_must_be_an_arm_id(self) -> None:
with pytest.raises(ValidationError):
ServerlessComputeSettings(custom_subnet="justaname", no_public_ip=True)
Expand Down Expand Up @@ -65,11 +157,17 @@ def test_serverless_compute_settings_subnet_name_must_be_an_arm_id(self) -> None
)
def test_workspace_load_override_serverless(self, settings: ServerlessComputeSettings) -> None:
params_override = [
{"serverless_compute": {"custom_subnet": settings.custom_subnet, "no_public_ip": settings.no_public_ip}}
{
"serverless_compute": {
"custom_subnet": settings.custom_subnet,
"no_public_ip": settings.no_public_ip,
}
}
]

workspace_override = load_workspace(
"./tests/test_configs/workspace/workspace_serverless.yaml", params_override=params_override
"./tests/test_configs/workspace/workspace_serverless.yaml",
params_override=params_override,
)
assert workspace_override.serverless_compute == settings

Expand Down
Loading