Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6e892a3
WIP
Nov 15, 2024
76f793c
Merge branch 'main' of https://github.com/sagarsumant/azure-sdk-for-p…
Nov 15, 2024
d14d59a
Fix circular import
Nov 19, 2024
96e410b
Add queue_settings and resources fields.
Nov 19, 2024
139d250
Work in progress.
Nov 21, 2024
b032006
Test run successful
Nov 21, 2024
19ff504
All 3 flavors working in integration tests.
Nov 21, 2024
b4ac643
add unit tests for yaml.
Nov 22, 2024
8b9f738
Add tests for yaml job creation.
Nov 22, 2024
db2ed0c
Merge branch 'main' of https://github.com/sagarsumant/azure-sdk-for-p…
Nov 22, 2024
73ba217
Fix tests.
Nov 22, 2024
a46f956
Fix test.
Nov 22, 2024
e35a85f
Refactor and update tests.
Nov 23, 2024
da142ce
Fix things.
Nov 23, 2024
47486aa
Fix tests.
Nov 23, 2024
db43427
Fix failing test.
Nov 23, 2024
134178c
Fix tests.
Nov 23, 2024
993941e
add @pytest.mark.e2etest
Nov 23, 2024
146dae7
try recorded test.
Nov 23, 2024
7dded02
comment status for non live testing.
Nov 23, 2024
a102afc
Fix things.
Nov 26, 2024
fbaa872
Merge branch 'main' of https://github.com/sagarsumant/azure-sdk-for-p…
Nov 27, 2024
fc86030
Rebase and fix merge issues.
Nov 27, 2024
da83198
format using black.
Nov 27, 2024
cbd70a4
format using black.
Nov 27, 2024
4d54f06
Disable recorded test execution
Nov 28, 2024
41b838b
Remove assets.json changes.
Nov 28, 2024
3010dd9
Fix linting errors.
Nov 28, 2024
ac86738
Fix imports
Nov 28, 2024
18d2d2c
Fix comments.
Dec 2, 2024
003a75d
Remove job.py changes.
Dec 2, 2024
e866fc1
Fix tests.
Dec 2, 2024
0f7ed81
Merge branch 'Azure:main' into sasum/maap-finetuning-job-sdk-cli-impl…
sagarsumant Dec 3, 2024
7cfcce9
Merge branch 'main' of https://github.com/sagarsumant/azure-sdk-for-p…
Dec 3, 2024
433d164
Merge branch 'sasum/maap-finetuning-job-sdk-cli-implementation' of ht…
Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/_ml_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def __init__(
_service_client_kwargs=kwargs,
requests_pipeline=self._requests_pipeline,
service_client_01_2024_preview=self._service_client_01_2024_preview,
service_client_10_2024_preview=self._service_client_10_2024_preview,
**ops_kwargs,
)
self._operation_container.add(AzureMLResourceType.JOB, self._jobs)
Expand Down Expand Up @@ -746,7 +747,8 @@ def __init__(
**ops_kwargs, # type: ignore[arg-type]
)
self._operation_container.add(
AzureMLResourceType.VIRTUALCLUSTER, self._virtual_clusters # type: ignore[arg-type]
AzureMLResourceType.VIRTUALCLUSTER,
self._virtual_clusters, # type: ignore[arg-type]
)
except Exception as ex: # pylint: disable=broad-except
module_logger.debug("Virtual Cluster operations could not be initialized due to %s ", ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@
from azure.ai.ml._schema.job import BaseJobSchema
from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml._schema.core.fields import (
NestedField,
)
from ..queue_settings import QueueSettingsSchema
from ..job_resources import JobResourcesSchema

# This is meant to match the yaml definition NOT the models defined in _restclient


@experimental
class FineTuningJobSchema(BaseJobSchema):
outputs = OutputsField()
queue_settings = NestedField(QueueSettingsSchema)
resources = NestedField(JobResourcesSchema)
21 changes: 21 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job_resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

from marshmallow import fields, post_load

from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta


class JobResourcesSchema(metaclass=PatchedSchemaMeta):
instance_types = fields.List(
fields.Str(), metadata={"description": "The instance type to make available to this job."}
)

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities import JobResources

return JobResources(**data)
17 changes: 17 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------


class FineTuningTaskType:
CHAT_COMPLETION = "ChatCompletion"
TEXT_COMPLETION = "TextCompletion"
TEXT_CLASSIFICATION = "TextClassification"
QUESTION_ANSWERING = "QuestionAnswering"
TEXT_SUMMARIZATION = "TextSummarization"
TOKEN_CLASSIFICATION = "TokenClassification"
TEXT_TRANSLATION = "TextTranslation"
IMAGE_CLASSIFICATION = "ImageClassification"
IMAGE_INSTANCE_SEGMENTATION = "ImageInstanceSegmentation"
IMAGE_OBJECT_DETECTION = "ImageObjectDetection"
VIDEO_MULTI_OBJECT_TRACKING = "VideoMultiObjectTracking"
63 changes: 54 additions & 9 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,29 @@
from ._component.pipeline_component import PipelineComponent
from ._component.spark_component import SparkComponent
from ._compute._aml_compute_node_info import AmlComputeNodeInfo
from ._compute._custom_applications import CustomApplications, EndpointsSettings, ImageSettings, VolumeSettings
from ._compute._custom_applications import (
CustomApplications,
EndpointsSettings,
ImageSettings,
VolumeSettings,
)
from ._compute._image_metadata import ImageMetadata
from ._compute._schedule import ComputePowerAction, ComputeSchedules, ComputeStartStopSchedule, ScheduleState
from ._compute._schedule import (
ComputePowerAction,
ComputeSchedules,
ComputeStartStopSchedule,
ScheduleState,
)
from ._compute._setup_scripts import ScriptReference, SetupScripts
from ._compute._usage import Usage, UsageName
from ._compute._vm_size import VmSize
from ._compute.aml_compute import AmlCompute, AmlComputeSshSettings
from ._compute.compute import Compute, NetworkSettings
from ._compute.compute_instance import AssignedUserConfiguration, ComputeInstance, ComputeInstanceSshSettings
from ._compute.compute_instance import (
AssignedUserConfiguration,
ComputeInstance,
ComputeInstanceSshSettings,
)
from ._compute.kubernetes_compute import KubernetesCompute
from ._compute.synapsespark_compute import AutoPauseSettings, AutoScaleSettings, SynapseSparkCompute
from ._compute.unsupported_compute import UnsupportedCompute
Expand All @@ -84,7 +98,11 @@
from ._data_import.data_import import DataImport
from ._data_import.schedule import ImportDataSchedule
from ._datastore.adls_gen1 import AzureDataLakeGen1Datastore
from ._datastore.azure_storage import AzureBlobDatastore, AzureDataLakeGen2Datastore, AzureFileDatastore
from ._datastore.azure_storage import (
AzureBlobDatastore,
AzureDataLakeGen2Datastore,
AzureFileDatastore,
)
from ._datastore.datastore import Datastore
from ._datastore.one_lake import OneLakeArtifact, OneLakeDatastore
from ._deployment.batch_deployment import BatchDeployment
Expand All @@ -94,7 +112,11 @@
from ._deployment.data_asset import DataAsset
from ._deployment.data_collector import DataCollector
from ._deployment.deployment_collection import DeploymentCollection
from ._deployment.deployment_settings import BatchRetrySettings, OnlineRequestSettings, ProbeSettings
from ._deployment.deployment_settings import (
BatchRetrySettings,
OnlineRequestSettings,
ProbeSettings,
)
from ._deployment.model_batch_deployment import ModelBatchDeployment
from ._deployment.model_batch_deployment_settings import ModelBatchDeploymentSettings
from ._deployment.online_deployment import (
Expand All @@ -106,7 +128,11 @@
from ._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment
from ._deployment.request_logging import RequestLogging
from ._deployment.resource_requirements_settings import ResourceRequirementsSettings
from ._deployment.scale_settings import DefaultScaleSettings, OnlineScaleSettings, TargetUtilizationScaleSettings
from ._deployment.scale_settings import (
DefaultScaleSettings,
OnlineScaleSettings,
TargetUtilizationScaleSettings,
)
from ._endpoint.batch_endpoint import BatchEndpoint
from ._endpoint.endpoint import Endpoint
from ._endpoint.online_endpoint import (
Expand Down Expand Up @@ -136,11 +162,19 @@
from ._indexes import ModelConfiguration as IndexModelConfiguration
from ._job.command_job import CommandJob
from ._job.compute_configuration import ComputeConfiguration
from ._job.finetuning.custom_model_finetuning_job import CustomModelFineTuningJob
from ._job.input_port import InputPort
from ._job.job import Job
from ._job.job_limits import CommandJobLimits
from ._job.job_resources import JobResources
from ._job.job_resource_configuration import JobResourceConfiguration
from ._job.job_service import JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService
from ._job.job_service import (
JobService,
JupyterLabJobService,
SshJobService,
TensorBoardJobService,
VsCodeJobService,
)
from ._job.parallel.parallel_task import ParallelTask
from ._job.parallel.retry_settings import RetrySettings
from ._job.parameterized_command import ParameterizedCommand
Expand All @@ -156,7 +190,12 @@
from ._monitoring.alert_notification import AlertNotification
from ._monitoring.compute import ServerlessSparkCompute
from ._monitoring.definition import MonitorDefinition
from ._monitoring.input_data import FixedInputData, MonitorInputData, StaticInputData, TrailingInputData
from ._monitoring.input_data import (
FixedInputData,
MonitorInputData,
StaticInputData,
TrailingInputData,
)
from ._monitoring.schedule import MonitorSchedule
from ._monitoring.signals import (
BaselineDataRange,
Expand Down Expand Up @@ -244,7 +283,11 @@
from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint
from ._workspace.serverless_compute import ServerlessComputeSettings
from ._workspace.workspace import Workspace
from ._workspace.workspace_keys import ContainerRegistryCredential, NotebookAccessKeys, WorkspaceKeys
from ._workspace.workspace_keys import (
ContainerRegistryCredential,
NotebookAccessKeys,
WorkspaceKeys,
)

__all__ = [
"Resource",
Expand All @@ -258,8 +301,10 @@
"SparkJobEntryType",
"CommandJobLimits",
"ComputeConfiguration",
"CustomModelFineTuningJob",
"CreatedByType",
"ResourceConfiguration",
"JobResources",
"JobResourceConfiguration",
"QueueSettings",
"JobService",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def validate_pipeline_input_key_characters(key: str) -> None:
# so a valid pipeline key is: ^{single_key}([.]{single_key})*$
if re.match(IOConstants.VALID_KEY_PATTERN, key) is None:
msg = (
"Pipeline input key name {} must be composed letters, numbers, and underscores with optional "
"split by dots."
"Pipeline input key name {} must be composed letters, numbers, and underscores with optional split by dots."
)
raise ValidationException(
message=msg.format(key),
Expand Down Expand Up @@ -262,7 +261,6 @@ def to_rest_dataset_literal_inputs(
uri=input_value.path,
mode=(INPUT_MOUNT_MAPPING_TO_REST[input_value.mode.lower()] if input_value.mode else None),
)

else:
msg = f"Job input type {input_value.type} is not supported as job input."
raise ValidationException(
Expand Down Expand Up @@ -415,7 +413,7 @@ def from_rest_data_outputs(outputs: Dict[str, RestJobOutput]) -> Dict[str, Outpu
path_on_compute=sourcePathOnCompute,
description=output_value.description,
name=output_value.asset_name,
version=output_value.asset_version,
version=(output_value.asset_version if hasattr(output_value, "asset_version") else None),
)
else:
msg = "unsupported JobOutput type: {}".format(output_value.job_output_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Any, Dict

from azure.ai.ml._restclient.v2024_01_01_preview.models import (
from azure.ai.ml._restclient.v2024_10_01_preview.models import (
ModelProvider as RestModelProvider,
CustomModelFineTuning as RestCustomModelFineTuningVertical,
FineTuningJob as RestFineTuningJob,
Expand All @@ -16,6 +16,8 @@
from_rest_data_outputs,
to_rest_data_outputs,
)
from azure.ai.ml.entities._job.job_resources import JobResources
from azure.ai.ml.entities._job.queue_settings import QueueSettings
from azure.ai.ml.entities._inputs_outputs import Input
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
from azure.ai.ml.entities._job.finetuning.finetuning_vertical import FineTuningVertical
Expand Down Expand Up @@ -89,9 +91,14 @@ def _to_rest_object(self) -> "RestFineTuningJob":
experiment_name=self.experiment_name,
tags=self.tags,
properties=self.properties,
compute_id=self.compute,
fine_tuning_details=custom_finetuning_vertical,
outputs=to_rest_data_outputs(self.outputs),
)
if self.resources:
finetuning_job.resources = self.resources._to_rest_object()
if self.queue_settings:
finetuning_job.queue_settings = self.queue_settings._to_rest_object()

result = RestJobBase(properties=finetuning_job)
result.name = self.name
Expand Down Expand Up @@ -166,9 +173,15 @@ def _from_rest_object(cls, obj: RestJobBase) -> "CustomModelFineTuningJob":
"status": properties.status,
"creation_context": obj.system_data,
"display_name": properties.display_name,
"compute": properties.compute_id,
"outputs": from_rest_data_outputs(properties.outputs),
}

if properties.resources:
job_args_dict["resources"] = JobResources._from_rest_object(properties.resources)
if properties.queue_settings:
job_args_dict["queue_settings"] = QueueSettings._from_rest_object(properties.queue_settings)

custom_model_finetuning_job = cls(
task=finetuning_details.task_type,
model=finetuning_details.model,
Expand Down
Loading
Loading