diff --git a/CHANGELOG.md b/CHANGELOG.md index 6dfc524c4..4ea944a52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## v0.8.0 + +### Added +- performance analysis function + ### Fixed +- babylon display fix - Any typing does not trigger error with subclass anymore +- workflow: imposed variable values fixes + +### Performance +- types: caching type from calling import_module ## v0.7.0 diff --git a/code_pylint.py b/code_pylint.py index 78e3bd0cc..ea6448e8d 100644 --- a/code_pylint.py +++ b/code_pylint.py @@ -21,7 +21,7 @@ 'protected-access': 27, 'invalid-name': 8, 'consider-using-f-string': 10, - 'no-else-return': 4, + 'no-else-return': 0, 'arguments-differ': 12, 'no-member': 1, 'too-many-locals': 14, @@ -40,7 +40,7 @@ 'unnecessary-comprehension': 5, 'no-value-for-parameter': 2, 'too-many-return-statements': 8, - 'raise-missing-from': 6, + 'raise-missing-from': 0, 'consider-merging-isinstance': 6, 'abstract-method': 6, 'import-outside-toplevel': 7, @@ -52,7 +52,6 @@ 'consider-using-get': 2, 'undefined-loop-variable': 2, 'consider-using-with': 2, - 'eval-used': 2, 'too-many-nested-blocks': 2, 'bad-staticmethod-argument': 1, 'too-many-public-methods': 2, # Try to lower by splitting DessiaObject and Workflow @@ -63,6 +62,7 @@ 'use-maxsplit-arg': 1, 'duplicate-code': 1, # No tolerance errors + 'eval-used': 0, 'redefined-builtin': 0, 'arguments-renamed': 0, 'ungrouped-imports': 0, diff --git a/dessia_common/breakdown.py b/dessia_common/breakdown.py index 20e2bb413..27a5a0402 100644 --- a/dessia_common/breakdown.py +++ b/dessia_common/breakdown.py @@ -41,7 +41,7 @@ def extract_from_object(object_, segment): try: return object_[int(segment)] - except ValueError: + except ValueError as error: # should be a tuple if segment.startswith('(') and segment.endswith(')') and ',' in segment: key = [] @@ -53,7 +53,7 @@ def extract_from_object(object_, segment): key.append(subkey) return object_[tuple(key)] # else: - raise NotImplementedError(f'Cannot extract segment {segment} from object {object_}') + raise ValueError(f'Cannot extract segment {segment} from object {object_}') from error # Finally, it is a regular object return getattr(object_, segment) @@ -63,7 +63,15 @@ def get_in_object_from_path(object_, path): segments = path.lstrip('#/').split('/') element = object_ for segment in segments: - element = extract_from_object(element, segment) + if isinstance(element, dict) and '$ref' in element: + # Going down in the object and it is a reference + # Evaluating subreference + element = get_in_object_from_path(object_, element['$ref']) + try: + element = extract_from_object(element, segment) + except ValueError as err: + print(err) + raise ValueError(f'Cannot get segment {segment} from path {path} in element {element}') from err return element diff --git a/dessia_common/core.py b/dessia_common/core.py index 4876d26ac..e15296813 100755 --- a/dessia_common/core.py +++ b/dessia_common/core.py @@ -3,6 +3,8 @@ """ """ + +import time import sys import warnings import math @@ -12,7 +14,7 @@ from copy import deepcopy import inspect import json -from operator import attrgetter +# from operator import attrgetter from typing import List, Dict, Any, Tuple, get_type_hints import traceback as tb @@ -420,10 +422,10 @@ def dict_to_arguments(self, dict_, method): value = dict_[str(i)] try: deserialized_value = deserialize_argument(arg_specs, value) - except TypeError: + except TypeError as err: msg = 'Error in deserialisation of value: ' msg += f'{value} of expected type {arg_specs}' - raise TypeError(msg) + raise TypeError(msg) from err arguments[arg] = deserialized_value return arguments @@ -527,7 +529,7 @@ def mpl_plot(self, **kwargs): try: plot_datas = self.plot_data(**kwargs) except TypeError as error: - raise TypeError(f'{self.__class__.__name__}.{error}') + raise TypeError(f'{self.__class__.__name__}.{error}') from error for data in plot_datas: if hasattr(data, 'mpl_plot'): ax = data.mpl_plot() @@ -580,9 +582,37 @@ def _displays(self, **kwargs) -> List[JsonSerializable]: displays.append(display_.to_dict()) return displays - def to_markdown(self): + def to_markdown(self) -> str: + """ + Render a markdown of the object output type: string + """ return templates.dessia_object_markdown_template.substitute(name=self.name) + def _performance_analysis(self): + """ + Prints time of rendering some commons operations (serialization, hash, displays) + """ + data_hash_time = time.time() + self._data_hash() + data_hash_time = time.time() - data_hash_time + print(f'Data hash time: {round(data_hash_time, 3)} seconds') + + todict_time = time.time() + dict_ = self.to_dict() + todict_time = time.time() - todict_time + print(f'to_dict time: {round(todict_time, 3)} seconds') + + dto_time = time.time() + self.dict_to_object(dict_) + dto_time = time.time() - dto_time + print(f'dict_to_object time: {round(dto_time, 3)} seconds') + + for display_setting in self.display_settings(): + display_time = time.time() + self._display_from_selector(display_setting.selector) + display_time = time.time() - display_time + print(f'Generation of display {display_setting.selector} in: {round(display_time, 6)} seconds') + def _check_platform(self): """ Reproduce lifecycle on platform (serialization, display) @@ -601,6 +631,7 @@ def _check_platform(self): ' after serialization/deserialization') copied_object = self.copy() if not copied_object._data_eq(self): + print('data diff: ', self._data_diff(copied_object)) raise dessia_common.errors.CopyError('Object is not equal to itself' ' after copy') @@ -624,8 +655,7 @@ def to_xlsx_stream(self, stream): writer = XLSXWriter(self) writer.save_to_stream(stream) - @staticmethod - def _export_formats(): + def _export_formats(self): formats = [{"extension": "json", "method_name": "save_to_stream", "text": True, "args": {}}, {"extension": "xlsx", "method_name": "to_xlsx_stream", "text": False, "args": {}}] return formats @@ -707,9 +737,8 @@ def save_babylonjs_to_file(self, filename: str = None, use_cdn: bool = True, use_cdn=use_cdn, debug=debug) - @staticmethod - def _export_formats(): - formats = DessiaObject._export_formats() + def _export_formats(self): + formats = DessiaObject._export_formats(self) formats3d = [{"extension": "step", "method_name": "to_step_stream", "text": True, "args": {}}, {"extension": "stl", "method_name": "to_stl_stream", "text": False, "args": {}}] formats.extend(formats3d) diff --git a/dessia_common/forms.py b/dessia_common/forms.py index e032094ee..14434488b 100644 --- a/dessia_common/forms.py +++ b/dessia_common/forms.py @@ -26,7 +26,7 @@ """ from math import floor, ceil, cos -from typing import Dict, List, Tuple, Union, TextIO, BinaryIO +from typing import Dict, List, Tuple, Union from numpy import linspace try: @@ -46,14 +46,14 @@ from dessia_common.files import BinaryFile, StringFile -class StandaloneSubobject(DessiaObject): +class StandaloneSubobject(PhysicalObject): _standalone_in_db = True _generic_eq = True def __init__(self, floatarg: Distance, name: str = 'Standalone Subobject'): self.floatarg = floatarg - DessiaObject.__init__(self, name=name) + PhysicalObject.__init__(self, name=name) @classmethod def generate(cls, seed: int) -> 'StandaloneSubobject': @@ -67,16 +67,16 @@ def generate_many(cls, seed: int) -> List['StandaloneSubobject']: return subobjects def contour(self): - points = [vm.Point2D(0, 0), vm.Point2D(0, 1), - vm.Point2D(1, 1), vm.Point2D(1, 0)] + origin = self.floatarg + points = [vm.Point2D(origin, origin), vm.Point2D(origin, origin + 1), + vm.Point2D(origin + 1, origin + 1), vm.Point2D(origin + 1, origin)] crls = p2d.ClosedRoundedLineSegments2D(points=points, radius={}) return crls def voldmlr_primitives(self): contour = self.contour() - volumes = [p3d.ExtrudedProfile(vm.O3D, vm.X3D, vm.Z3D, - contour, [], vm.Y3D)] + volumes = [p3d.ExtrudedProfile(vm.O3D, vm.X3D, vm.Z3D, contour, [], vm.Y3D)] return volumes @@ -149,16 +149,14 @@ def __init__(self, embedded_list: List[int] = None, name: str = 'Enhanced Embedded Subobject'): self.embedded_array = embedded_array - EmbeddedSubobject.__init__(self, embedded_list=embedded_list, - name=name) + EmbeddedSubobject.__init__(self, embedded_list=embedded_list, name=name) @classmethod def generate(cls, seed: int) -> 'EnhancedEmbeddedSubobject': embedded_list = [seed] embedded_array = [[seed, seed * 10, seed * 10]] * seed name = 'Embedded Subobject' + str(seed) - return cls(embedded_list=embedded_list, embedded_array=embedded_array, - name=name) + return cls(embedded_list=embedded_list, embedded_array=embedded_array, name=name) DEF_ES = EmbeddedSubobject.generate(10) @@ -184,8 +182,7 @@ class StandaloneObject(PhysicalObject): _standalone_in_db = True _generic_eq = True _allowed_methods = ['add_standalone_object', 'add_embedded_object', - 'add_float', 'generate_from_text', 'generate_from_bin', - 'generate_from_bin_file', 'generate_from_text_file'] + 'add_float', 'generate_from_text', 'generate_from_bin'] def __init__(self, standalone_subobject: StandaloneSubobject, embedded_subobject: EmbeddedSubobject, dynamic_dict: Dict[str, bool], float_dict: Dict[str, float], string_dict: Dict[str, str], @@ -212,8 +209,7 @@ def __init__(self, standalone_subobject: StandaloneSubobject, embedded_subobject PhysicalObject.__init__(self, name=name) @classmethod - def generate(cls, seed: int, - name: str = 'Standalone Object Demo') -> 'StandaloneObject': + def generate(cls, seed: int, name: str = 'Standalone Object Demo') -> 'StandaloneObject': is_even = not bool(seed % 2) standalone_subobject = StandaloneSubobject.generate(seed) embedded_subobject = EmbeddedSubobject.generate(seed) @@ -233,65 +229,27 @@ def generate(cls, seed: int, subclass_arg = StandaloneSubobject.generate(-seed) else: subclass_arg = InheritingStandaloneSubobject.generate(seed) - return cls(standalone_subobject=standalone_subobject, - embedded_subobject=embedded_subobject, - dynamic_dict=dynamic_dict, - float_dict=float_dict, - string_dict=string_dict, - tuple_arg=tuple_arg, - intarg=intarg, strarg=strarg, object_list=object_list, - subobject_list=subobject_list, builtin_list=builtin_list, - union_arg=union_arg, subclass_arg=subclass_arg, + return cls(standalone_subobject=standalone_subobject, embedded_subobject=embedded_subobject, + dynamic_dict=dynamic_dict, float_dict=float_dict, string_dict=string_dict, tuple_arg=tuple_arg, + intarg=intarg, strarg=strarg, object_list=object_list, subobject_list=subobject_list, + builtin_list=builtin_list, union_arg=union_arg, subclass_arg=subclass_arg, array_arg=array_arg, name=name) @classmethod - def generate_from_text(cls, stream: TextIO): - try: - my_string = stream.read() - # this is a hack for test until we get frontend support for types BinaryFile & StringFile - # a TextIO does not have filename, but it's ok since we return a StringFile from backend - my_file_name = stream.filename - _, raw_seed = my_string.split(",") - seed = int(raw_seed.strip()) - finally: - stream.close() - return cls.generate(seed=seed, name=my_file_name) - - @classmethod - def generate_from_bin(cls, stream: BinaryIO): - # the user need to decode the binary as he see fit - try: - my_string = stream.read().decode('utf8') - # this is a hack for test until we get frontend support for types BinaryFile & StringFile - # a BinaryIO does not have filename, but it's ok since we return a BinaryFile from backend - my_file_name = stream.filename - _, raw_seed = my_string.split(",") - seed = int(raw_seed.strip()) - finally: - stream.close() - return cls.generate(seed=seed, name=my_file_name) - - @classmethod - def generate_from_bin_file(cls, stream: BinaryFile): - # the user need to decode the binary as he see fit - try: - my_string = stream.read().decode('utf8') - my_file_name = stream.filename - _, raw_seed = my_string.split(",") - seed = int(raw_seed.strip()) - finally: - stream.close() + def generate_from_bin(cls, stream: BinaryFile): + # User need to decode the binary as he see fit + my_string = stream.read().decode('utf8') + my_file_name = stream.filename + _, raw_seed = my_string.split(",") + seed = int(raw_seed.strip()) return cls.generate(seed=seed, name=my_file_name) @classmethod - def generate_from_text_file(cls, stream: StringFile): - try: - my_text = stream.read() - my_file_name = stream.filename - _, raw_seed = my_text.split(",") - seed = int(raw_seed.strip()) - finally: - stream.close() + def generate_from_text(cls, stream: StringFile): + my_text = stream.getvalue() + my_file_name = stream.filename + _, raw_seed = my_text.split(",") + seed = int(raw_seed.strip()) return cls.generate(seed=seed, name=my_file_name) def add_standalone_object(self, object_: StandaloneSubobject): @@ -327,37 +285,27 @@ def plot_data(self): # Contour contour = self.standalone_subobject.contour().plot_data() - primitives_group = plot_data.PrimitiveGroup(primitives=[contour], - name='Contour') + primitives_group = plot_data.PrimitiveGroup(primitives=[contour], name='Contour') # Scatter Plot bounds = {'x': [0, 6], 'y': [100, 2000]} catalog = Catalog.random_2d(bounds=bounds, threshold=8000) - points = [plot_data.Point2D(cx=v[0], cy=v[1], name='Point' + str(i)) - for i, v in enumerate(catalog.array)] + points = [plot_data.Point2D(cx=v[0], cy=v[1], name='Point' + str(i)) for i, v in enumerate(catalog.array)] axis = plot_data.Axis() - tooltip = plot_data.Tooltip(attributes=attributes, - name='Tooltips') - scatter_plot = plot_data.Scatter(axis=axis, tooltip=tooltip, - elements=points, - x_variable=attributes[0], - y_variable=attributes[1], - name='Scatter Plot') + tooltip = plot_data.Tooltip(attributes=attributes, name='Tooltips') + scatter_plot = plot_data.Scatter(axis=axis, tooltip=tooltip, x_variable=attributes[0], + y_variable=attributes[1], name='Scatter Plot') # Parallel Plot attributes = ['cx', 'cy', 'color_fill', 'color_stroke'] - parallel_plot = plot_data.ParallelPlot(elements=points, - axes=attributes, - name='Parallel Plot') + parallel_plot = plot_data.ParallelPlot(elements=points, axes=attributes, name='Parallel Plot') # Multi Plot objects = [scatter_plot, parallel_plot] - sizes = [plot_data.Window(width=560, height=300), - plot_data.Window(width=560, height=300)] + sizes = [plot_data.Window(width=560, height=300), plot_data.Window(width=560, height=300)] coords = [(0, 0), (300, 0)] - multi_plot = plot_data.MultiplePlots(elements=points, plots=objects, - sizes=sizes, coords=coords, - name='Multiple Plot') + multi_plot = plot_data.MultiplePlots(elements=points, plots=objects, sizes=sizes, + coords=coords, name='Multiple Plot') attribute_names = ['time', 'electric current'] tooltip = plot_data.Tooltip(attributes=attribute_names) @@ -367,15 +315,12 @@ def plot_data(self): for time, current in zip(time1, current1): elements1.append({'time': time, 'electric current': current}) - # The previous line instantiates a dataset with limited arguments but - # several customizations are available + # The previous line instantiates a dataset with limited arguments but several customizations are available point_style = plot_data.PointStyle(color_fill=plot_data.colors.RED, color_stroke=plot_data.colors.BLACK) edge_style = plot_data.EdgeStyle(color_stroke=plot_data.colors.BLUE, dashline=[10, 5]) - custom_dataset = plot_data.Dataset(elements=elements1, name='I = f(t)', - tooltip=tooltip, - point_style=point_style, - edge_style=edge_style) + custom_dataset = plot_data.Dataset(elements=elements1, name='I = f(t)', tooltip=tooltip, + point_style=point_style, edge_style=edge_style) # Now let's create another dataset for the purpose of this exercice time2 = linspace(0, 20, 100) @@ -387,8 +332,7 @@ def plot_data(self): dataset2 = plot_data.Dataset(elements=elements2, name='I2 = f(t)') graph2d = plot_data.Graph2D(graphs=[custom_dataset, dataset2], - x_variable=attribute_names[0], - y_variable=attribute_names[1]) + x_variable=attribute_names[0], y_variable=attribute_names[1]) return [primitives_group, scatter_plot, parallel_plot, multi_plot, graph2d] @@ -518,15 +462,12 @@ def __init__(self, standalone_subobject: StandaloneSubobject = DEF_SS, if array_arg is None: array_arg = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - StandaloneObject.__init__( - self, standalone_subobject=standalone_subobject, - embedded_subobject=embedded_subobject, dynamic_dict=dynamic_dict, - float_dict=float_dict, string_dict=string_dict, - tuple_arg=tuple_arg, intarg=intarg, strarg=strarg, - object_list=object_list, subobject_list=subobject_list, - builtin_list=builtin_list, union_arg=union_arg, - subclass_arg=subclass_arg, array_arg=array_arg, name=name - ) + StandaloneObject.__init__(self, standalone_subobject=standalone_subobject, + embedded_subobject=embedded_subobject, dynamic_dict=dynamic_dict, + float_dict=float_dict, string_dict=string_dict, tuple_arg=tuple_arg, intarg=intarg, + strarg=strarg, object_list=object_list, subobject_list=subobject_list, + builtin_list=builtin_list, union_arg=union_arg, subclass_arg=subclass_arg, + array_arg=array_arg, name=name) DEF_SOWDV = StandaloneObjectWithDefaultValues() @@ -545,8 +486,7 @@ class Generator(DessiaObject): """ _standalone_in_db = True - def __init__(self, parameter: int, nb_solutions: int = 25, - models: List[StandaloneObject] = None, name: str = ''): + def __init__(self, parameter: int, nb_solutions: int = 25, models: List[StandaloneObject] = None, name: str = ''): self.parameter = parameter self.nb_solutions = nb_solutions self.models = models @@ -557,8 +497,7 @@ def generate(self) -> List[StandaloneObject]: """ Generates a list of Standalone objects """ - self.models = [StandaloneObject.generate(self.parameter + i) - for i in range(self.nb_solutions)] + self.models = [StandaloneObject.generate(self.parameter + i) for i in range(self.nb_solutions)] return self.models @@ -602,7 +541,6 @@ def __init__(self, models: List[StandaloneObject] = None, name: str = ""): DessiaObject.__init__(self, name=name) @classmethod - def generate_from_text_files(cls, files: List[TextIO], - name: str = "Generated from text files"): + def generate_from_text_files(cls, files: List[StringFile], name: str = "Generated from text files"): models = [StandaloneObject.generate_from_text(file) for file in files] return cls(models=models, name=name) diff --git a/dessia_common/typings.py b/dessia_common/typings.py index 5dd398543..591f25917 100644 --- a/dessia_common/typings.py +++ b/dessia_common/typings.py @@ -54,6 +54,10 @@ class Distance(Measure): units = 'm' +class Angle(Measure): + units = 'radians' + + class Torque(Measure): units = 'Nm' diff --git a/dessia_common/utils/diff.py b/dessia_common/utils/diff.py index 3c8372db9..c1878e8e1 100644 --- a/dessia_common/utils/diff.py +++ b/dessia_common/utils/diff.py @@ -34,29 +34,29 @@ def diff(value1, value2, path='#'): if value1 != value2: diff_values.append((path, value1, value2)) return diff_values, missing_keys_in_other_object, invalid_types - elif isinstance(value1, dict): + if isinstance(value1, dict): return dict_diff(value1, value2, path=path) # elif hasattr(value1, '_data_eq'): - else: - # Should be object - if hasattr(value1, '_data_eq'): - # DessiaObject - if value1._data_eq(value2): - return [], [], [] - # Use same code snippet as in data_eq - eq_dict = value1._serializable_dict() - if 'name' in eq_dict: - del eq_dict['name'] + # Should be object + if hasattr(value1, '_data_eq'): + # DessiaObject + if value1._data_eq(value2): + return [], [], [] + + # Use same code snippet as in data_eq + eq_dict = value1._serializable_dict() + if 'name' in eq_dict: + del eq_dict['name'] - other_eq_dict = value2._serializable_dict() + other_eq_dict = value2._serializable_dict() - return dict_diff(eq_dict, other_eq_dict) + return dict_diff(eq_dict, other_eq_dict) - if value1 == value2: - return [], [], [] + if value1 == value2: + return [], [], [] - raise NotImplementedError('Undefined type in diff: {}'.format(type(value1))) + raise NotImplementedError('Undefined type in diff: {}'.format(type(value1))) def dict_diff(dict1, dict2, path='#'): diff --git a/dessia_common/utils/jsonschema.py b/dessia_common/utils/jsonschema.py index d75b2e601..90594a3f7 100644 --- a/dessia_common/utils/jsonschema.py +++ b/dessia_common/utils/jsonschema.py @@ -67,15 +67,15 @@ def chose_default(jsonschema): datatype = datatype_from_jsonschema(jsonschema) if datatype in ['heterogeneous_sequence', 'homogeneous_sequence']: return default_sequence(jsonschema) - elif datatype == 'static_dict': + if datatype == 'static_dict': return default_dict(jsonschema) - elif datatype in ['standalone_object', 'embedded_object', - 'instance_of', 'union']: + if datatype in ['standalone_object', 'embedded_object', + 'instance_of', 'union']: if 'default_value' in jsonschema: return jsonschema['default_value'] return None - else: - return None + + return None def default_dict(jsonschema): diff --git a/dessia_common/utils/serialization.py b/dessia_common/utils/serialization.py index b786eede1..302c5335f 100644 --- a/dessia_common/utils/serialization.py +++ b/dessia_common/utils/serialization.py @@ -5,13 +5,14 @@ """ +from ast import literal_eval import warnings import inspect import collections from typing import get_origin, get_args, Union, Any, TextIO, BinaryIO import dessia_common as dc import dessia_common.errors as dc_err -import dessia_common.files +from dessia_common.files import StringFile, BinaryFile import dessia_common.utils.types as dcty from dessia_common.typings import InstanceOf from dessia_common.graph import explore_tree_from_leaves # , cut_tree_final_branches @@ -55,7 +56,7 @@ def serialize(value): serialized_value = serialize_dict(value) elif dcty.is_sequence(value): serialized_value = serialize_sequence(value) - elif isinstance(value, (dessia_common.files.BinaryFile, dessia_common.files.StringFile)): + elif isinstance(value, (BinaryFile, StringFile)): serialized_value = value elif isinstance(value, type) or dcty.is_typing(value): return dcty.serialize_typing(value) @@ -89,7 +90,7 @@ def serialize_with_pointers(value, memo=None, path='#'): serialized, memo = serialize_dict_with_pointers(value, memo, path) elif dcty.is_sequence(value): serialized, memo = serialize_sequence_with_pointers(value, memo, path) - elif isinstance(value, (dessia_common.files.BinaryFile, dessia_common.files.StringFile)): + elif isinstance(value, (BinaryFile, StringFile)): serialized = value else: if not dcty.is_jsonable(value): @@ -155,17 +156,17 @@ def deserialize(serialized_element, sequence_annotation: str = 'List', return pointers_memo[path] if isinstance(serialized_element, dict): - try: - return dict_to_object(serialized_element, global_dict=global_dict, - pointers_memo=pointers_memo, - path=path) - except TypeError: - warnings.warn(f'specific dict_to_object of class {serialized_element.__class__.__name__}' - ' should implement global_dict and' - ' pointers_memo arguments', - Warning) - return dict_to_object(serialized_element) - elif dcty.is_sequence(serialized_element): + # try: + return dict_to_object(serialized_element, global_dict=global_dict, + pointers_memo=pointers_memo, + path=path) + # except TypeError: + # warnings.warn(f'specific dict_to_object of class {serialized_element.__class__.__name__}' + # ' should implement global_dict and' + # ' pointers_memo arguments', + # Warning) + # return dict_to_object(serialized_element) + if dcty.is_sequence(serialized_element): return deserialize_sequence(sequence=serialized_element, annotation=sequence_annotation, global_dict=global_dict, @@ -195,9 +196,9 @@ def deserialize_sequence(sequence, annotation=None, def dict_to_object(dict_, class_=None, force_generic: bool = False, global_dict=None, pointers_memo=None, path='#'): - ''' + """ Transform a dict to an object - ''' + """ class_argspec = None @@ -211,16 +212,12 @@ def dict_to_object(dict_, class_=None, force_generic: bool = False, if '$ref' in dict_: return pointers_memo[dict_['$ref']] - # working_dict = dict_ - if class_ is None and 'object_class' in dict_: class_ = dcty.get_python_class_from_class_name(dict_['object_class']) # Create init_dict - init_dict = None if class_ is not None and hasattr(class_, 'dict_to_object'): - different_methods = (class_.dict_to_object.__func__ - is not dc.DessiaObject.dict_to_object.__func__) + different_methods = (class_.dict_to_object.__func__ is not dc.DessiaObject.dict_to_object.__func__) if different_methods and not force_generic: try: @@ -229,12 +226,14 @@ def dict_to_object(dict_, class_=None, force_generic: bool = False, pointers_memo=pointers_memo, path=path) except TypeError: + warnings.warn(f'specific to_dict of class {class_.__name__} ' + 'should implement use_pointers, memo and path arguments', Warning) obj = class_.dict_to_object(dict_) + return obj class_argspec = inspect.getfullargspec(class_) - init_dict = {k: v for k, v in dict_.items() - if k in class_argspec.args + class_argspec.kwonlyargs} + init_dict = {k: v for k, v in dict_.items() if k in class_argspec.args + class_argspec.kwonlyargs} # TOCHECK Class method to generate init_dict ?? else: init_dict = dict_ @@ -251,10 +250,8 @@ def dict_to_object(dict_, class_=None, force_generic: bool = False, if key_path in pointers_memo: subobjects[key] = pointers_memo[key_path] else: - subobjects[key] = deserialize(value, annotation, - global_dict=global_dict, - pointers_memo=pointers_memo, - path=key_path) # , enforce_pointers=False) + subobjects[key] = deserialize(value, annotation, global_dict=global_dict, + pointers_memo=pointers_memo, path=key_path) # , enforce_pointers=False) if class_ is not None: obj = class_(**subobjects) else: @@ -265,7 +262,7 @@ def dict_to_object(dict_, class_=None, force_generic: bool = False, def deserialize_with_type(type_, value): if type_ in dcty.TYPES_STRINGS.values(): - return eval(type_)(value) + return literal_eval(type_)(value) if isinstance(type_, str): class_ = dcty.get_python_class_from_class_name(type_) if inspect.isclass(class_): @@ -316,22 +313,19 @@ def deserialize_with_typing(type_, argument): elif origin in [list, collections.Iterator]: # Homogenous sequences (lists) sequence_subtype = args[0] - deserialized_arg = [deserialize_argument(sequence_subtype, arg) - for arg in argument] + deserialized_arg = [deserialize_argument(sequence_subtype, arg) for arg in argument] if origin is collections.Iterator: deserialized_arg = iter(deserialized_arg) elif origin is tuple: # Heterogenous sequences (tuples) - deserialized_arg = tuple([deserialize_argument(t, arg) - for t, arg in zip(args, argument)]) + deserialized_arg = tuple([deserialize_argument(t, arg) for t, arg in zip(args, argument)]) elif origin is dict: # Dynamic dict deserialized_arg = argument elif origin is InstanceOf: classname = args[0] - object_class = dc.full_classname(object_=classname, - compute_for='class') + object_class = dc.full_classname(object_=classname, compute_for='class') class_ = dcty.get_python_class_from_class_name(object_class) deserialized_arg = class_.dict_to_object(argument) else: @@ -346,13 +340,9 @@ def deserialize_argument(type_, argument): """ if argument is None: return None - if dcty.is_typing(type_): return deserialize_with_typing(type_, argument) - elif type_ is TextIO: - deserialized_arg = argument - elif type_ is BinaryIO: - # files are supplied as io.BytesIO which is compatible with : BinaryIO + if type_ in [TextIO, BinaryIO, StringFile, BinaryFile]: deserialized_arg = argument else: if type_ in dcty.TYPING_EQUIVALENCES.keys(): @@ -363,9 +353,8 @@ def deserialize_argument(type_, argument): # Explicit conversion in this case deserialized_arg = float(argument) else: - msg = 'Given built-in type and argument are incompatible: ' - msg += '{} and {} in {}'.format(type(argument), - type_, argument) + msg = f"Given built-in type and argument are incompatible: " \ + f"{type(argument)} and {type_} in {argument}" raise TypeError(msg) elif type_ is Any: # Any type @@ -389,7 +378,7 @@ def find_references(value, path='#'): return [] if dcty.is_sequence(value): return find_references_sequence(value, path) - if isinstance(value, (dessia_common.files.BinaryFile, dessia_common.files.StringFile)): + if isinstance(value, (BinaryFile, StringFile)): return [] raise ValueError(value) @@ -507,7 +496,7 @@ def pointer_graph_elements(value, path='#'): return pointer_graph_elements_dict(value, path) if dcty.isinstance_base_types(value): return [], [] - elif dcty.is_sequence(value): + if dcty.is_sequence(value): return pointer_graph_elements_sequence(value, path) raise ValueError(value) diff --git a/dessia_common/utils/types.py b/dessia_common/utils/types.py index c6a819cf9..28baf6036 100644 --- a/dessia_common/utils/types.py +++ b/dessia_common/utils/types.py @@ -3,6 +3,7 @@ """ """ +from ast import literal_eval from typing import Any, Dict, List, Tuple, Type, Union, TextIO, BinaryIO, get_origin, get_args import dessia_common as dc @@ -24,6 +25,8 @@ SERIALIZED_BUILTINS = ['float', 'builtins.float', 'int', 'builtins.int', 'str', 'builtins.str', 'bool', 'builtins.bool'] +_PYTHON_CLASS_CACHE = {} + def full_classname(object_, compute_for: str = 'instance'): if compute_for == 'instance': @@ -67,9 +70,16 @@ def isinstance_base_types(obj): def get_python_class_from_class_name(full_class_name): + cached_value = _PYTHON_CLASS_CACHE.get(full_class_name, None) + if cached_value is not None: + return cached_value + module_name, class_name = full_class_name.rsplit('.', 1) module = import_module(module_name) class_ = getattr(module, class_name) + + # Storing in cache + _PYTHON_CLASS_CACHE[full_class_name] = class_ return class_ @@ -152,8 +162,7 @@ def type_from_argname(argname): if argname: if splitted_argname[0] != '__builtins__': return get_python_class_from_class_name(argname) - # TODO Check for dangerous eval - return eval(splitted_argname[1]) + return literal_eval(splitted_argname[1]) return Any @@ -309,7 +318,7 @@ def typematch(type_: Type, match_against: Type) -> bool: """ # TODO Implement a more intelligent check for Unions : Union[T, U] should match against Union[T, U, V] # TODO Implement a check for Dict - if type_ == match_against or match_against is Any: + if type_ == match_against or match_against is Any or particular_typematches(type_, match_against): # Trivial cases. If types are strictly equal, then it should pass straight away return True @@ -322,42 +331,76 @@ def typematch(type_: Type, match_against: Type) -> bool: return True # type_ is not complex and match_against is - origin = get_origin(match_against) - args = get_args(match_against) + match_against, origin, args = heal_type(match_against) if origin is Union: matches = [typematch(type_, subtype) for subtype in args] return any(matches) return False -def complex_first_type_match(type_: Type, match_against: Type): +def complex_first_type_match(type_: Type, match_against: Type) -> bool: """ Match type when type_ is a complex typing (List, Union, Tuple,...) """ # Complex typing for the first type_. Cases : List, Tuple, Union - type_origin = get_origin(type_) - type_args = get_args(type_) if not is_typing(match_against): # Type matching is unilateral and match against should be more open than type_ return False - match_against_origin = get_origin(match_against) - match_against_args = get_args(match_against) - - if type_origin is Union: - # Check for default values false positive - if union_is_default_value(type_): - return typematch(type_args[0], match_against) + # Inspecting and healing types + type_, type_origin, type_args = heal_type(type_) + match_against, match_against_origin, match_against_args = heal_type(match_against) if type_origin != match_against_origin: # Being strict for now. Is there any other case than default values where this would be wrong ? return False if type_origin is list: + # Can only have one arg, should match return typematch(type_args[0], match_against_args[0]) if type_origin is tuple: + # Order matters, all args should match return all(typematch(a, b) for a, b in zip(type_args, match_against_args)) + if type_origin is dict: + # key type AND value type should match + return typematch(type_args[0], match_against_args[0]) and typematch(type_args[1], match_against_args[1]) + + if type_origin is Union: + # type args must be a subset of match_against args set + type_argsset = set(type_args) + match_against_argsset = set(match_against_args) + return type_argsset.issubset(match_against_argsset) + # Otherwise, it is not implemented raise NotImplementedError(f"Type {type_} is a complex typing and cannot be matched against others yet") + + +def heal_type(type_: Type): + """ + Inspect type and returns its params + + For now, only checks wether the type is an 'Optional' / Union[T, NoneType], + which should be flattened and not considered + + returns the cleaned type, origin and args + """ + type_origin = get_origin(type_) + type_args = get_args(type_) + if type_origin is Union: + # Check for default values false positive + if union_is_default_value(type_): + type_ = type_args[0] + type_origin = get_origin(type_) + type_args = get_args(type_) + return type_, type_origin, type_args + + +def particular_typematches(type_: Type, match_against: Type) -> bool: + """ + Checks for specific cases of typematches and returns and boolean + """ + if type_ is int and match_against is float: + return True + return False diff --git a/dessia_common/workflow/blocks.py b/dessia_common/workflow/blocks.py index 15c5f5fc5..2bfcc1d5d 100644 --- a/dessia_common/workflow/blocks.py +++ b/dessia_common/workflow/blocks.py @@ -39,8 +39,9 @@ def set_inputs_from_function(method, inputs=None): try: annotations = get_type_hints(method) type_ = type_from_annotation(annotations[argument], module=method.__module__) - except KeyError: - raise UntypedArgumentError(f"Argument {argument} of method/function {method.__name__} has no typing") + except KeyError as error: + raise UntypedArgumentError(f"Argument {argument} of method/function {method.__name__} has no typing")\ + from error if iarg >= nargs - ndefault_args: default = args_specs.defaults[ndefault_args - nargs + iarg] input_ = TypedVariableWithDefaultValue(type_=type_, default_value=default, name=argument) @@ -379,7 +380,8 @@ def to_dict(self, use_pointers=True, memo=None, path: str = '#'): @classmethod @set_block_variable_names_from_dict - def dict_to_object(cls, dict_): + def dict_to_object(cls, dict_: JsonSerializable, force_generic: bool = False, + global_dict=None, pointers_memo: Dict[str, Any] = None, path: str = '#'): return cls(workflow=Workflow.dict_to_object(dict_['workflow']), name=dict_['name']) def evaluate(self, values): diff --git a/dessia_common/workflow/core.py b/dessia_common/workflow/core.py index f3dabad00..9e94f8e9f 100644 --- a/dessia_common/workflow/core.py +++ b/dessia_common/workflow/core.py @@ -3,8 +3,7 @@ """ Gathers all workflow relative features """ - - +import ast import time import datetime import tempfile @@ -16,16 +15,14 @@ from typing import List, Union, Type, Any, Dict, Tuple, Optional from copy import deepcopy from dessia_common.templates import workflow_template -from dessia_common import DessiaObject, is_sequence,\ - JSONSCHEMA_HEADER, jsonschema_from_annotation,\ +from dessia_common import DessiaObject, is_sequence, JSONSCHEMA_HEADER, jsonschema_from_annotation,\ deserialize_argument, set_default_value, prettyname, serialize_dict, DisplaySetting from dessia_common.utils.serialization import dict_to_object, deserialize, serialize_with_pointers, serialize,\ dereference_jsonpointers -from dessia_common.utils.types import serialize_typing,\ - deserialize_typing, recursive_type, typematch +from dessia_common.utils.types import serialize_typing, deserialize_typing, recursive_type, typematch from dessia_common.utils.copy import deepcopy_value -from dessia_common.utils.docstrings import FAILED_ATTRIBUTE_PARSING +from dessia_common.utils.docstrings import FAILED_ATTRIBUTE_PARSING, EMPTY_PARSED_ATTRIBUTE from dessia_common.utils.diff import choose_hash from dessia_common.typings import JsonSerializable, MethodType import warnings @@ -68,7 +65,7 @@ def to_dict(self, use_pointers=True, memo=None, path: str = '#'): @classmethod def dict_to_object(cls, dict_: JsonSerializable, force_generic: bool = False, - global_dict=None, pointers_memo: Dict[str, Any] = None) -> 'TypedVariable': + global_dict=None, pointers_memo: Dict[str, Any] = None, path: str = '#') -> 'TypedVariable': type_ = deserialize_typing(dict_['type_']) memorize = dict_['memorize'] return cls(type_=type_, memorize=memorize, name=dict_['name']) @@ -111,7 +108,7 @@ def to_dict(self, use_pointers: bool = True, memo=None, path: str = '#'): @classmethod def dict_to_object(cls, dict_: JsonSerializable, force_generic: bool = False, global_dict=None, - pointers_memo: Dict[str, Any] = None) -> 'TypedVariableWithDefaultValue': + pointers_memo: Dict[str, Any] = None, path: str = '#') -> 'TypedVariableWithDefaultValue': type_ = deserialize_typing(dict_['type_']) default_value = deserialize(dict_['default_value'], global_dict=global_dict, pointers_memo=pointers_memo) @@ -194,12 +191,12 @@ def jointjs_data(self): data['name'] = self.__class__.__name__ return data - @staticmethod - def _docstring(): + def _docstring(self): """ Base function for submodel docstring computing """ - return None + block_docstring = {i: EMPTY_PARSED_ATTRIBUTE for i in self.inputs} + return block_docstring class Pipe(DessiaObject): @@ -340,8 +337,8 @@ def __init__(self, blocks, pipes, output, *, imposed_variable_values=None, self.variables.extend(block.outputs) try: self.coordinates[block] = (0, 0) - except ValueError: - raise ValueError(f"Cannot serialize block {block} ({block.name})") + except ValueError as err: + raise ValueError(f"Cannot serialize block {block} ({block.name})") from err for pipe in self.pipes: upstream_var = pipe.input_variable @@ -368,15 +365,14 @@ def __init__(self, blocks, pipes, output, *, imposed_variable_values=None, self.description = description self.documentation = documentation - output.memorize = True + if output is not None: + output.memorize = True Block.__init__(self, input_variables, [output], name=name) self.output = self.outputs[0] def _data_hash(self): - output_hash = self.variable_indices(self.outputs[0]) - if not isinstance(output_hash, int): - output_hash = sum(output_hash) + output_hash = hash(self.variable_indices(self.outputs[0])) base_hash = len(self.blocks) + 11 * len(self.pipes) + output_hash block_hash = int(sum([b.equivalent_hash() for b in self.blocks]) % 10e5) @@ -389,6 +385,18 @@ def _data_eq(self, other_object): # TODO: implement imposed_variable_values in for block1, block2 in zip(self.blocks, other_object.blocks): if not block1.equivalent(block2): return False + + if len(self.imposed_variable_values) != len(other_object.imposed_variable_values): + return False + for imposed_key1, imposed_key2 in zip(self.imposed_variable_values.keys(), + other_object.imposed_variable_values.keys()): + if hash(imposed_key1) != hash(imposed_key2): + return False + imposed_value1 = self.imposed_variable_values[imposed_key1] + imposed_value2 = other_object.imposed_variable_values[imposed_key2] + if hash(imposed_value1) != hash(imposed_value2): + return False + return True def __deepcopy__(self, memo=None): @@ -400,8 +408,12 @@ def __deepcopy__(self, memo=None): blocks = [b.__deepcopy__() for b in self.blocks] output_adress = self.variable_indices(self.output) - output_block = blocks[output_adress[0]] - output = output_block.outputs[output_adress[2]] + if output_adress is None: + output = None + else: + output_block = blocks[output_adress[0]] + output = output_block.outputs[output_adress[2]] + copied_workflow = Workflow(blocks=blocks, pipes=[], output=output, name=self.name) pipes = [self.copy_pipe(pipe=p, copied_workflow=copied_workflow) for p in self.pipes] @@ -511,11 +523,9 @@ def _export_formats(self): """ Reads block to compute available export formats """ - export_formats = DessiaObject._export_formats() - export_formats.append({'extension': 'py', - 'method_name': 'save_script_to_stream', - 'text': True, - 'args': {}}) + export_formats = DessiaObject._export_formats(self) + export_formats.append({'extension': 'py', 'method_name': 'save_script_to_stream', + 'text': True, 'args': {}}) return export_formats def to_dict(self, use_pointers=True, memo=None, path='#'): @@ -548,7 +558,7 @@ def to_dict(self, use_pointers=True, memo=None, path='#'): else: ser_value = serialize(value) - imposed_variable_values[var_index] = ser_value + imposed_variable_values[str(var_index)] = ser_value dict_.update({'description': self.description, 'documentation': self.documentation, 'imposed_variable_values': imposed_variable_values}) @@ -582,7 +592,10 @@ def dict_to_object(cls, dict_: JsonSerializable, force_generic: bool = False, pipes.append(Pipe(variable1, variable2)) - output = blocks[dict_['output'][0]].outputs[dict_['output'][2]] + if dict_['output'] is not None: + output = blocks[dict_['output'][0]].outputs[dict_['output'][2]] + else: + output = None temp_workflow = cls(blocks=blocks, pipes=pipes, output=output) if 'imposed_variable_values' in dict_ and 'imposed_variables' in dict_: @@ -594,20 +607,22 @@ def dict_to_object(cls, dict_: JsonSerializable, force_generic: bool = False, variable = temp_workflow.variable_from_index(variable_index) imposed_variable_values[variable] = value - elif 'imposed_variable_values' in dict_: - # New format with a dict - imposed_variable_values = {} - for variable_index, serialized_value in dict_['imposed_variable_values']: - value = deserialize(serialized_value, global_dict=global_dict, pointers_memo=pointers_memo) - variable = temp_workflow.variable_from_index(variable_index) - imposed_variable_values[variable] = value - elif 'imposed_variable_indices' in dict_: - imposed_variable_values = {} - for variable_index in dict_['imposed_variable_indices']: - variable = temp_workflow.variable_from_index(variable_index) - imposed_variable_values[variable] = variable.default_value else: - imposed_variable_values = None + imposed_variable_values = {} + if 'imposed_variable_indices' in dict_: + for variable_index in dict_['imposed_variable_indices']: + variable = temp_workflow.variable_from_index(variable_index) + imposed_variable_values[variable] = variable.default_value + if 'imposed_variable_values' in dict_: + # New format with a dict + for variable_index_str, serialized_value in dict_['imposed_variable_values'].items(): + variable_index = ast.literal_eval(variable_index_str) + value = deserialize(serialized_value, global_dict=global_dict, pointers_memo=pointers_memo) + variable = temp_workflow.variable_from_index(variable_index) + imposed_variable_values[variable] = value + + if 'imposed_variable_indices' not in dict_ and 'imposed_variable_values' not in dict_: + imposed_variable_values = None if "description" in dict_: # Retro-compatibility @@ -759,13 +774,16 @@ def upstream_variable(self, variable: Variable) -> Optional[Variable]: return incoming_pipe.input_variable return None - def variable_indices(self, variable: Variable) -> Union[Tuple[int, int, int], int]: + def variable_indices(self, variable: Variable) -> Optional[Union[Tuple[int, int, int], int]]: """ Returns global adress of given variable as a tuple or an int If variable is non block, return index of variable in variables sequence Else returns global adress (ib, i, ip) """ + if variable is None: + return None + for iblock, block in enumerate(self.blocks): if variable in block.inputs: ib1 = iblock @@ -859,7 +877,7 @@ def match_variables(self, serialize_output: bool = False): other_vartype = other_variable.type_ if typematch(vartype, other_vartype): if serialize_output: - varval = self.variable_indices(other_variable) + varval = str(self.variable_indices(other_variable)) else: varval = other_variable variable_match[varkey].append(varval) @@ -993,6 +1011,9 @@ def run(self, input_values, verbose=False, progress_callback=lambda x: None, nam state.output_value = state.values[self.outputs[0]] + name_index = str(len(self.inputs) + 1) + if name is None and name_index in input_values: + name = input_values[name_index] if not name: timestamp = start_timestamp.strftime("%m-%d (%H:%M)") name = f"{self.name} @ [{timestamp}]" @@ -1078,13 +1099,13 @@ def is_valid(self): if type1 != type2: try: issubclass(pipe.input_variable.type_, pipe.output_variable.type_) - except TypeError: # TODO: need of a real typing check + except TypeError as error: # TODO: need of a real typing check consistent = True if not consistent: raise TypeError(f"Inconsistent pipe type from pipe input {pipe.input_variable.name}" f"to pipe output {pipe.output_variable.name}: " f"{pipe.input_variable.type_} incompatible with" - f"{pipe.output_variable.type_}") + f"{pipe.output_variable.type_}") from error return True def package_mix(self) -> Dict[str, float]: @@ -1149,6 +1170,8 @@ def to_script(self) -> str: script += f"pipes = [{', '.join(['pipe_' + str(i) for i in range(len(self.pipes))])}]\n" workflow_output_index = self.variable_indices(self.output) + if workflow_output_index is None: + raise ValueError("A workflow output must be set") output_name = f"block_{workflow_output_index[0]}.outputs[{workflow_output_index[2]}]" script += f"workflow = dcw.Workflow(blocks, pipes, output={output_name},name='{self.name}')\n" return script @@ -1627,7 +1650,7 @@ def _export_formats(self): """ Reads block to compute available export formats """ - export_formats = DessiaObject._export_formats() + export_formats = DessiaObject._export_formats(self) for i, block in enumerate(self.workflow.blocks): if hasattr(block, "_export_format"): export_formats.append(block._export_format(i)) @@ -1718,7 +1741,6 @@ def to_dict(self, use_pointers: bool = True, memo=None, path: str = '#'): dict_["variable_values"] = variable_values else: dict_["variable_values"] = {str(k): serialize(v) for k, v in self.variable_values.items()} - return dict_ def display_settings(self) -> List[DisplaySetting]: @@ -1751,7 +1773,7 @@ def block_display(self, block_index: int): """ self._activate_activable_pipes() self.activate_inputs() - block = self.blocks[block_index] + block = self.workflow.blocks[block_index] if block in self._activable_blocks(): self._evaluate_block(block) reference_path = '' diff --git a/tests/type_matching.py b/tests/type_matching.py index 4e49f6884..dcb36b9c3 100644 --- a/tests/type_matching.py +++ b/tests/type_matching.py @@ -1,11 +1,13 @@ from dessia_common.utils.types import typematch -from dessia_common import DessiaObject +from dessia_common import DessiaObject, PhysicalObject from dessia_common.forms import StandaloneObject, StandaloneObjectWithDefaultValues -from typing import List, Tuple, Union, Any, Optional +from typing import List, Tuple, Union, Any, Optional, Dict from dessia_common.typings import Measure +# TRIVIAL AND SPECIFIC assert typematch(DessiaObject, Any) assert typematch(int, Any) +assert typematch(int, float) and not typematch(float, int) # INHERITANCE # DessiaObject should pass a test against object, but not the other way around @@ -23,8 +25,13 @@ assert typematch(Tuple[int, str], Tuple[int, str]) assert not typematch(Tuple[int, int], Tuple[str, int]) +# DICTS +assert typematch(Dict[str, PhysicalObject], Dict[str, DessiaObject]) +assert not typematch(Dict[str, str], Dict[int, str]) and not typematch(Dict[str, str], Dict[str, int]) + # DEFAULT VALUES assert typematch(Optional[List[StandaloneObject]], List[DessiaObject]) +assert typematch(List[StandaloneObject], Optional[List[DessiaObject]]) assert typematch(Union[List[StandaloneObject], type(None)], List[DessiaObject]) # UNION @@ -33,7 +40,7 @@ assert not typematch(Union[DessiaObject, int], DessiaObject) assert typematch(StandaloneObjectWithDefaultValues, Union[DessiaObject, StandaloneObject]) assert typematch(Union[str, int], Union[str, int]) -# assert not typematch(Union[str, int], Union[str, int, bool]) # TODO Not implemented yet +assert typematch(Union[str, int], Union[bool, int, str]) assert typematch(Union[str, int], Union[int, str]) # UNEQUAL COMPLEX diff --git a/tests/workflow/forms_simulation.py b/tests/workflow/forms_simulation.py index e13a00d12..949d7f708 100644 --- a/tests/workflow/forms_simulation.py +++ b/tests/workflow/forms_simulation.py @@ -18,18 +18,19 @@ variable_match = workflow_.match_variables(True) -match_dict = {"(0, 0, 0)": [0], - "(0, 0, 1)": [0], - "(0, 0, 2)": [(1, 1, 0)], - "(0, 0, 3)": [1], - "(0, 1, 0)": [(1, 0, 0)], - "(1, 0, 0)": [(0, 1, 0)], - "(1, 1, 0)": [], - "(1, 1, 1)": [], - "(3, 0, 1)": [1], - "(3, 0, 2)": [0], - "0": [(0, 0, 0), (0, 0, 1), (3, 0, 2)], - "1": [(0, 0, 3), (3, 0, 1)]} + +match_dict = {'(0, 0, 0)': ['0'], + '(0, 0, 1)': ['0'], + '(0, 0, 2)': ['(1, 1, 0)'], + '(0, 0, 3)': ['1'], + '(0, 1, 0)': ['(1, 0, 0)'], + '(1, 0, 0)': ['(0, 1, 0)'], + '(1, 1, 0)': ['(0, 0, 2)'], + '(1, 1, 1)': [], + '(3, 0, 1)': ['1'], + '(3, 0, 2)': ['0'], + '0': ['(0, 0, 0)', '(0, 0, 1)', '(3, 0, 2)'], + '1': ['(0, 0, 3)', '(3, 0, 1)']} assert variable_match == match_dict diff --git a/tests/workflow/power_simulation.py b/tests/workflow/power_simulation.py index ff4f5c6f6..ea6678e7f 100644 --- a/tests/workflow/power_simulation.py +++ b/tests/workflow/power_simulation.py @@ -55,6 +55,7 @@ manual_run.to_dict(use_pointers=False) manual_run.jsonschema() +manual_run._performance_analysis() # Testing that there is no pointer when use_pointers=False d = workflow_run.to_dict(use_pointers=False)