Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions gradient/api_sdk/clients/model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def delete(self, model_id):
repository = self.build_repository(repositories.DeleteModel)
repository.delete(model_id)

def upload(self, path, name, model_type, model_summary=None, notes=None, tags=None, project_id=None):
def upload(self, path, name, model_type, model_summary=None, notes=None, tags=None, project_id=None, cluster_id=None):
"""Upload model

:param file path: path to Model
Expand All @@ -39,6 +39,7 @@ def upload(self, path, name, model_type, model_summary=None, notes=None, tags=No
:param str|None notes: Optional model description
:param list[str] tags: List of tags
:param str|None project_id: ID of a project
:param str|None cluster_id: ID of a cluster

:return: ID of new model
:rtype: str
Expand All @@ -53,7 +54,7 @@ def upload(self, path, name, model_type, model_summary=None, notes=None, tags=No
)

repository = self.build_repository(repositories.UploadModel)
model_id = repository.create(model, path=path)
model_id = repository.create(model, path=path, cluster_id=cluster_id)

if tags:
self.add_tags(entity_id=model_id, tags=tags)
Expand Down
2 changes: 2 additions & 0 deletions gradient/api_sdk/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Model(object):
:param str name:
:param str project_id:
:param str experiment_id:
:param str cluster_id:
:param list tags:
:param str model_type:
:param str url:
Expand All @@ -22,6 +23,7 @@ class Model(object):
name = attr.ib(type=str, default=None)
project_id = attr.ib(type=str, default=None)
experiment_id = attr.ib(type=str, default=None)
cluster_id = attr.ib(type=str, default=None)
tags = attr.ib(type=list, factory=list)
model_type = attr.ib(type=str, default=None)
url = attr.ib(type=str, default=None)
Expand Down
8 changes: 4 additions & 4 deletions gradient/api_sdk/repositories/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,21 @@ def _get_request_params(self, kwargs):
def _get_request_json(self, instance_dict):
return None

def create(self, instance, data=None, path=None):
def create(self, instance, data=None, path=None, cluster_id=None):
model_id = super(UploadModel, self).create(instance, data=data, path=path)
try:
self._upload_model(path, model_id)
self._upload_model(path, model_id, cluster_id=cluster_id)
except BaseException:
self._delete_model(model_id)
raise

return model_id

def _upload_model(self, file_path, model_id):
def _upload_model(self, file_path, model_id, cluster_id=None):
model_uploader = s3_uploader.S3ModelUploader(
self.api_key, logger=self.logger, ps_client_name=self.ps_client_name
)
model_uploader.upload(file_path, model_id)
model_uploader.upload(file_path, model_id, cluster_id=cluster_id)

def _delete_model(self, model_id):
repository = DeleteModel(self.api_key, logger=self.logger, ps_client_name=self.ps_client_name)
Expand Down
12 changes: 7 additions & 5 deletions gradient/api_sdk/s3_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(self, api_key, multipart_encoder_cls=None, logger=None, ps_client_n
multipart_encoder_cls=self.multipart_encoder_cls
)

def upload(self, file_path, model_id):
def upload(self, file_path, model_id, cluster_id=None):
"""Upload file to S3 bucket for a project

:param str file_path:
Expand All @@ -202,11 +202,11 @@ def upload(self, file_path, model_id):
:rtype: str
:return: S3 bucket's URL
"""
url = self._get_upload_data(file_path, model_id)
url = self._get_upload_data(file_path, model_id, cluster_id=cluster_id)
self.s3uploader.upload(file_path, url)
return url

def _get_upload_data(self, file_path, model_id):
def _get_upload_data(self, file_path, model_id, cluster_id=None):
"""Ask API for data required to upload a file to S3

:param str file_path:
Expand All @@ -221,6 +221,8 @@ def _get_upload_data(self, file_path, model_id):
"modelHandle": model_id,
"contentType": mimetypes.guess_type(file_path)[0] or "",
}
if cluster_id:
params["clusterId"] = cluster_id

response = self.ps_api_client.get("/mlModels/getPresignedModelUrl", params=params)
if not response.ok:
Expand All @@ -239,11 +241,11 @@ def _get_client(self, url, ps_client_name=None, api_key=None):


class S3ModelUploader(S3ModelFileUploader):
def upload(self, file_path, model_id):
def upload(self, file_path, model_id, cluster_id=None):
if os.path.isdir(file_path):
file_path = self._zip_model_directory(file_path)

return super(S3ModelUploader, self).upload(file_path, model_id)
return super(S3ModelUploader, self).upload(file_path, model_id, cluster_id=cluster_id)

def _zip_model_directory(self, dir_path):
archiver = self._get_archiver()
Expand Down
6 changes: 6 additions & 0 deletions gradient/cli/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def delete_model(api_key, model_id, options_file):
help="ID of a project",
cls=common.GradientOption,
)
@click.option(
"--clusterId",
"cluster_id",
help="ID of a cluster",
cls=common.GradientOption,
)
@click.option(
"--modelSummary",
"model_summary",
Expand Down
1 change: 1 addition & 0 deletions tests/config_files/models_upload.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ modelType: Tensorflow
name: some_name
notes: some notes
projectId: some_project_id
clusterId: some_cluster_id
9 changes: 6 additions & 3 deletions tests/functional/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class TestModelUpload(object):
"--modelSummary", """{"key": "value"}""",
"--notes", "some notes",
"--projectId", "some_project_id",
"--clusterId", "some_cluster_id",
]
ALL_OPTIONS_PARAMS = {
"name": "some_name",
Expand All @@ -287,14 +288,16 @@ class TestModelUpload(object):
"--modelSummary", """{"key": "value"}""",
"--notes", "some notes",
"--projectId", "some_project_id",
"--clusterId", "some_cluster_id",
"--apiKey", "some_key",
]
COMMAND_WITH_OPTIONS_FILE = ["models", "upload", "--optionsFile", ] # path added in test

EXPECTED_STDOUT = "Model uploaded with ID: some_model_id\n"

GET_PRESIGNED_URL = "https://api.paperspace.io/mlModels/getPresignedModelUrl"
GET_PRESIGNED_URL_PARAMS = {"fileName": "saved_model.pb", "modelHandle": "some_model_id", "contentType": ""}
GET_PRESIGNED_URL_PARAMS = {"fileName": "saved_model.pb", "modelHandle": "some_model_id", "contentType": "", "clusterId": "some_cluster_id"}
GET_PRESIGNED_URL_PARAMS_BASIC = {"fileName": "saved_model.pb", "modelHandle": "some_model_id", "contentType": ""}
GET_PRESIGNED_URL_RESPONSE = example_responses.MODEL_UPLOAD_GET_PRESIGNED_URL_RESPONSE

CREATE_MODEL_V2_REPONSE = example_responses.MODEL_CREATE_RESPONSE_JSON_V2
Expand Down Expand Up @@ -336,7 +339,7 @@ def test_should_send_post_request_when_models_update_command_was_used_with_basic
])
get_patched.assert_called_once_with(self.GET_PRESIGNED_URL,
headers=EXPECTED_HEADERS,
params=self.GET_PRESIGNED_URL_PARAMS,
params=self.GET_PRESIGNED_URL_PARAMS_BASIC,
json=None,
)
assert put_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE
Expand Down Expand Up @@ -520,7 +523,7 @@ def test_should_send_proper_data_and_tag_machine(
mock.call(
self.GET_PRESIGNED_URL,
headers=EXPECTED_HEADERS,
params=self.GET_PRESIGNED_URL_PARAMS,
params=self.GET_PRESIGNED_URL_PARAMS_BASIC,
json=None,
),
]
Expand Down