diff --git a/gradient/api_sdk/clients/base_client.py b/gradient/api_sdk/clients/base_client.py index a7047086..815083e9 100644 --- a/gradient/api_sdk/clients/base_client.py +++ b/gradient/api_sdk/clients/base_client.py @@ -3,7 +3,6 @@ from .. import logger as sdk_logger from ..repositories.common import BaseRepository from ..repositories.tags import ListTagRepository, UpdateTagRepository -from ...exceptions import ReceivingDataFailedError class BaseClient(object): @@ -24,19 +23,22 @@ def __init__( self.ps_client_name = ps_client_name self.logger = logger - KNOWN_TAGS_ENTITIES = [ - "project", "job", "notebook", "experiment", "deployment", "mlModel", "machine", - ] - entity = "" - - def _validate_entities(self, entity): + def build_repository(self, repository_class, *args, **kwargs): """ - Method to validate if passed entity is correct - :param entity: - :return: + :param type[BaseRepository] repository_class: + :rtype: BaseRepository """ - if entity not in self.KNOWN_TAGS_ENTITIES: - raise ReceivingDataFailedError("Not known entity type provided") + + if self.ps_client_name is not None and kwargs.get("ps_client_name") is None: + kwargs = copy.deepcopy(kwargs) + kwargs["ps_client_name"] = self.ps_client_name + + repository = repository_class(*args, api_key=self.api_key, logger=self.logger, **kwargs) + return repository + + +class TagsSupportMixin(object): + entity = "" @staticmethod def merge_tags(entity_id, entity_tags, new_tags): @@ -59,7 +61,7 @@ def diff_tags(entity_id, entity_tags, tags_to_remove): return result_tags - def add_tags(self, entity_id, entity, tags): + def add_tags(self, entity_id, tags): """ Add tags to entity. :param entity_id: @@ -67,58 +69,37 @@ def add_tags(self, entity_id, entity, tags): :param tags: :return: """ - self._validate_entities(entity) - list_tag_repository = self.build_repository(ListTagRepository) - entity_tags = list_tag_repository.list(entity=entity, entity_ids=[entity_id]) + entity_tags = list_tag_repository.list(entity=self.entity, entity_ids=[entity_id]) if entity_tags: tags = self.merge_tags(entity_id, entity_tags, tags) update_tag_repository = self.build_repository(UpdateTagRepository) - update_tag_repository.update(entity=entity, entity_id=entity_id, tags=tags) + update_tag_repository.update(entity=self.entity, entity_id=entity_id, tags=tags) - def remove_tags(self, entity_id, entity, tags): + def remove_tags(self, entity_id, tags): """ Remove tags from entity. :param str entity_id: - :param str entity: :param list[str] tags: list of tags to remove from entity :return: """ - self._validate_entities(entity) - list_tag_repository = self.build_repository(ListTagRepository) - entity_tags = list_tag_repository.list(entity=entity, entity_ids=[entity_id]) + entity_tags = list_tag_repository.list(entity=self.entity, entity_ids=[entity_id]) if entity_tags: entity_tags = self.diff_tags(entity_id, entity_tags, tags) update_tag_repository = self.build_repository(UpdateTagRepository) - update_tag_repository.update(entity=entity, entity_id=entity_id, tags=entity_tags) + update_tag_repository.update(entity=self.entity, entity_id=entity_id, tags=entity_tags) - def list_tags(self, entity_ids, entity): + def list_tags(self, entity_ids): """ List tags for entity :param list[str] entity_ids: - :param str entity: :return: """ - self._validate_entities(entity) - list_tag_repository = self.build_repository(ListTagRepository) - entity_tags = list_tag_repository.list(entity=entity, entity_ids=entity_ids) + entity_tags = list_tag_repository.list(entity=self.entity, entity_ids=entity_ids) return entity_tags - - def build_repository(self, repository_class, *args, **kwargs): - """ - :param type[BaseRepository] repository_class: - :rtype: BaseRepository - """ - - if self.ps_client_name is not None and kwargs.get("ps_client_name") is None: - kwargs = copy.deepcopy(kwargs) - kwargs["ps_client_name"] = self.ps_client_name - - repository = repository_class(*args, api_key=self.api_key, logger=self.logger, **kwargs) - return repository diff --git a/gradient/api_sdk/clients/deployment_client.py b/gradient/api_sdk/clients/deployment_client.py index 70969cb2..6777664b 100644 --- a/gradient/api_sdk/clients/deployment_client.py +++ b/gradient/api_sdk/clients/deployment_client.py @@ -3,11 +3,11 @@ Remember that in code snippets all highlighted lines are required other lines are optional. """ -from .base_client import BaseClient +from .base_client import BaseClient, TagsSupportMixin from .. import config, models, repositories -class DeploymentsClient(BaseClient): +class DeploymentsClient(TagsSupportMixin, BaseClient): """ Client to handle deployment related actions. @@ -137,7 +137,7 @@ def create( repository = self.build_repository(repositories.CreateDeployment) deployment_id = repository.create(deployment) if tags: - self.add_tags(entity_id=deployment_id, entity=self.entity, tags=tags) + self.add_tags(entity_id=deployment_id, tags=tags) return deployment_id def get(self, deployment_id): diff --git a/gradient/api_sdk/clients/experiment_client.py b/gradient/api_sdk/clients/experiment_client.py index 9fe5cb51..a9ded9a6 100644 --- a/gradient/api_sdk/clients/experiment_client.py +++ b/gradient/api_sdk/clients/experiment_client.py @@ -1,12 +1,12 @@ import datetime -from .base_client import BaseClient +from .base_client import BaseClient, TagsSupportMixin from .. import repositories, models, constants, utils from ..sdk_exceptions import ResourceCreatingDataError, InvalidParametersError from ..validation_messages import EXPERIMENT_MODEL_PATH_VALIDATION_ERROR -class ExperimentsClient(utils.ExperimentsClientHelpersMixin, BaseClient): +class ExperimentsClient(TagsSupportMixin, utils.ExperimentsClientHelpersMixin, BaseClient): entity = "experiment" def create_single_node( @@ -104,7 +104,7 @@ def create_single_node( handle = repository.create(experiment) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle @@ -234,7 +234,7 @@ def create_multi_node( handle = repository.create(experiment) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle def create_mpi_multi_node( @@ -360,7 +360,7 @@ def create_mpi_multi_node( handle = repository.create(experiment) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle def run_single_node( @@ -458,7 +458,7 @@ def run_single_node( handle = repository.create(experiment) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle def run_multi_node( @@ -586,7 +586,7 @@ def run_multi_node( handle = repository.create(experiment) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle @@ -713,7 +713,7 @@ def run_mpi_multi_node( handle = repository.create(experiment) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle diff --git a/gradient/api_sdk/clients/hyperparameter_client.py b/gradient/api_sdk/clients/hyperparameter_client.py index bad412b3..fbe563ee 100644 --- a/gradient/api_sdk/clients/hyperparameter_client.py +++ b/gradient/api_sdk/clients/hyperparameter_client.py @@ -1,8 +1,9 @@ from . import base_client +from .base_client import TagsSupportMixin from .. import models, repositories -class HyperparameterJobsClient(base_client.BaseClient): +class HyperparameterJobsClient(TagsSupportMixin, base_client.BaseClient): entity = "experiment" def create( @@ -110,7 +111,7 @@ def create( handle = repository.create(hyperparameter) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle @@ -220,7 +221,7 @@ def run( handle = repository.create(hyperparameter) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle diff --git a/gradient/api_sdk/clients/job_client.py b/gradient/api_sdk/clients/job_client.py index e3890631..60e2382b 100644 --- a/gradient/api_sdk/clients/job_client.py +++ b/gradient/api_sdk/clients/job_client.py @@ -3,13 +3,13 @@ Remember that in code snippets all highlighted lines are required other lines are optional. """ -from .base_client import BaseClient +from .base_client import BaseClient, TagsSupportMixin from ..models import Artifact, Job from ..repositories.jobs import ListJobs, ListJobLogs, ListJobArtifacts, CreateJob, DeleteJob, StopJob, \ DeleteJobArtifacts, GetJobArtifacts, GetJobMetrics, StreamJobMetrics -class JobsClient(BaseClient): +class JobsClient(TagsSupportMixin, BaseClient): """ Client to handle job related actions. @@ -172,7 +172,7 @@ def create( handle = repository.create(job, data=data) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle diff --git a/gradient/api_sdk/clients/machines_client.py b/gradient/api_sdk/clients/machines_client.py index 644c353b..bf6509b9 100644 --- a/gradient/api_sdk/clients/machines_client.py +++ b/gradient/api_sdk/clients/machines_client.py @@ -1,9 +1,9 @@ -from .base_client import BaseClient +from .base_client import BaseClient, TagsSupportMixin from .. import repositories, models from ..repositories.machines import CheckMachineAvailability, DeleteMachine, ListMachines, WaitForState -class MachinesClient(BaseClient): +class MachinesClient(TagsSupportMixin, BaseClient): entity = "machine" def create( @@ -78,7 +78,7 @@ def create( repository = self.build_repository(repositories.CreateMachine) handle = repository.create(instance) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle def get(self, id): diff --git a/gradient/api_sdk/clients/model_client.py b/gradient/api_sdk/clients/model_client.py index ef68650f..1580be71 100644 --- a/gradient/api_sdk/clients/model_client.py +++ b/gradient/api_sdk/clients/model_client.py @@ -1,10 +1,10 @@ import json -from .base_client import BaseClient +from .base_client import BaseClient, TagsSupportMixin from .. import repositories, models -class ModelsClient(BaseClient): +class ModelsClient(TagsSupportMixin, BaseClient): entity = "mlModel" def list(self, experiment_id=None, project_id=None, tags=None): @@ -56,7 +56,7 @@ def upload(self, path, name, model_type, model_summary=None, notes=None, tags=No model_id = repository.create(model, path=path) if tags: - self.add_tags(entity_id=model_id, entity=self.entity, tags=tags) + self.add_tags(entity_id=model_id, tags=tags) return model_id diff --git a/gradient/api_sdk/clients/notebook_client.py b/gradient/api_sdk/clients/notebook_client.py index bbd3c32e..5365c9e2 100644 --- a/gradient/api_sdk/clients/notebook_client.py +++ b/gradient/api_sdk/clients/notebook_client.py @@ -1,8 +1,8 @@ -from .base_client import BaseClient +from .base_client import BaseClient, TagsSupportMixin from .. import repositories, models -class NotebooksClient(BaseClient): +class NotebooksClient(TagsSupportMixin, BaseClient): entity = "notebook" def create( @@ -26,7 +26,7 @@ def create( :param int container_id: :param str cluster_id: - :param str vm_type_id; + :param str vm_type_id: :param int vm_type_label: :param str container_name: :param str name: @@ -54,16 +54,16 @@ def create( container_user=container_user, shutdown_timeout=shutdown_timeout, is_preemptible=is_preemptible, - vm_type_label = vm_type_label, - vm_type_id = vm_type_id, - is_public = is_public, + vm_type_label=vm_type_label, + vm_type_id=vm_type_id, + is_public=is_public, ) repository = self.build_repository(repositories.CreateNotebook) handle = repository.create(notebook) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle @@ -79,7 +79,7 @@ def start( tags=None, ): """Start existing notebook - :param str|int id + :param str|int id: :param str cluster_id: :param str vm_type_id: :param int vm_type_label: @@ -106,7 +106,7 @@ def start( handle = repository.start(notebook) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle @@ -122,7 +122,7 @@ def fork(self, id, tags=None): handle = repository.fork(id) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle @@ -207,7 +207,6 @@ def stop(self, id): repository = self.build_repository(repositories.StopNotebook) repository.stop(id) - def artifacts_list(self, notebook_id, files=None, size=False, links=True): """ Method to retrieve all artifacts files. @@ -234,4 +233,3 @@ def artifacts_list(self, notebook_id, files=None, size=False, links=True): repository = self.build_repository(repositories.ListNotebookArtifacts) artifacts = repository.list(notebook_id=notebook_id, files=files, links=links, size=size) return artifacts - diff --git a/gradient/api_sdk/clients/project_client.py b/gradient/api_sdk/clients/project_client.py index 417a2087..c39eb1f8 100644 --- a/gradient/api_sdk/clients/project_client.py +++ b/gradient/api_sdk/clients/project_client.py @@ -1,8 +1,8 @@ -from .base_client import BaseClient +from .base_client import BaseClient, TagsSupportMixin from .. import models, repositories -class ProjectsClient(BaseClient): +class ProjectsClient(TagsSupportMixin, BaseClient): entity = "project" def create(self, name, repository_name=None, repository_url=None, tags=None, ): @@ -47,7 +47,7 @@ def create(self, name, repository_name=None, repository_url=None, tags=None, ): handle = repository.create(project) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle diff --git a/gradient/commands/deployments.py b/gradient/commands/deployments.py index a28db8b3..251aad85 100644 --- a/gradient/commands/deployments.py +++ b/gradient/commands/deployments.py @@ -18,8 +18,6 @@ @six.add_metaclass(abc.ABCMeta) class BaseDeploymentCommand(BaseCommand): - entity = "deployment" - def _get_client(self, api_key, logger): client = DeploymentsClient( api_key=api_key, @@ -174,13 +172,13 @@ def _get_table_data(self, instance): class DeploymentAddTagsCommand(BaseDeploymentCommand): def execute(self, deployment_id, *args, **kwargs): - self.client.add_tags(deployment_id, entity=self.entity, **kwargs) + self.client.add_tags(deployment_id, **kwargs) self.logger.log("Tags added to deployment") class DeploymentRemoveTagsCommand(BaseDeploymentCommand): def execute(self, deployment_id, *args, **kwargs): - self.client.remove_tags(deployment_id, entity=self.entity, **kwargs) + self.client.remove_tags(deployment_id, **kwargs) self.logger.log("Tags removed from deployment") diff --git a/gradient/commands/experiments.py b/gradient/commands/experiments.py index 9c1d5a01..dd1102af 100644 --- a/gradient/commands/experiments.py +++ b/gradient/commands/experiments.py @@ -28,8 +28,6 @@ @six.add_metaclass(abc.ABCMeta) class BaseExperimentCommand(BaseCommand): - entity = "experiment" - def _get_client(self, api_key, logger): client = api_sdk.clients.ExperimentsClient( api_key=api_key, @@ -404,13 +402,13 @@ def execute(self, experiment_id, *args, **kwargs): class ExperimentAddTagsCommand(BaseExperimentCommand): def execute(self, experiment_id, *args, **kwargs): - self.client.add_tags(experiment_id, entity=self.entity, **kwargs) + self.client.add_tags(experiment_id, **kwargs) self.logger.log("Tags added to experiment") class ExperimentRemoveTagsCommand(BaseExperimentCommand): def execute(self, experiment_id, *args, **kwargs): - self.client.remove_tags(experiment_id, entity=self.entity, **kwargs) + self.client.remove_tags(experiment_id, **kwargs) self.logger.log("Tags removed from experiment") diff --git a/gradient/commands/hyperparameters.py b/gradient/commands/hyperparameters.py index c4a4cbfc..7966b8dc 100644 --- a/gradient/commands/hyperparameters.py +++ b/gradient/commands/hyperparameters.py @@ -11,8 +11,6 @@ @six.add_metaclass(abc.ABCMeta) class BaseHyperparameterCommand(BaseCommand): - entity = "experiment" - def _get_client(self, api_key, logger): client = api_sdk.clients.HyperparameterJobsClient( api_key=api_key, @@ -87,11 +85,11 @@ def execute(self, id_): class HyperparameterAddTagsCommand(BaseHyperparameterCommand): def execute(self, hyperparameter_id, *args, **kwargs): - self.client.add_tags(hyperparameter_id, entity=self.entity, **kwargs) + self.client.add_tags(hyperparameter_id, **kwargs) self.logger.log("Tags added to hyperparameter") class HyperparameterRemoveTagsCommand(BaseHyperparameterCommand): def execute(self, hyperparameter_id, *args, **kwargs): - self.client.remove_tags(hyperparameter_id, entity=self.entity, **kwargs) + self.client.remove_tags(hyperparameter_id, **kwargs) self.logger.log("Tags removed from hyperparameter") diff --git a/gradient/commands/jobs.py b/gradient/commands/jobs.py index d72a354f..905d2837 100644 --- a/gradient/commands/jobs.py +++ b/gradient/commands/jobs.py @@ -10,7 +10,7 @@ from gradient import api_sdk, exceptions, Job, JobArtifactsDownloader, cli_constants from gradient.api_sdk import config, sdk_exceptions from gradient.api_sdk.clients import http_client -from gradient.api_sdk.clients.base_client import BaseClient +from gradient.api_sdk.clients.base_client import BaseClient, TagsSupportMixin from gradient.api_sdk.repositories.jobs import RunJob from gradient.api_sdk.utils import print_dict_recursive, concatenate_urls, MultipartEncoder from gradient.cli_constants import CLI_PS_CLIENT_NAME @@ -20,8 +20,6 @@ @six.add_metaclass(abc.ABCMeta) class BaseJobCommand(BaseCommand): - entity = "job" - def _get_client(self, api_key, logger_): client = api_sdk.clients.JobsClient( api_key=api_key, @@ -189,9 +187,7 @@ def _create(self, json_, data): return self.client.create(data=data, **json_) -class JobRunClient(BaseClient): - entity = "job" - +class JobRunClient(TagsSupportMixin, BaseClient): def __init__(self, http_client_, *args, **kwargs): super(JobRunClient, self).__init__(*args, **kwargs) self.client = http_client_ @@ -259,7 +255,7 @@ def create( ) handle = RunJob(self.api_key, self.logger, self.client).create(job, data=data) if tags: - self.add_tags(entity_id=handle, entity=self.entity, tags=tags) + self.add_tags(entity_id=handle, tags=tags) return handle @@ -363,13 +359,13 @@ def execute(self, job_id, destination_directory): class JobAddTagsCommand(BaseJobCommand): def execute(self, job_id, *args, **kwargs): - self.client.add_tags(job_id, entity=self.entity, **kwargs) + self.client.add_tags(job_id, **kwargs) self.logger.log("Tags added to job") class JobRemoveTagsCommand(BaseJobCommand): def execute(self, job_id, *args, **kwargs): - self.client.remove_tags(job_id, entity=self.entity, **kwargs) + self.client.remove_tags(job_id, **kwargs) self.logger.log("Tags removed from job") diff --git a/gradient/commands/machines.py b/gradient/commands/machines.py index 67b5de40..eaa1e98d 100644 --- a/gradient/commands/machines.py +++ b/gradient/commands/machines.py @@ -11,8 +11,6 @@ class GetMachinesClientMixin(object): - entity = "machine" - def _get_client(self, api_key, logger): client = api_sdk.MachinesClient( api_key=api_key, @@ -176,11 +174,11 @@ def execute(self, machine_id, state, interval=5): class MachineAddTagsCommand(GetMachinesClientMixin, BaseCommand): def execute(self, machine_id, *args, **kwargs): - self.client.add_tags(machine_id, entity=self.entity, **kwargs) + self.client.add_tags(machine_id, **kwargs) self.logger.log("Tags added to machine") class MachineRemoveTagsCommand(GetMachinesClientMixin, BaseCommand): def execute(self, machine_id, *args, **kwargs): - self.client.remove_tags(machine_id, entity=self.entity, **kwargs) + self.client.remove_tags(machine_id, **kwargs) self.logger.log("Tags removed from machine") diff --git a/gradient/commands/models.py b/gradient/commands/models.py index 2e402654..4b12f624 100644 --- a/gradient/commands/models.py +++ b/gradient/commands/models.py @@ -12,8 +12,6 @@ @six.add_metaclass(abc.ABCMeta) class GetModelsClientMixin: - entity = "mlModel" - def _get_client(self, api_key, logger): client = api_sdk.clients.ModelsClient( api_key=api_key, @@ -98,11 +96,11 @@ def execute(self, model_id, destination_directory): class MLModelAddTagsCommand(GetModelsClientMixin, BaseCommand): def execute(self, ml_model_id, *args, **kwargs): - self.client.add_tags(ml_model_id, entity=self.entity, **kwargs) + self.client.add_tags(ml_model_id, **kwargs) self.logger.log("Tags added to ml model") class MLModelRemoveTagsCommand(GetModelsClientMixin, BaseCommand): def execute(self, ml_model_id, *args, **kwargs): - self.client.remove_tags(ml_model_id, entity=self.entity, **kwargs) + self.client.remove_tags(ml_model_id, **kwargs) self.logger.log("Tags removed from ml model") diff --git a/gradient/commands/notebooks.py b/gradient/commands/notebooks.py index bc35e13e..cf0d0bee 100644 --- a/gradient/commands/notebooks.py +++ b/gradient/commands/notebooks.py @@ -10,15 +10,12 @@ from gradient import api_sdk, exceptions from gradient.api_sdk import sdk_exceptions from gradient.cli_constants import CLI_PS_CLIENT_NAME -from gradient.commands.common import BaseCommand, ListCommandMixin, DetailsCommandMixin, StreamMetricsCommand from gradient.cliutils import get_terminal_lines -from gradient.api_sdk.utils import print_dict_recursive +from gradient.commands.common import BaseCommand, ListCommandMixin, DetailsCommandMixin, StreamMetricsCommand @six.add_metaclass(abc.ABCMeta) class BaseNotebookCommand(BaseCommand): - entity = "notebook" - def _get_client(self, api_key, logger): client = api_sdk.clients.NotebooksClient( api_key=api_key, @@ -55,6 +52,7 @@ def execute(self, id): class StartNotebookCommand(BaseNotebookCommand): SPINNER_MESSAGE = "Starting notebook" + def execute(self, **kwargs): with halo.Halo(text=self.SPINNER_MESSAGE, spinner="dots"): notebook_id = self.client.start(**kwargs) @@ -131,13 +129,13 @@ def _get_table_data(self, instance): class NotebookAddTagsCommand(BaseNotebookCommand): def execute(self, notebook_id, *args, **kwargs): - self.client.add_tags(notebook_id, entity=self.entity, **kwargs) + self.client.add_tags(notebook_id, **kwargs) self.logger.log("Tags added to notebook") class NotebookRemoveTagsCommand(BaseNotebookCommand): def execute(self, notebook_id, *args, **kwargs): - self.client.remove_tags(notebook_id, entity=self.entity, **kwargs) + self.client.remove_tags(notebook_id, **kwargs) self.logger.log("Tags removed from notebook") @@ -157,6 +155,7 @@ def execute(self, notebook_id, start, end, interval, built_in_metrics, *args, ** class StreamNotebookMetricsCommand(StreamMetricsCommand, BaseNotebookCommand): pass + class NotebookLogsCommand(BaseNotebookCommand): def execute(self, notebook_id, line, limit, follow): @@ -250,4 +249,3 @@ def _make_table(table_data): ascii_table = terminaltables.AsciiTable(table_data) table_string = ascii_table.table return table_string - diff --git a/gradient/commands/projects.py b/gradient/commands/projects.py index 7d3e3911..af8e1b78 100644 --- a/gradient/commands/projects.py +++ b/gradient/commands/projects.py @@ -13,8 +13,6 @@ @six.add_metaclass(abc.ABCMeta) class BaseProjectCommand(BaseCommand): - entity = "project" - def _get_client(self, api_key, logger): client = api_sdk.clients.ProjectsClient( api_key=api_key, @@ -66,13 +64,13 @@ def execute(self, project_id): class ProjectAddTagsCommand(BaseProjectCommand): def execute(self, project_id, *args, **kwargs): - self.client.add_tags(project_id, entity=self.entity, **kwargs) + self.client.add_tags(project_id, **kwargs) self.logger.log("Tags added to project") class ProjectRemoveTagsCommand(BaseProjectCommand): def execute(self, project_id, *args, **kwargs): - self.client.remove_tags(project_id, entity=self.entity, **kwargs) + self.client.remove_tags(project_id, **kwargs) self.logger.log("Tags removed from project") diff --git a/tests/unit/test_base_client.py b/tests/unit/test_tags_support_mixin.py similarity index 76% rename from tests/unit/test_base_client.py rename to tests/unit/test_tags_support_mixin.py index a254fef1..69e1ce91 100644 --- a/tests/unit/test_base_client.py +++ b/tests/unit/test_tags_support_mixin.py @@ -1,9 +1,13 @@ import pytest -from gradient.api_sdk.clients.base_client import BaseClient +from gradient.api_sdk.clients.base_client import BaseClient, TagsSupportMixin -class TestBaseClientMethods(object): +class ClassWithTagsSupportMixin(TagsSupportMixin, BaseClient): + pass + + +class TestTagsSupportMixinMethods(object): example_entity_tags = [{"some_id": ["test0", "test2", "test1", "test3"]}] @pytest.mark.parametrize( @@ -16,7 +20,7 @@ class TestBaseClientMethods(object): ] ) def test_merge_tags(self, entity_id, entity_tags, new_tags, expected_result_tags): - result_tags = BaseClient.merge_tags(entity_id, entity_tags, new_tags) + result_tags = ClassWithTagsSupportMixin.merge_tags(entity_id, entity_tags, new_tags) assert result_tags == expected_result_tags @pytest.mark.parametrize( @@ -29,5 +33,5 @@ def test_merge_tags(self, entity_id, entity_tags, new_tags, expected_result_tags ] ) def test_diff_tags(self, entity_id, entity_tags, tags_to_remove, expected_result_tags): - result_tags = BaseClient.diff_tags(entity_id, entity_tags, tags_to_remove) + result_tags = ClassWithTagsSupportMixin.diff_tags(entity_id, entity_tags, tags_to_remove) assert result_tags == expected_result_tags