diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e428faeed..f26bee116 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: exclude: ^(ansys/rep/client/jms/resource/|ansys/rep/client/auth/resource/) - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.11.5 hooks: - id: isort exclude: ^(ansys/rep/client/jms/resource/|ansys/rep/client/auth/resource/) diff --git a/ansys/rep/client/jms/api/base.py b/ansys/rep/client/jms/api/base.py index da9708b65..1982708bf 100644 --- a/ansys/rep/client/jms/api/base.py +++ b/ansys/rep/client/jms/api/base.py @@ -1,6 +1,6 @@ import json import logging -from typing import List +from typing import List, Type from requests import Session @@ -10,7 +10,9 @@ log = logging.getLogger(__name__) -def get_objects(session: Session, url: str, obj_type: Object, as_objects=True, **query_params): +def get_objects( + session: Session, url: str, obj_type: Type[Object], as_objects=True, **query_params +): rest_name = obj_type.Meta.rest_name url = f"{url}/{rest_name}" @@ -29,7 +31,7 @@ def get_objects(session: Session, url: str, obj_type: Object, as_objects=True, * def get_object( - session: Session, url: str, obj_type: Object, id: str, as_object=True, **query_params + session: Session, url: str, obj_type: Type[Object], id: str, as_object=True, **query_params ): rest_name = obj_type.Meta.rest_name @@ -80,16 +82,21 @@ def create_objects( def update_objects( - session: Session, url: str, objects: List[Object], as_objects=True, **query_params + session: Session, + url: str, + objects: List[Object], + obj_type: Type[Object], + as_objects=True, + **query_params, ): - if not objects: - return [] - are_same = [o.__class__ == objects[0].__class__ for o in objects[1:]] + if objects is None: + raise ClientError("objects can't be None") + + are_same = [o.__class__ == obj_type for o in objects] if not all(are_same): raise ClientError("Mixed object types") - obj_type = objects[0].__class__ rest_name = obj_type.Meta.rest_name url = f"{url}/{rest_name}" diff --git a/ansys/rep/client/jms/api/jms_api.py b/ansys/rep/client/jms/api/jms_api.py index a1c6bdd21..5021c8e2b 100644 --- a/ansys/rep/client/jms/api/jms_api.py +++ b/ansys/rep/client/jms/api/jms_api.py @@ -113,7 +113,9 @@ def update_evaluators( self, evaluators: List[Evaluator], as_objects=True, **query_params ) -> List[Evaluator]: """Update evaluators configuration""" - return update_objects(self.client.session, self.url, evaluators, as_objects, **query_params) + return update_objects( + self.client.session, self.url, evaluators, Evaluator, as_objects, **query_params + ) ################################################################ # Task Definition Templates @@ -146,7 +148,14 @@ def update_task_definition_templates( templates (list of :class:`ansys.rep.client.jms.TaskDefinitionTemplate`): A list of task definition templates """ - return update_objects(self.client.session, self.url, templates, as_objects, **query_params) + return update_objects( + self.client.session, + self.url, + templates, + TaskDefinitionTemplate, + as_objects, + *query_params, + ) def delete_task_definition_templates(self, templates: List[TaskDefinitionTemplate]): """Delete existing task definition templates @@ -180,6 +189,7 @@ def update_task_definition_template_permissions( self.client.session, f"{self.url}/task_definition_templates/{template_id}", permissions, + Permission, as_objects, ) diff --git a/ansys/rep/client/jms/api/project_api.py b/ansys/rep/client/jms/api/project_api.py index c46a7ee9a..365664230 100644 --- a/ansys/rep/client/jms/api/project_api.py +++ b/ansys/rep/client/jms/api/project_api.py @@ -2,7 +2,7 @@ import logging import os from pathlib import Path -from typing import Callable, List +from typing import Callable, List, Type from cachetools import TTLCache, cached from marshmallow.utils import missing @@ -170,7 +170,7 @@ def create_parameter_definitions( def update_parameter_definitions( self, parameter_definitions: List[ParameterDefinition], as_objects=True ) -> List[ParameterDefinition]: - return self._update_objects(parameter_definitions, as_objects) + return self._update_objects(parameter_definitions, ParameterDefinition, as_objects) def delete_parameter_definitions(self, parameter_definitions: List[ParameterDefinition]): return self._delete_objects(parameter_definitions) @@ -188,7 +188,7 @@ def create_parameter_mappings( def update_parameter_mappings( self, parameter_mappings: List[ParameterMapping], as_objects=True ) -> List[ParameterMapping]: - return self._update_objects(parameter_mappings, as_objects=as_objects) + return self._update_objects(parameter_mappings, ParameterMapping, as_objects=as_objects) def delete_parameter_mappings(self, parameter_mappings: List[ParameterMapping]): return self._delete_objects(parameter_mappings) @@ -206,7 +206,7 @@ def create_task_definitions( def update_task_definitions( self, task_definitions: List[TaskDefinition], as_objects=True ) -> List[TaskDefinition]: - return self._update_objects(task_definitions, as_objects=as_objects) + return self._update_objects(task_definitions, TaskDefinition, as_objects=as_objects) def delete_task_definitions(self, task_definitions: List[TaskDefinition]): return self._delete_objects(task_definitions) @@ -224,7 +224,7 @@ def create_job_definitions( def update_job_definitions( self, job_definitions: List[JobDefinition], as_objects=True ) -> List[JobDefinition]: - return self._update_objects(job_definitions, as_objects=as_objects) + return self._update_objects(job_definitions, JobDefinition, as_objects=as_objects) def delete_job_definitions(self, job_definitions: List[JobDefinition]): return self._delete_objects(job_definitions) @@ -267,7 +267,7 @@ def update_jobs(self, jobs: List[Job], as_objects=True) -> List[Job]: Returns: List of :class:`ansys.rep.client.jms.Job` or list of dict if `as_objects` is True """ - return self._update_objects(jobs, as_objects=as_objects) + return self._update_objects(jobs, Job, as_objects=as_objects) def delete_jobs(self, jobs: List[Job]): """Delete existing jobs @@ -298,7 +298,7 @@ def get_tasks(self, as_objects=True, **query_params) -> List[Task]: return self._get_objects(Task, as_objects=as_objects, **query_params) def update_tasks(self, tasks: List[Task], as_objects=True) -> List[Task]: - return self._update_objects(tasks, as_objects=as_objects) + return self._update_objects(tasks, Task, as_objects=as_objects) ################################################################ # Selections @@ -313,7 +313,7 @@ def create_job_selections( def update_job_selections( self, selections: List[JobSelection], as_objects=True ) -> List[JobSelection]: - return self._update_objects(selections, as_objects=as_objects) + return self._update_objects(selections, JobSelection, as_objects=as_objects) def delete_job_selections(self, selections: List[JobSelection]): return self._delete_objects(selections) @@ -327,7 +327,7 @@ def create_algorithms(self, algorithms: List[Algorithm], as_objects=True) -> Lis return self._create_objects(algorithms, as_objects=as_objects) def update_algorithms(self, algorithms: List[Algorithm], as_objects=True) -> List[Algorithm]: - return self._update_objects(algorithms, as_objects=as_objects) + return self._update_objects(algorithms, Algorithm, as_objects=as_objects) def delete_algorithms(self, algorithms: List[Algorithm]): return self._delete_objects(algorithms) @@ -338,7 +338,7 @@ def get_permissions(self, as_objects=True) -> List[Permission]: return self._get_objects(Permission, as_objects=as_objects, fields=None) def update_permissions(self, permissions: List[Permission], as_objects=True): - return self._update_objects(permissions, as_objects=as_objects) + return self._update_objects(permissions, Permission, as_objects=as_objects) ################################################################ # License contexts @@ -357,7 +357,7 @@ def create_license_contexts(self, as_objects=True) -> List[LicenseContext]: return objects def update_license_contexts(self, license_contexts, as_objects=True) -> List[LicenseContext]: - return self._update_objects(self, license_contexts, as_objects=as_objects) + return self._update_objects(self, license_contexts, LicenseContext, as_objects=as_objects) def delete_license_contexts(self): rest_name = LicenseContext.Meta.rest_name @@ -371,8 +371,12 @@ def _get_objects(self, obj_type: Object, as_objects=True, **query_params): def _create_objects(self, objects: List[Object], as_objects=True, **query_params): return create_objects(self.client.session, self.url, objects, as_objects, **query_params) - def _update_objects(self, objects: List[Object], as_objects=True, **query_params): - return update_objects(self.client.session, self.url, objects, as_objects, **query_params) + def _update_objects( + self, objects: List[Object], obj_type: Type[Object], as_objects=True, **query_params + ): + return update_objects( + self.client.session, self.url, objects, obj_type, as_objects, **query_params + ) def _delete_objects(self, objects: List[Object]): delete_objects(self.client.session, self.url, objects) @@ -440,7 +444,7 @@ def create_files(project_api: ProjectApi, files, as_objects=True) -> List[File]: # (4) Update corresponding file resources in JMS with hashes of uploaded files created_files = update_objects( - project_api.client.session, project_api.url, created_files, as_objects=as_objects + project_api.client.session, project_api.url, created_files, File, as_objects=as_objects ) return created_files @@ -450,7 +454,9 @@ def update_files(project_api: ProjectApi, files: List[File], as_objects=True) -> # Upload files first if there are any src parameters _upload_files(project_api, files) # Update file resources in JMS - return update_objects(project_api.client.session, project_api.url, files, as_objects=as_objects) + return update_objects( + project_api.client.session, project_api.url, files, File, as_objects=as_objects + ) def _download_file( diff --git a/tests/auth/test_api.py b/tests/auth/test_api.py index de28078d5..55e7c0575 100644 --- a/tests/auth/test_api.py +++ b/tests/auth/test_api.py @@ -8,8 +8,6 @@ import logging import uuid -from keycloak.exceptions import KeycloakError - from ansys.rep.client import Client from ansys.rep.client.auth import AuthApi, User from tests.rep_test import REPTestCase @@ -58,7 +56,7 @@ def test_auth_client(self): usernames = [x.username for x in users] self.assertNotIn(new_user.username, usernames) - def test_auth_api_exceptions(self): + def test_get_users(self): api = AuthApi(Client(self.rep_url, username=self.username, password=self.password)) users = api.get_users() @@ -78,18 +76,14 @@ def test_auth_api_exceptions(self): last_name="User", ) new_user = api.create_user(new_user) + users = api.get_users() # use non-admin user to get users api_non_admin = AuthApi( Client(self.rep_url, username=username, password="test_auth_client") ) - except_obj = None - try: - users = api_non_admin.get_users() - except KeycloakError as e: - except_obj = e - - self.assertEqual(except_obj.response_code, 403) + users2 = api_non_admin.get_users() + self.assertEqual(len(users), len(users2)) api.delete_user(new_user) users = api.get_users() diff --git a/tests/jms/test_task_definition_templates.py b/tests/jms/test_task_definition_templates.py index 8bf8f83cd..4293c9f14 100644 --- a/tests/jms/test_task_definition_templates.py +++ b/tests/jms/test_task_definition_templates.py @@ -248,6 +248,27 @@ def test_template_permissions(self): # Delete user auth_api.delete_user(user1) + def test_template_permissions_update(self): + + client = self.client() + jms_api = JmsApi(client) + + # create new template and check default permissions + template = TaskDefinitionTemplate(name="my_template", version=uuid.uuid4()) + template = jms_api.create_task_definition_templates([template])[0] + permissions = jms_api.get_task_definition_template_permissions(template_id=template.id) + self.assertEqual(len(permissions), 1) + self.assertEqual(permissions[0].permission_type, "user") + + # remove permissions + permissions = [] + permissions = jms_api.update_task_definition_template_permissions( + template_id=template.id, permissions=permissions + ) + self.assertEqual(len(permissions), 0) + permissions = jms_api.get_task_definition_template_permissions(template_id=template.id) + self.assertEqual(len(permissions), 0) + def test_template_anyone_permission(self): client = self.client()