diff --git a/gradient/api_sdk/clients/model_client.py b/gradient/api_sdk/clients/model_client.py index 1580be71..647fd762 100644 --- a/gradient/api_sdk/clients/model_client.py +++ b/gradient/api_sdk/clients/model_client.py @@ -29,7 +29,7 @@ def delete(self, model_id): repository = self.build_repository(repositories.DeleteModel) repository.delete(model_id) - def upload(self, path, name, model_type, model_summary=None, notes=None, tags=None, project_id=None): + def upload(self, path, name, model_type, model_summary=None, notes=None, tags=None, project_id=None, cluster_id=None): """Upload model :param file path: path to Model @@ -39,6 +39,7 @@ def upload(self, path, name, model_type, model_summary=None, notes=None, tags=No :param str|None notes: Optional model description :param list[str] tags: List of tags :param str|None project_id: ID of a project + :param str|None cluster_id: ID of a cluster :return: ID of new model :rtype: str @@ -53,7 +54,7 @@ def upload(self, path, name, model_type, model_summary=None, notes=None, tags=No ) repository = self.build_repository(repositories.UploadModel) - model_id = repository.create(model, path=path) + model_id = repository.create(model, path=path, cluster_id=cluster_id) if tags: self.add_tags(entity_id=model_id, tags=tags) diff --git a/gradient/api_sdk/models/model.py b/gradient/api_sdk/models/model.py index 2f347f44..a94583dd 100644 --- a/gradient/api_sdk/models/model.py +++ b/gradient/api_sdk/models/model.py @@ -10,6 +10,7 @@ class Model(object): :param str name: :param str project_id: :param str experiment_id: + :param str cluster_id: :param list tags: :param str model_type: :param str url: @@ -22,6 +23,7 @@ class Model(object): name = attr.ib(type=str, default=None) project_id = attr.ib(type=str, default=None) experiment_id = attr.ib(type=str, default=None) + cluster_id = attr.ib(type=str, default=None) tags = attr.ib(type=list, factory=list) model_type = attr.ib(type=str, default=None) url = attr.ib(type=str, default=None) diff --git a/gradient/api_sdk/repositories/models.py b/gradient/api_sdk/repositories/models.py index 909916d4..3619e619 100644 --- a/gradient/api_sdk/repositories/models.py +++ b/gradient/api_sdk/repositories/models.py @@ -86,21 +86,21 @@ def _get_request_params(self, kwargs): def _get_request_json(self, instance_dict): return None - def create(self, instance, data=None, path=None): + def create(self, instance, data=None, path=None, cluster_id=None): model_id = super(UploadModel, self).create(instance, data=data, path=path) try: - self._upload_model(path, model_id) + self._upload_model(path, model_id, cluster_id=cluster_id) except BaseException: self._delete_model(model_id) raise return model_id - def _upload_model(self, file_path, model_id): + def _upload_model(self, file_path, model_id, cluster_id=None): model_uploader = s3_uploader.S3ModelUploader( self.api_key, logger=self.logger, ps_client_name=self.ps_client_name ) - model_uploader.upload(file_path, model_id) + model_uploader.upload(file_path, model_id, cluster_id=cluster_id) def _delete_model(self, model_id): repository = DeleteModel(self.api_key, logger=self.logger, ps_client_name=self.ps_client_name) diff --git a/gradient/api_sdk/s3_uploader.py b/gradient/api_sdk/s3_uploader.py index a3817c3e..b3750b97 100644 --- a/gradient/api_sdk/s3_uploader.py +++ b/gradient/api_sdk/s3_uploader.py @@ -193,7 +193,7 @@ def __init__(self, api_key, multipart_encoder_cls=None, logger=None, ps_client_n multipart_encoder_cls=self.multipart_encoder_cls ) - def upload(self, file_path, model_id): + def upload(self, file_path, model_id, cluster_id=None): """Upload file to S3 bucket for a project :param str file_path: @@ -202,11 +202,11 @@ def upload(self, file_path, model_id): :rtype: str :return: S3 bucket's URL """ - url = self._get_upload_data(file_path, model_id) + url = self._get_upload_data(file_path, model_id, cluster_id=cluster_id) self.s3uploader.upload(file_path, url) return url - def _get_upload_data(self, file_path, model_id): + def _get_upload_data(self, file_path, model_id, cluster_id=None): """Ask API for data required to upload a file to S3 :param str file_path: @@ -221,6 +221,8 @@ def _get_upload_data(self, file_path, model_id): "modelHandle": model_id, "contentType": mimetypes.guess_type(file_path)[0] or "", } + if cluster_id: + params["clusterId"] = cluster_id response = self.ps_api_client.get("/mlModels/getPresignedModelUrl", params=params) if not response.ok: @@ -239,11 +241,11 @@ def _get_client(self, url, ps_client_name=None, api_key=None): class S3ModelUploader(S3ModelFileUploader): - def upload(self, file_path, model_id): + def upload(self, file_path, model_id, cluster_id=None): if os.path.isdir(file_path): file_path = self._zip_model_directory(file_path) - return super(S3ModelUploader, self).upload(file_path, model_id) + return super(S3ModelUploader, self).upload(file_path, model_id, cluster_id=cluster_id) def _zip_model_directory(self, dir_path): archiver = self._get_archiver() diff --git a/gradient/cli/models.py b/gradient/cli/models.py index b4d5cd34..35d81db0 100644 --- a/gradient/cli/models.py +++ b/gradient/cli/models.py @@ -87,6 +87,12 @@ def delete_model(api_key, model_id, options_file): help="ID of a project", cls=common.GradientOption, ) +@click.option( + "--clusterId", + "cluster_id", + help="ID of a cluster", + cls=common.GradientOption, +) @click.option( "--modelSummary", "model_summary", diff --git a/tests/config_files/models_upload.yaml b/tests/config_files/models_upload.yaml index e7fed145..574070a8 100644 --- a/tests/config_files/models_upload.yaml +++ b/tests/config_files/models_upload.yaml @@ -6,3 +6,4 @@ modelType: Tensorflow name: some_name notes: some notes projectId: some_project_id +clusterId: some_cluster_id diff --git a/tests/functional/test_models.py b/tests/functional/test_models.py index 4292ad11..f54f12fb 100644 --- a/tests/functional/test_models.py +++ b/tests/functional/test_models.py @@ -270,6 +270,7 @@ class TestModelUpload(object): "--modelSummary", """{"key": "value"}""", "--notes", "some notes", "--projectId", "some_project_id", + "--clusterId", "some_cluster_id", ] ALL_OPTIONS_PARAMS = { "name": "some_name", @@ -287,6 +288,7 @@ class TestModelUpload(object): "--modelSummary", """{"key": "value"}""", "--notes", "some notes", "--projectId", "some_project_id", + "--clusterId", "some_cluster_id", "--apiKey", "some_key", ] COMMAND_WITH_OPTIONS_FILE = ["models", "upload", "--optionsFile", ] # path added in test @@ -294,7 +296,8 @@ class TestModelUpload(object): EXPECTED_STDOUT = "Model uploaded with ID: some_model_id\n" GET_PRESIGNED_URL = "https://api.paperspace.io/mlModels/getPresignedModelUrl" - GET_PRESIGNED_URL_PARAMS = {"fileName": "saved_model.pb", "modelHandle": "some_model_id", "contentType": ""} + GET_PRESIGNED_URL_PARAMS = {"fileName": "saved_model.pb", "modelHandle": "some_model_id", "contentType": "", "clusterId": "some_cluster_id"} + GET_PRESIGNED_URL_PARAMS_BASIC = {"fileName": "saved_model.pb", "modelHandle": "some_model_id", "contentType": ""} GET_PRESIGNED_URL_RESPONSE = example_responses.MODEL_UPLOAD_GET_PRESIGNED_URL_RESPONSE CREATE_MODEL_V2_REPONSE = example_responses.MODEL_CREATE_RESPONSE_JSON_V2 @@ -336,7 +339,7 @@ def test_should_send_post_request_when_models_update_command_was_used_with_basic ]) get_patched.assert_called_once_with(self.GET_PRESIGNED_URL, headers=EXPECTED_HEADERS, - params=self.GET_PRESIGNED_URL_PARAMS, + params=self.GET_PRESIGNED_URL_PARAMS_BASIC, json=None, ) assert put_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE @@ -520,7 +523,7 @@ def test_should_send_proper_data_and_tag_machine( mock.call( self.GET_PRESIGNED_URL, headers=EXPECTED_HEADERS, - params=self.GET_PRESIGNED_URL_PARAMS, + params=self.GET_PRESIGNED_URL_PARAMS_BASIC, json=None, ), ]