Skip to content

Commit

Permalink
Merge pull request #52 from MilesCranmer/deepsource-fix-147c2d73
Browse files Browse the repository at this point in the history
Refactor unnecessary `else` / `elif` when `if` block has a `return` statement
  • Loading branch information
MilesCranmer committed Jun 7, 2021
2 parents c623864 + 5bb2875 commit bfe511a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 26 deletions.
20 changes: 9 additions & 11 deletions pysr/export_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,19 @@ def sympy2jaxtext(expr, parameters, symbols_in):
if issubclass(expr.func, sympy.Float):
parameters.append(float(expr))
return f"parameters[{len(parameters) - 1}]"
elif issubclass(expr.func, sympy.Integer):
if issubclass(expr.func, sympy.Integer):
return f"{int(expr)}"
elif issubclass(expr.func, sympy.Symbol):
if issubclass(expr.func, sympy.Symbol):
return (
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
)
else:
_func = _jnp_func_lookup[expr.func]
args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
if _func == MUL:
return " * ".join(["(" + arg + ")" for arg in args])
elif _func == ADD:
return " + ".join(["(" + arg + ")" for arg in args])
else:
return f'{_func}({", ".join(args)})'
_func = _jnp_func_lookup[expr.func]
args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
if _func == MUL:
return " * ".join(["(" + arg + ")" for arg in args])
if _func == ADD:
return " + ".join(["(" + arg + ")" for arg in args])
return f'{_func}({", ".join(args)})'


jax_initialized = False
Expand Down
23 changes: 8 additions & 15 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,9 @@ def _make_hyperparams_julia_str(
def tuple_fix(ops):
if len(ops) > 1:
return ", ".join(ops)
elif len(ops) == 0:
if len(ops) == 0:
return ""
else:
return ops[0] + ","
return ops[0] + ","

def_hyperparams += f"""\n
plus=(+)
Expand Down Expand Up @@ -1024,8 +1023,7 @@ def get_hof(

if multioutput:
return ret_outputs
else:
return ret_outputs[0]
return ret_outputs[0]


def best_row(equations=None):
Expand All @@ -1036,8 +1034,7 @@ def best_row(equations=None):
equations = get_hof()
if isinstance(equations, list):
return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
else:
return equations.iloc[np.argmax(equations["score"])]
return equations.iloc[np.argmax(equations["score"])]


def best_tex(equations=None):
Expand All @@ -1050,8 +1047,7 @@ def best_tex(equations=None):
return [
sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
]
else:
return sympy.latex(best_row(equations)["sympy_format"].simplify())
return sympy.latex(best_row(equations)["sympy_format"].simplify())


def best(equations=None):
Expand All @@ -1062,8 +1058,7 @@ def best(equations=None):
equations = get_hof()
if isinstance(equations, list):
return [best_row(eq)["sympy_format"].simplify() for eq in equations]
else:
return best_row(equations)["sympy_format"].simplify()
return best_row(equations)["sympy_format"].simplify()


def best_callable(equations=None):
Expand All @@ -1074,8 +1069,7 @@ def best_callable(equations=None):
equations = get_hof()
if isinstance(equations, list):
return [best_row(eq)["lambda_format"] for eq in equations]
else:
return best_row(equations)["lambda_format"]
return best_row(equations)["lambda_format"]


def _escape_filename(filename):
Expand Down Expand Up @@ -1113,5 +1107,4 @@ def __repr__(self):
def __call__(self, X):
if self._selection is not None:
return self._lambda(*X[:, self._selection].T)
else:
return self._lambda(*X.T)
return self._lambda(*X.T)

0 comments on commit bfe511a

Please sign in to comment.