From d79e41eb7e2afd3f7132ca9e20b867a4fb7dea86 Mon Sep 17 00:00:00 2001 From: Jakob Jordan Date: Mon, 19 Apr 2021 16:18:39 +0200 Subject: [PATCH] Apply changes from code review --- cgp/cartesian_graph.py | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/cgp/cartesian_graph.py b/cgp/cartesian_graph.py index 0c48bbaf..a817da2d 100644 --- a/cgp/cartesian_graph.py +++ b/cgp/cartesian_graph.py @@ -234,9 +234,9 @@ def to_func(self) -> Callable[..., List[float]]: """Create a Python callable implementing the function described by this graph. - The function expects as many arguments as the number of inputs - defined in the genome. The function returns a tuple with - length equal to the number of outputs defined in the + The returned callable expects as many arguments as the number + of inputs defined in the genome. The function returns a tuple + with length equal to the number of outputs defined in the genome. For convenience, if only a single output is defined the function will *not* return a tuple but only its first element. @@ -278,9 +278,9 @@ def to_numpy(self) -> Callable[..., List[np.ndarray]]: """Create a NumPy-array-compatible Python callable implementing the function described by this graph. - The function expects as many arguments as the number of inputs - defined in the genome. Every argument needs to be a NumPy - array of equal length. The function returns a tuple with + The returned callable expects as many arguments as the number + of inputs defined in the genome. Every argument needs to be a + NumPy array of equal length. The function returns a tuple with length equal to the number of outputs defined in the genome. Each element will have the same length as the input arrays. For convenience, if only a single output is defined @@ -313,8 +313,8 @@ def _f(*x): return locals()["_f"] def to_torch(self) -> "torch.nn.Module": - """Create a Torch class instance implementing the function defined by - this graph. + """Create a Torch nn.Module instance implementing the function defined + by this graph. The generated instance will have a `forward` method accepting Torch tensor of dimension (, n_inputs) and @@ -403,14 +403,6 @@ def to_sympy( """ - def possibly_unpack( - sympy_exprs: List["sympy_expr.Expr"], - ) -> Union["sympy_expr.Expr", List["sympy_expr.Expr"]]: - if len(sympy_exprs) == 1: - return sympy_exprs[0] - else: - return sympy_exprs - if not sympy_available: raise ModuleNotFoundError("No module named 'sympy' (extra requirement)") @@ -432,12 +424,19 @@ def possibly_unpack( # sympy should not automatically simplify the expression sympy_exprs.append(sympy.sympify(s, evaluate=False)) - if not simplify: - return possibly_unpack(sympy_exprs) - else: # simplify expression if desired and possible + if simplify: for i, expr in enumerate(sympy_exprs): try: sympy_exprs[i] = expr.simplify() except TypeError: RuntimeWarning(f"SymPy could not simplify expression: {expr}") - return possibly_unpack(sympy_exprs) + + def possibly_unpack( + sympy_exprs: List["sympy_expr.Expr"], + ) -> Union["sympy_expr.Expr", List["sympy_expr.Expr"]]: + if len(sympy_exprs) == 1: + return sympy_exprs[0] + else: + return sympy_exprs + + return possibly_unpack(sympy_exprs)