Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 22 additions & 41 deletions gradient/api_sdk/clients/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .. import logger as sdk_logger
from ..repositories.common import BaseRepository
from ..repositories.tags import ListTagRepository, UpdateTagRepository
from ...exceptions import ReceivingDataFailedError


class BaseClient(object):
Expand All @@ -24,19 +23,22 @@ def __init__(
self.ps_client_name = ps_client_name
self.logger = logger

KNOWN_TAGS_ENTITIES = [
"project", "job", "notebook", "experiment", "deployment", "mlModel", "machine",
]
entity = ""

def _validate_entities(self, entity):
def build_repository(self, repository_class, *args, **kwargs):
"""
Method to validate if passed entity is correct
:param entity:
:return:
:param type[BaseRepository] repository_class:
:rtype: BaseRepository
"""
if entity not in self.KNOWN_TAGS_ENTITIES:
raise ReceivingDataFailedError("Not known entity type provided")

if self.ps_client_name is not None and kwargs.get("ps_client_name") is None:
kwargs = copy.deepcopy(kwargs)
kwargs["ps_client_name"] = self.ps_client_name

repository = repository_class(*args, api_key=self.api_key, logger=self.logger, **kwargs)
return repository


class TagsSupportMixin(object):
entity = ""

@staticmethod
def merge_tags(entity_id, entity_tags, new_tags):
Expand All @@ -59,66 +61,45 @@ def diff_tags(entity_id, entity_tags, tags_to_remove):

return result_tags

def add_tags(self, entity_id, entity, tags):
def add_tags(self, entity_id, tags):
"""
Add tags to entity.
:param entity_id:
:param entity:
:param tags:
:return:
"""
self._validate_entities(entity)

list_tag_repository = self.build_repository(ListTagRepository)
entity_tags = list_tag_repository.list(entity=entity, entity_ids=[entity_id])
entity_tags = list_tag_repository.list(entity=self.entity, entity_ids=[entity_id])

if entity_tags:
tags = self.merge_tags(entity_id, entity_tags, tags)

update_tag_repository = self.build_repository(UpdateTagRepository)
update_tag_repository.update(entity=entity, entity_id=entity_id, tags=tags)
update_tag_repository.update(entity=self.entity, entity_id=entity_id, tags=tags)

def remove_tags(self, entity_id, entity, tags):
def remove_tags(self, entity_id, tags):
"""
Remove tags from entity.
:param str entity_id:
:param str entity:
:param list[str] tags: list of tags to remove from entity
:return:
"""
self._validate_entities(entity)

list_tag_repository = self.build_repository(ListTagRepository)
entity_tags = list_tag_repository.list(entity=entity, entity_ids=[entity_id])
entity_tags = list_tag_repository.list(entity=self.entity, entity_ids=[entity_id])

if entity_tags:
entity_tags = self.diff_tags(entity_id, entity_tags, tags)

update_tag_repository = self.build_repository(UpdateTagRepository)
update_tag_repository.update(entity=entity, entity_id=entity_id, tags=entity_tags)
update_tag_repository.update(entity=self.entity, entity_id=entity_id, tags=entity_tags)

def list_tags(self, entity_ids, entity):
def list_tags(self, entity_ids):
"""
List tags for entity
:param list[str] entity_ids:
:param str entity:
:return:
"""
self._validate_entities(entity)

list_tag_repository = self.build_repository(ListTagRepository)
entity_tags = list_tag_repository.list(entity=entity, entity_ids=entity_ids)
entity_tags = list_tag_repository.list(entity=self.entity, entity_ids=entity_ids)
return entity_tags

def build_repository(self, repository_class, *args, **kwargs):
"""
:param type[BaseRepository] repository_class:
:rtype: BaseRepository
"""

if self.ps_client_name is not None and kwargs.get("ps_client_name") is None:
kwargs = copy.deepcopy(kwargs)
kwargs["ps_client_name"] = self.ps_client_name

repository = repository_class(*args, api_key=self.api_key, logger=self.logger, **kwargs)
return repository
6 changes: 3 additions & 3 deletions gradient/api_sdk/clients/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

Remember that in code snippets all highlighted lines are required other lines are optional.
"""
from .base_client import BaseClient
from .base_client import BaseClient, TagsSupportMixin
from .. import config, models, repositories


class DeploymentsClient(BaseClient):
class DeploymentsClient(TagsSupportMixin, BaseClient):
"""
Client to handle deployment related actions.

Expand Down Expand Up @@ -137,7 +137,7 @@ def create(
repository = self.build_repository(repositories.CreateDeployment)
deployment_id = repository.create(deployment)
if tags:
self.add_tags(entity_id=deployment_id, entity=self.entity, tags=tags)
self.add_tags(entity_id=deployment_id, tags=tags)
return deployment_id

def get(self, deployment_id):
Expand Down
16 changes: 8 additions & 8 deletions gradient/api_sdk/clients/experiment_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import datetime

from .base_client import BaseClient
from .base_client import BaseClient, TagsSupportMixin
from .. import repositories, models, constants, utils
from ..sdk_exceptions import ResourceCreatingDataError, InvalidParametersError
from ..validation_messages import EXPERIMENT_MODEL_PATH_VALIDATION_ERROR


class ExperimentsClient(utils.ExperimentsClientHelpersMixin, BaseClient):
class ExperimentsClient(TagsSupportMixin, utils.ExperimentsClientHelpersMixin, BaseClient):
entity = "experiment"

def create_single_node(
Expand Down Expand Up @@ -104,7 +104,7 @@ def create_single_node(
handle = repository.create(experiment)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)

return handle

Expand Down Expand Up @@ -234,7 +234,7 @@ def create_multi_node(
handle = repository.create(experiment)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)
return handle

def create_mpi_multi_node(
Expand Down Expand Up @@ -360,7 +360,7 @@ def create_mpi_multi_node(
handle = repository.create(experiment)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)
return handle

def run_single_node(
Expand Down Expand Up @@ -458,7 +458,7 @@ def run_single_node(
handle = repository.create(experiment)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)
return handle

def run_multi_node(
Expand Down Expand Up @@ -586,7 +586,7 @@ def run_multi_node(
handle = repository.create(experiment)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)

return handle

Expand Down Expand Up @@ -713,7 +713,7 @@ def run_mpi_multi_node(
handle = repository.create(experiment)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)

return handle

Expand Down
7 changes: 4 additions & 3 deletions gradient/api_sdk/clients/hyperparameter_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from . import base_client
from .base_client import TagsSupportMixin
from .. import models, repositories


class HyperparameterJobsClient(base_client.BaseClient):
class HyperparameterJobsClient(TagsSupportMixin, base_client.BaseClient):
entity = "experiment"

def create(
Expand Down Expand Up @@ -110,7 +111,7 @@ def create(
handle = repository.create(hyperparameter)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)

return handle

Expand Down Expand Up @@ -220,7 +221,7 @@ def run(
handle = repository.create(hyperparameter)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)

return handle

Expand Down
6 changes: 3 additions & 3 deletions gradient/api_sdk/clients/job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

Remember that in code snippets all highlighted lines are required other lines are optional.
"""
from .base_client import BaseClient
from .base_client import BaseClient, TagsSupportMixin
from ..models import Artifact, Job
from ..repositories.jobs import ListJobs, ListJobLogs, ListJobArtifacts, CreateJob, DeleteJob, StopJob, \
DeleteJobArtifacts, GetJobArtifacts, GetJobMetrics, StreamJobMetrics


class JobsClient(BaseClient):
class JobsClient(TagsSupportMixin, BaseClient):
"""
Client to handle job related actions.

Expand Down Expand Up @@ -172,7 +172,7 @@ def create(
handle = repository.create(job, data=data)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)

return handle

Expand Down
6 changes: 3 additions & 3 deletions gradient/api_sdk/clients/machines_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .base_client import BaseClient
from .base_client import BaseClient, TagsSupportMixin
from .. import repositories, models
from ..repositories.machines import CheckMachineAvailability, DeleteMachine, ListMachines, WaitForState


class MachinesClient(BaseClient):
class MachinesClient(TagsSupportMixin, BaseClient):
entity = "machine"

def create(
Expand Down Expand Up @@ -78,7 +78,7 @@ def create(
repository = self.build_repository(repositories.CreateMachine)
handle = repository.create(instance)
if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)
return handle

def get(self, id):
Expand Down
6 changes: 3 additions & 3 deletions gradient/api_sdk/clients/model_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json

from .base_client import BaseClient
from .base_client import BaseClient, TagsSupportMixin
from .. import repositories, models


class ModelsClient(BaseClient):
class ModelsClient(TagsSupportMixin, BaseClient):
entity = "mlModel"

def list(self, experiment_id=None, project_id=None, tags=None):
Expand Down Expand Up @@ -56,7 +56,7 @@ def upload(self, path, name, model_type, model_summary=None, notes=None, tags=No
model_id = repository.create(model, path=path)

if tags:
self.add_tags(entity_id=model_id, entity=self.entity, tags=tags)
self.add_tags(entity_id=model_id, tags=tags)

return model_id

Expand Down
22 changes: 10 additions & 12 deletions gradient/api_sdk/clients/notebook_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .base_client import BaseClient
from .base_client import BaseClient, TagsSupportMixin
from .. import repositories, models


class NotebooksClient(BaseClient):
class NotebooksClient(TagsSupportMixin, BaseClient):
entity = "notebook"

def create(
Expand All @@ -26,7 +26,7 @@ def create(

:param int container_id:
:param str cluster_id:
:param str vm_type_id;
:param str vm_type_id:
:param int vm_type_label:
:param str container_name:
:param str name:
Expand Down Expand Up @@ -54,16 +54,16 @@ def create(
container_user=container_user,
shutdown_timeout=shutdown_timeout,
is_preemptible=is_preemptible,
vm_type_label = vm_type_label,
vm_type_id = vm_type_id,
is_public = is_public,
vm_type_label=vm_type_label,
vm_type_id=vm_type_id,
is_public=is_public,
)

repository = self.build_repository(repositories.CreateNotebook)
handle = repository.create(notebook)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)

return handle

Expand All @@ -79,7 +79,7 @@ def start(
tags=None,
):
"""Start existing notebook
:param str|int id
:param str|int id:
:param str cluster_id:
:param str vm_type_id:
:param int vm_type_label:
Expand All @@ -106,7 +106,7 @@ def start(
handle = repository.start(notebook)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)

return handle

Expand All @@ -122,7 +122,7 @@ def fork(self, id, tags=None):
handle = repository.fork(id)

if tags:
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
self.add_tags(entity_id=handle, tags=tags)

return handle

Expand Down Expand Up @@ -207,7 +207,6 @@ def stop(self, id):
repository = self.build_repository(repositories.StopNotebook)
repository.stop(id)


def artifacts_list(self, notebook_id, files=None, size=False, links=True):
"""
Method to retrieve all artifacts files.
Expand All @@ -234,4 +233,3 @@ def artifacts_list(self, notebook_id, files=None, size=False, links=True):
repository = self.build_repository(repositories.ListNotebookArtifacts)
artifacts = repository.list(notebook_id=notebook_id, files=files, links=links, size=size)
return artifacts

Loading