Skip to content

Commit

Permalink
Apply changes from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobj committed Apr 19, 2021
1 parent fbf4c56 commit d79e41e
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions cgp/cartesian_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ def to_func(self) -> Callable[..., List[float]]:
"""Create a Python callable implementing the function described by
this graph.
The function expects as many arguments as the number of inputs
defined in the genome. The function returns a tuple with
length equal to the number of outputs defined in the
The returned callable expects as many arguments as the number
of inputs defined in the genome. The function returns a tuple
with length equal to the number of outputs defined in the
genome. For convenience, if only a single output is defined
the function will *not* return a tuple but only its first
element.
Expand Down Expand Up @@ -278,9 +278,9 @@ def to_numpy(self) -> Callable[..., List[np.ndarray]]:
"""Create a NumPy-array-compatible Python callable implementing the
function described by this graph.
The function expects as many arguments as the number of inputs
defined in the genome. Every argument needs to be a NumPy
array of equal length. The function returns a tuple with
The returned callable expects as many arguments as the number
of inputs defined in the genome. Every argument needs to be a
NumPy array of equal length. The function returns a tuple with
length equal to the number of outputs defined in the
genome. Each element will have the same length as the input
arrays. For convenience, if only a single output is defined
Expand Down Expand Up @@ -313,8 +313,8 @@ def _f(*x):
return locals()["_f"]

def to_torch(self) -> "torch.nn.Module":
"""Create a Torch class instance implementing the function defined by
this graph.
"""Create a Torch nn.Module instance implementing the function defined
by this graph.
The generated instance will have a `forward` method accepting
Torch tensor of dimension (<batch size>, n_inputs) and
Expand Down Expand Up @@ -403,14 +403,6 @@ def to_sympy(
"""

def possibly_unpack(
sympy_exprs: List["sympy_expr.Expr"],
) -> Union["sympy_expr.Expr", List["sympy_expr.Expr"]]:
if len(sympy_exprs) == 1:
return sympy_exprs[0]
else:
return sympy_exprs

if not sympy_available:
raise ModuleNotFoundError("No module named 'sympy' (extra requirement)")

Expand All @@ -432,12 +424,19 @@ def possibly_unpack(
# sympy should not automatically simplify the expression
sympy_exprs.append(sympy.sympify(s, evaluate=False))

if not simplify:
return possibly_unpack(sympy_exprs)
else: # simplify expression if desired and possible
if simplify:
for i, expr in enumerate(sympy_exprs):
try:
sympy_exprs[i] = expr.simplify()
except TypeError:
RuntimeWarning(f"SymPy could not simplify expression: {expr}")
return possibly_unpack(sympy_exprs)

def possibly_unpack(
sympy_exprs: List["sympy_expr.Expr"],
) -> Union["sympy_expr.Expr", List["sympy_expr.Expr"]]:
if len(sympy_exprs) == 1:
return sympy_exprs[0]
else:
return sympy_exprs

return possibly_unpack(sympy_exprs)

0 comments on commit d79e41e

Please sign in to comment.