Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions codegen/settingsgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,13 @@ def _populate_classes(parent_dir):
if stubf:
stubf.write(f"{istr1}{child_object_type} = ...\n")

return_type = getattr(cls, "return_type", None)
if return_type:
f.write(f"{istr1}return_type = {return_type}\n")
f.write(f'{istr1}"""\n')
if stubf:
stubf.write(f"{istr1}{return_type} = ...\n")


def _populate_init(parent_dir, sinfo):
hash = _gethash(sinfo)
Expand Down
65 changes: 62 additions & 3 deletions src/ansys/fluent/core/solver/flobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,21 @@
import string
import sys
import types
from typing import Any, Dict, Generic, List, NewType, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Dict,
ForwardRef,
Generic,
List,
NewType,
Optional,
Tuple,
TypeVar,
Union,
_eval_type,
get_args,
get_origin,
)
import warnings
import weakref

Expand Down Expand Up @@ -85,6 +99,42 @@ class _InlineConstants:
ListStateType = List["StateType"]
StateType = Union[PrimitiveStateType, DictStateType, ListStateType]


def check_type(val, tp):
if hasattr(tp, "__supertype__"):
return check_type(val, tp.__supertype__)
if isinstance(tp, ForwardRef):
return check_type(val, _eval_type(tp, globals(), locals()))
origin = get_origin(tp)
if origin == list:
return isinstance(val, list) and all(
check_type(x, get_args(tp)[0]) for x in val
)
elif origin == tuple:
return isinstance(val, tuple) and all(
check_type(x, t) for x, t in zip(val, get_args(tp))
)
elif origin == Union:
return any(check_type(val, t) for t in get_args(tp))
elif origin == dict:
k_t, k_v = get_args(tp)
return isinstance(val, dict) and all(
check_type(k, k_t) and check_type(v, k_v) for k, v in val.items()
)
elif origin is None:
try:
return isinstance(val, tp)
except TypeError:
return False
else:
return False


def assert_type(val, tp):
if not check_type(val, tp):
raise TypeError(f"{val} is not of type {tp}.")


_ttable = str.maketrans(string.punctuation, "_" * len(string.punctuation), "?'")


Expand Down Expand Up @@ -1374,11 +1424,16 @@ def execute_command(self, *args, **kwds):
for arg, value in kwds.items():
argument = getattr(self, arg)
argument.before_execute(value)
cmd = self._execute_command(*args, **kwds)
ret = self._execute_command(*args, **kwds)
for arg, value in kwds.items():
argument = getattr(self, arg)
argument.after_execute(value)
return cmd
return_t = getattr(self, "return_type", None)
if return_t:
base_t = _baseTypes.get(return_t)
if base_t:
assert_type(ret, base_t._state_type)
return ret

def __call__(self, *args, **kwds):
return self.execute_command(*args, **kwds)
Expand Down Expand Up @@ -1727,6 +1782,10 @@ def _process_cls_names(info_dict, names, write_doc=False):
_process_cls_names(arguments, cls.argument_names, write_doc=True)
cls.__doc__ = doc

return_type = info.get("return-type") or info.get("return_type")
if return_type:
cls.return_type = return_type

object_type = info.get("object-type", False) or info.get("object_type", False)
if object_type:
cls.child_object_type, _ = get_cls(
Expand Down
44 changes: 44 additions & 0 deletions tests/test_flobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,3 +1115,47 @@ def test_ansys_units_integration_nested_state(load_mixing_elbow_mesh):
"turbulent_viscosity_ratio": (10, None),
},
}


def test_assert_type():
types = [
bool,
int,
flobject.RealType,
str,
flobject.BoolListType,
flobject.IntListType,
flobject.RealListType,
flobject.StringListType,
flobject.RealVectorType,
flobject.DictStateType,
]
vals = [
False,
1,
1.0,
"a",
[False, True],
[1, 2],
[1.0, 2.0],
["a", "b"],
(1.0, 2.0, 3.0),
{"a": 1},
]
subtypes = {
bool: (int,),
str: (flobject.RealType,),
flobject.BoolListType: (flobject.IntListType,),
flobject.StringListType: (flobject.RealListType,),
}
for i_t, tp in enumerate(types):
for i_v, val in enumerate(vals):
if i_t == i_v:
flobject.assert_type(val, tp)
else:
subtype = subtypes.get(types[i_v])
if subtype and types[i_t] in subtype:
flobject.assert_type(val, tp)
else:
with pytest.raises(TypeError):
flobject.assert_type(val, tp)
15 changes: 14 additions & 1 deletion tests/test_settings_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def test_api_upgrade(new_solver_session, capsys):
"<solver_session>.file.read_case" in capsys.readouterr().out


@pytest.mark.skip(reason="Skipping till docker image is updated")
@pytest.mark.fluent_version(">=24.2")
def test_deprecated_settings(new_solver_session):
solver = new_solver_session
Expand Down Expand Up @@ -237,3 +236,17 @@ def test_deprecated_settings(new_solver_session):
].turbulence.hydraulic_diameter()
== 10
)


@pytest.mark.fluent_version(">=24.2")
def test_command_return_type(new_solver_session):
solver = new_solver_session
case_path = download_file("mixing_elbow.cas.h5", "pyfluent/mixing_elbow")
download_file("mixing_elbow.dat.h5", "pyfluent/mixing_elbow")
ret = solver.file.read_case_data(file_name=case_path)
assert ret is None
solver.solution.report_definitions.surface["surface-1"] = dict(
surface_names=["cold-inlet"]
)
ret = solver.solution.report_definitions.compute(report_defs=["surface-1"])
assert ret is not None