From fc25e2f9e19a93bbd7e6183081bae3b1853c3d45 Mon Sep 17 00:00:00 2001 From: Federico Negri Date: Fri, 12 Jan 2024 11:24:53 +0100 Subject: [PATCH] Improve type checking, add tests --- ansys/hps/client/jms/api/base.py | 37 +++++++---- ansys/hps/client/jms/api/jms_api.py | 11 +++- ansys/hps/client/jms/api/project_api.py | 46 +++++++------ ansys/hps/client/rms/api/base.py | 4 +- .../project_setup.py | 29 ++++---- tests/jms/test_jms_api.py | 66 ++++++++++++++++++- tests/jms/test_parameter_definitions.py | 34 ++++++++-- tests/test_examples.py | 43 +++++++++++- 8 files changed, 206 insertions(+), 64 deletions(-) diff --git a/ansys/hps/client/jms/api/base.py b/ansys/hps/client/jms/api/base.py index acf831250..8bbf1680f 100644 --- a/ansys/hps/client/jms/api/base.py +++ b/ansys/hps/client/jms/api/base.py @@ -52,17 +52,29 @@ def get_object( ) +def _check_object_types(objects: List[Object], obj_type: Type[Object]): + + are_same = [isinstance(o, obj_type) for o in objects] + if not all(are_same): + actual_types = set([type(o) for o in objects]) + if len(actual_types) == 1: + actual_types = actual_types.pop() + raise ClientError(f"Wrong object types: expected '{obj_type}', got {actual_types}.") + + def create_objects( - session: Session, url: str, objects: List[Object], as_objects=True, **query_params + session: Session, + url: str, + objects: List[Object], + obj_type: Type[Object], + as_objects=True, + **query_params, ): if not objects: return [] - are_same = [o.__class__ == objects[0].__class__ for o in objects[1:]] - if not all(are_same): - raise ClientError("Mixed object types") + _check_object_types(objects, obj_type) - obj_type = objects[0].__class__ rest_name = obj_type.Meta.rest_name url = f"{url}/{rest_name}" @@ -87,12 +99,10 @@ def update_objects( **query_params, ): - if objects is None: - raise ClientError("objects can't be None") + if not objects: + return [] - are_same = [o.__class__ == obj_type for o in objects] - if not all(are_same): - raise ClientError("Mixed object types") + _check_object_types(objects, obj_type) rest_name = obj_type.Meta.rest_name @@ -109,13 +119,12 @@ def update_objects( return schema.load(data) -def delete_objects(session: Session, url: str, objects: List[Object]): +def delete_objects(session: Session, url: str, objects: List[Object], obj_type: Type[Object]): + if not objects: return - are_same = [o.__class__ == objects[0].__class__ for o in objects[1:]] - if not all(are_same): - raise ClientError("Mixed object types") + _check_object_types(objects, obj_type) obj_type = objects[0].__class__ rest_name = obj_type.Meta.rest_name diff --git a/ansys/hps/client/jms/api/jms_api.py b/ansys/hps/client/jms/api/jms_api.py index c601ba920..d82718951 100644 --- a/ansys/hps/client/jms/api/jms_api.py +++ b/ansys/hps/client/jms/api/jms_api.py @@ -129,7 +129,14 @@ def create_task_definition_templates( templates (list of :class:`ansys.hps.client.jms.TaskDefinitionTemplate`): A list of task definition templates """ - return create_objects(self.client.session, self.url, templates, as_objects, **query_params) + return create_objects( + self.client.session, + self.url, + templates, + TaskDefinitionTemplate, + as_objects, + **query_params, + ) def update_task_definition_templates( self, templates: List[TaskDefinitionTemplate], as_objects=True, **query_params @@ -156,7 +163,7 @@ def delete_task_definition_templates(self, templates: List[TaskDefinitionTemplat templates (list of :class:`ansys.hps.client.jms.TaskDefinitionTemplate`): A list of task definition templates """ - return delete_objects(self.client.session, self.url, templates) + return delete_objects(self.client.session, self.url, templates, TaskDefinitionTemplate) def copy_task_definition_templates( self, templates: List[TaskDefinitionTemplate], wait: bool = True diff --git a/ansys/hps/client/jms/api/project_api.py b/ansys/hps/client/jms/api/project_api.py index d50f6e33d..370bf6eb9 100644 --- a/ansys/hps/client/jms/api/project_api.py +++ b/ansys/hps/client/jms/api/project_api.py @@ -139,7 +139,7 @@ def update_files(self, files: List[File], as_objects=True): return update_files(self, files, as_objects=as_objects) def delete_files(self, files: List[File]): - return self._delete_objects(files) + return self._delete_objects(files, File) def download_file( self, @@ -163,9 +163,9 @@ def get_parameter_definitions( return self._get_objects(ParameterDefinition, as_objects, **query_params) def create_parameter_definitions( - self, parameter_definitions, as_objects=True + self, parameter_definitions: List[ParameterDefinition], as_objects=True ) -> List[ParameterDefinition]: - return self._create_objects(parameter_definitions, as_objects) + return self._create_objects(parameter_definitions, ParameterDefinition, as_objects) def update_parameter_definitions( self, parameter_definitions: List[ParameterDefinition], as_objects=True @@ -173,7 +173,7 @@ def update_parameter_definitions( return self._update_objects(parameter_definitions, ParameterDefinition, as_objects) def delete_parameter_definitions(self, parameter_definitions: List[ParameterDefinition]): - return self._delete_objects(parameter_definitions) + return self._delete_objects(parameter_definitions, ParameterDefinition) ################################################################ # Parameter mappings @@ -183,7 +183,7 @@ def get_parameter_mappings(self, as_objects=True, **query_params) -> List[Parame def create_parameter_mappings( self, parameter_mappings: List[ParameterMapping], as_objects=True ) -> List[ParameterMapping]: - return self._create_objects(parameter_mappings, as_objects=as_objects) + return self._create_objects(parameter_mappings, ParameterMapping, as_objects=as_objects) def update_parameter_mappings( self, parameter_mappings: List[ParameterMapping], as_objects=True @@ -191,7 +191,7 @@ def update_parameter_mappings( return self._update_objects(parameter_mappings, ParameterMapping, as_objects=as_objects) def delete_parameter_mappings(self, parameter_mappings: List[ParameterMapping]): - return self._delete_objects(parameter_mappings) + return self._delete_objects(parameter_mappings, ParameterMapping) ################################################################ # Task definitions @@ -201,7 +201,7 @@ def get_task_definitions(self, as_objects=True, **query_params) -> List[TaskDefi def create_task_definitions( self, task_definitions: List[TaskDefinition], as_objects=True ) -> List[TaskDefinition]: - return self._create_objects(task_definitions, as_objects=as_objects) + return self._create_objects(task_definitions, TaskDefinition, as_objects=as_objects) def update_task_definitions( self, task_definitions: List[TaskDefinition], as_objects=True @@ -209,7 +209,7 @@ def update_task_definitions( return self._update_objects(task_definitions, TaskDefinition, as_objects=as_objects) def delete_task_definitions(self, task_definitions: List[TaskDefinition]): - return self._delete_objects(task_definitions) + return self._delete_objects(task_definitions, TaskDefinition) def copy_task_definitions( self, task_definitions: List[TaskDefinition], wait: bool = True @@ -242,7 +242,7 @@ def get_job_definitions(self, as_objects=True, **query_params) -> List[JobDefini def create_job_definitions( self, job_definitions: List[JobDefinition], as_objects=True ) -> List[JobDefinition]: - return self._create_objects(job_definitions, as_objects=as_objects) + return self._create_objects(job_definitions, JobDefinition, as_objects=as_objects) def update_job_definitions( self, job_definitions: List[JobDefinition], as_objects=True @@ -250,7 +250,7 @@ def update_job_definitions( return self._update_objects(job_definitions, JobDefinition, as_objects=as_objects) def delete_job_definitions(self, job_definitions: List[JobDefinition]): - return self._delete_objects(job_definitions) + return self._delete_objects(job_definitions, JobDefinition) def copy_job_definitions( self, job_definitions: List[JobDefinition], wait: bool = True @@ -290,7 +290,7 @@ def create_jobs(self, jobs: List[Job], as_objects=True) -> List[Job]: Returns: List of :class:`ansys.hps.client.jms.Job` or list of dict if `as_objects` is False """ - return self._create_objects(jobs, as_objects=as_objects) + return self._create_objects(jobs, Job, as_objects=as_objects) def copy_jobs(self, jobs: List[Job], wait: bool = True) -> Union[str, List[str]]: """Create new jobs by copying existing ones @@ -342,7 +342,7 @@ def delete_jobs(self, jobs: List[Job]): >>> project_api.delete_jobs(jobs_to_delete) """ - return self._delete_objects(jobs) + return self._delete_objects(jobs, Job) def sync_jobs(self, jobs: List[Job]): return sync_jobs(self, jobs) @@ -372,7 +372,7 @@ def get_job_selections(self, as_objects=True, **query_params) -> List[JobSelecti def create_job_selections( self, selections: List[JobSelection], as_objects=True ) -> List[JobSelection]: - return self._create_objects(selections, as_objects=as_objects) + return self._create_objects(selections, JobSelection, as_objects=as_objects) def update_job_selections( self, selections: List[JobSelection], as_objects=True @@ -380,7 +380,7 @@ def update_job_selections( return self._update_objects(selections, JobSelection, as_objects=as_objects) def delete_job_selections(self, selections: List[JobSelection]): - return self._delete_objects(selections) + return self._delete_objects(selections, JobSelection) ################################################################ # Algorithms @@ -388,13 +388,13 @@ def get_algorithms(self, as_objects=True, **query_params) -> List[Algorithm]: return self._get_objects(Algorithm, as_objects=as_objects, **query_params) def create_algorithms(self, algorithms: List[Algorithm], as_objects=True) -> List[Algorithm]: - return self._create_objects(algorithms, as_objects=as_objects) + return self._create_objects(algorithms, Algorithm, as_objects=as_objects) def update_algorithms(self, algorithms: List[Algorithm], as_objects=True) -> List[Algorithm]: return self._update_objects(algorithms, Algorithm, as_objects=as_objects) def delete_algorithms(self, algorithms: List[Algorithm]): - return self._delete_objects(algorithms) + return self._delete_objects(algorithms, Algorithm) ################################################################ # Permissions @@ -466,8 +466,12 @@ def copy_default_execution_script(self, filename: str) -> File: def _get_objects(self, obj_type: Object, as_objects=True, **query_params): return get_objects(self.client.session, self.url, obj_type, as_objects, **query_params) - def _create_objects(self, objects: List[Object], as_objects=True, **query_params): - return create_objects(self.client.session, self.url, objects, as_objects, **query_params) + def _create_objects( + self, objects: List[Object], obj_type: Type[Object], as_objects=True, **query_params + ): + return create_objects( + self.client.session, self.url, objects, obj_type, as_objects, **query_params + ) def _update_objects( self, objects: List[Object], obj_type: Type[Object], as_objects=True, **query_params @@ -476,8 +480,8 @@ def _update_objects( self.client.session, self.url, objects, obj_type, as_objects, **query_params ) - def _delete_objects(self, objects: List[Object]): - delete_objects(self.client.session, self.url, objects) + def _delete_objects(self, objects: List[Object], obj_type: Type[Object]): + delete_objects(self.client.session, self.url, objects, obj_type) def _download_files(project_api: ProjectApi, files: List[File]): @@ -534,7 +538,7 @@ def _upload_files(project_api: ProjectApi, files): def create_files(project_api: ProjectApi, files, as_objects=True) -> List[File]: # (1) Create file resources in JMS created_files = create_objects( - project_api.client.session, project_api.url, files, as_objects=as_objects + project_api.client.session, project_api.url, files, File, as_objects=as_objects ) # (2) Check if there are src properties, files to upload diff --git a/ansys/hps/client/rms/api/base.py b/ansys/hps/client/rms/api/base.py index 70225bdb6..0e65087ed 100644 --- a/ansys/hps/client/rms/api/base.py +++ b/ansys/hps/client/rms/api/base.py @@ -142,8 +142,8 @@ def update_objects( **query_params, ): - if objects is None: - raise ClientError("objects can't be None") + if not objects: + return [] are_same = [o.__class__ == obj_type for o in objects] if not all(are_same): diff --git a/examples/python_multi_process_step/project_setup.py b/examples/python_multi_process_step/project_setup.py index 6d3e331a1..f78f34f72 100644 --- a/examples/python_multi_process_step/project_setup.py +++ b/examples/python_multi_process_step/project_setup.py @@ -1,5 +1,5 @@ """ -Project setup script for multi process step and task file replacement testing. +Project setup script for multi steps (task definitions) and task file replacement testing. Author(s): R.Walker @@ -25,8 +25,6 @@ import os import random -from task_files import update_task_files - from ansys.hps.client import Client, HPSError from ansys.hps.client.jms import ( File, @@ -44,6 +42,8 @@ TaskDefinition, ) +from .task_files import update_task_files + log = logging.getLogger(__name__) @@ -57,8 +57,8 @@ def main( change_job_tasks, inactive, sequential, -): - """Python project implementing multiple process steps and optional image generation.""" +) -> Project: + """Python project implementing multiple steps and optional image generation.""" log.debug("=== Project") name = f"Python - {num_task_definitions} Task Defs {' - Img' if images else ''}" name += f"{' - Sequential' if sequential else ' - Parallel'}" @@ -136,23 +136,20 @@ def main( params = [] mappings = [] for i in range(num_task_definitions): - int_params = [ + new_params = [ IntParameterDefinition(name=f"period{i}", lower_limit=1, upper_limit=period, units="s"), IntParameterDefinition( name=f"duration{i}", lower_limit=0, upper_limit=duration, units="s" ), IntParameterDefinition(name=f"steps{i}", units=""), - ] - str_params = [ StringParameterDefinition( name=f"color{i}", value_list=["red", "blue", "green", "yellow", "cyan"], default='"orange"', ), ] - int_params = project_api.create_parameter_definitions(int_params) - str_params = project_api.create_parameter_definitions(str_params) - params.extend(int_params + str_params) + new_params = project_api.create_parameter_definitions(new_params) + params.extend(new_params) input_file_id = file_ids[f"td{i}_input"] result_file_id = file_ids[f"td{i}_results_json"] @@ -161,7 +158,7 @@ def main( ParameterMapping( key_string='"period"', tokenizer=":", - parameter_definition_id=int_params[0].id, + parameter_definition_id=new_params[0].id, file_id=input_file_id, ) ) @@ -169,7 +166,7 @@ def main( ParameterMapping( key_string='"duration"', tokenizer=":", - parameter_definition_id=int_params[1].id, + parameter_definition_id=new_params[1].id, file_id=input_file_id, ) ) @@ -177,7 +174,7 @@ def main( ParameterMapping( key_string='"steps"', tokenizer=":", - parameter_definition_id=int_params[2].id, + parameter_definition_id=new_params[2].id, file_id=result_file_id, ) ) @@ -186,7 +183,7 @@ def main( key_string='"color"', tokenizer=":", string_quote='"', - parameter_definition_id=str_params[0].id, + parameter_definition_id=new_params[3].id, file_id=input_file_id, ) ) @@ -260,6 +257,8 @@ def main( log.info(f"Created project '{proj.name}', ID='{proj.id}'") + return proj + if __name__ == "__main__": diff --git a/tests/jms/test_jms_api.py b/tests/jms/test_jms_api.py index 963341176..0b13b9ed7 100644 --- a/tests/jms/test_jms_api.py +++ b/tests/jms/test_jms_api.py @@ -11,9 +11,15 @@ from examples.mapdl_motorbike_frame.project_setup import create_project from marshmallow.utils import missing -from ansys.hps.client import Client +from ansys.hps.client import Client, ClientError from ansys.hps.client.jms import JmsApi, ProjectApi -from ansys.hps.client.jms.resource import Job, Project +from ansys.hps.client.jms.resource import ( + FloatParameterDefinition, + IntParameterDefinition, + Job, + JobDefinition, + Project, +) from tests.rep_test import REPTestCase log = logging.getLogger(__name__) @@ -130,6 +136,62 @@ def test_storage_configuration(self): self.assertTrue("priority" in storage) self.assertTrue("obj_type" in storage) + def test_objects_type_check(self): + + proj_name = f"test_objects_type_check" + + client = self.client + + proj = Project(name=proj_name, active=True) + job_def = JobDefinition(name="Job Def", active=True) + job = Job(name="test") + + jms_api = JmsApi(client) + + with self.assertRaises(ClientError) as context: + _ = jms_api.create_task_definition_templates([job]) + assert "Wrong object type" in str(context.exception) + assert "got " in str(context.exception) + + proj = jms_api.create_project(proj, replace=True) + project_api = ProjectApi(client, proj.id) + + job_def = JobDefinition(name="New Config", active=True) + + with self.assertRaises(ClientError) as context: + _ = project_api.create_jobs([job_def]) + assert "Wrong object type" in str(context.exception) + assert "got " in str( + context.exception + ) + + job_def = project_api.create_job_definitions([job_def])[0] + + # verify support for mixed parameter definitions + with self.assertRaises(ClientError) as context: + _ = project_api.create_parameter_definitions( + [ + FloatParameterDefinition(), + Job(), + ] + ) + msg = str(context.exception) + assert "Wrong object type" in msg + assert "" in msg + assert ( + "" + in msg + ) + + _ = project_api.create_parameter_definitions( + [ + FloatParameterDefinition(), + IntParameterDefinition(), + ] + ) + + JmsApi(client).delete_project(proj) + if __name__ == "__main__": unittest.main() diff --git a/tests/jms/test_parameter_definitions.py b/tests/jms/test_parameter_definitions.py index b72698473..e2a8a3e12 100644 --- a/tests/jms/test_parameter_definitions.py +++ b/tests/jms/test_parameter_definitions.py @@ -200,13 +200,33 @@ def test_parameter_definition_integration(self): self.assertTrue(fp.id in job_def.parameter_definition_ids) self.assertTrue(bp.id in job_def.parameter_definition_ids) - # job_def.parameter_definitions[2].upper_limit = 13.0 - # job_def.parameter_definitions[2].lower_limit = 4.5 - # job_def = proj.update_job_definitions([job_def])[0] - # job_def = proj.get_job_definitions([job_def])[0] - # self.assertEqual(len(job_def.parameter_definitions), 4) - # self.assertEqual(job_def.parameter_definitions[2].upper_limit, 13.0) - # self.assertEqual(job_def.parameter_definitions[2].lower_limit, 4.5) + # Delete project + jms_api.delete_project(proj) + + def test_mixed_parameter_definition(self): + + client = self.client + proj_name = f"test_mixed_parameter_definition" + + proj = Project(name=proj_name, active=True) + jms_api = JmsApi(client) + proj = jms_api.create_project(proj, replace=True) + project_api = ProjectApi(client, proj.id) + + ip = IntParameterDefinition(name="int_param", upper_limit=27) + sp = StringParameterDefinition(name="s_param", value_list=["l1", "l2"]) + fp = FloatParameterDefinition(name="f_param", display_text="A Float Parameter") + bp = BoolParameterDefinition(name="b_param", display_text="A Bool Parameter", default=False) + + original_pds = [ip, sp, fp, bp] + pds = project_api.create_parameter_definitions(original_pds) + + for pd, original_pd in zip(pds, original_pds): + assert type(pd) == type(original_pd) + assert pd.name == original_pd.name + + assert pds[0].upper_limit == 27 + assert pds[1].value_list == ["l1", "l2"] # Delete project jms_api.delete_project(proj) diff --git a/tests/test_examples.py b/tests/test_examples.py index dd8983c34..f625605a0 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -2,7 +2,12 @@ import unittest from ansys.hps.client import __ansys_apps_version__ as ansys_version -from ansys.hps.client.jms import JmsApi, ProjectApi +from ansys.hps.client.jms import ( + IntParameterDefinition, + JmsApi, + ProjectApi, + StringParameterDefinition, +) from tests.rep_test import REPTestCase log = logging.getLogger(__name__) @@ -241,6 +246,42 @@ def test_cfx_static_mixer(self): jms_api.delete_project(project) + def test_python_multi_steps(self): + + from examples.python_multi_process_step.project_setup import main as create_project + + num_jobs = 3 + num_task_definitions = 2 + project = create_project( + self.client, + num_task_definitions=num_task_definitions, + num_jobs=num_jobs, + duration=10, + period=3, + images=False, + change_job_tasks=0, + inactive=True, + sequential=False, + ) + self.assertIsNotNone(project) + + project_api = ProjectApi(self.client, project.id) + + self.assertEqual(len(project_api.get_jobs()), num_jobs) + + # verify we created int and string type parameter definitions + pds = project_api.get_parameter_definitions() + types = [type(pd) for pd in pds] + + assert len(types) == 4 * num_task_definitions + + types = set(types) + assert len(types) == 2 + assert StringParameterDefinition in types + assert IntParameterDefinition in types + + JmsApi(self.client).delete_project(project) + if __name__ == "__main__": unittest.main()