Skip to content

Commit

Permalink
Refactor unnecessary else / elif when if block has a return s…
Browse files Browse the repository at this point in the history
…tatement
  • Loading branch information
deepsource-autofix[bot] committed Jun 7, 2021
1 parent b5d0afb commit 5bb2875
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 26 deletions.
20 changes: 9 additions & 11 deletions pysr/export_jax.py
Expand Up @@ -55,21 +55,19 @@ def sympy2jaxtext(expr, parameters, symbols_in):
if issubclass(expr.func, sympy.Float):
parameters.append(float(expr))
return f"parameters[{len(parameters) - 1}]"
elif issubclass(expr.func, sympy.Integer):
if issubclass(expr.func, sympy.Integer):
return f"{int(expr)}"
elif issubclass(expr.func, sympy.Symbol):
if issubclass(expr.func, sympy.Symbol):
return (
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
)
else:
_func = _jnp_func_lookup[expr.func]
args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
if _func == MUL:
return " * ".join(["(" + arg + ")" for arg in args])
elif _func == ADD:
return " + ".join(["(" + arg + ")" for arg in args])
else:
return f'{_func}({", ".join(args)})'
_func = _jnp_func_lookup[expr.func]
args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
if _func == MUL:
return " * ".join(["(" + arg + ")" for arg in args])
if _func == ADD:
return " + ".join(["(" + arg + ")" for arg in args])
return f'{_func}({", ".join(args)})'


jax_initialized = False
Expand Down
23 changes: 8 additions & 15 deletions pysr/sr.py
Expand Up @@ -643,10 +643,9 @@ def _make_hyperparams_julia_str(
def tuple_fix(ops):
if len(ops) > 1:
return ", ".join(ops)
elif len(ops) == 0:
if len(ops) == 0:
return ""
else:
return ops[0] + ","
return ops[0] + ","

def_hyperparams += f"""\n
plus=(+)
Expand Down Expand Up @@ -1025,8 +1024,7 @@ def get_hof(

if multioutput:
return ret_outputs
else:
return ret_outputs[0]
return ret_outputs[0]


def best_row(equations=None):
Expand All @@ -1037,8 +1035,7 @@ def best_row(equations=None):
equations = get_hof()
if isinstance(equations, list):
return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
else:
return equations.iloc[np.argmax(equations["score"])]
return equations.iloc[np.argmax(equations["score"])]


def best_tex(equations=None):
Expand All @@ -1051,8 +1048,7 @@ def best_tex(equations=None):
return [
sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
]
else:
return sympy.latex(best_row(equations)["sympy_format"].simplify())
return sympy.latex(best_row(equations)["sympy_format"].simplify())


def best(equations=None):
Expand All @@ -1063,8 +1059,7 @@ def best(equations=None):
equations = get_hof()
if isinstance(equations, list):
return [best_row(eq)["sympy_format"].simplify() for eq in equations]
else:
return best_row(equations)["sympy_format"].simplify()
return best_row(equations)["sympy_format"].simplify()


def best_callable(equations=None):
Expand All @@ -1075,8 +1070,7 @@ def best_callable(equations=None):
equations = get_hof()
if isinstance(equations, list):
return [best_row(eq)["lambda_format"] for eq in equations]
else:
return best_row(equations)["lambda_format"]
return best_row(equations)["lambda_format"]


def _escape_filename(filename):
Expand Down Expand Up @@ -1114,5 +1108,4 @@ def __repr__(self):
def __call__(self, X):
if self._selection is not None:
return self._lambda(*X[:, self._selection].T)
else:
return self._lambda(*X.T)
return self._lambda(*X.T)

0 comments on commit 5bb2875

Please sign in to comment.