Skip to content

Commit

Permalink
Allow user to pass extra torch operators to pysr
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 31, 2021
1 parent e7ede78 commit 84e4a47
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions pysr/sr.py
Expand Up @@ -102,6 +102,8 @@ def pysr(X, y, weights=None,
perturbationFactor=1.0,
timeout=None,
extra_sympy_mappings=None,
extra_torch_mappings=None,
extra_jax_mappings=None,
equation_file=None,
verbosity=1e9,
progress=True,
Expand Down Expand Up @@ -336,6 +338,8 @@ def pysr(X, y, weights=None,
weightSimplify=weightSimplify,
constraints=constraints,
extra_sympy_mappings=extra_sympy_mappings,
extra_jax_mappings=extra_jax_mappings,
extra_torch_mappings=extra_torch_mappings,
julia_project=julia_project, loss=loss,
output_jax_format=output_jax_format,
output_torch_format=output_torch_format,
Expand Down Expand Up @@ -730,6 +734,7 @@ def run_feature_selection(X, y, select_k_features):
def get_hof(equation_file=None, n_features=None, variable_names=None,
extra_sympy_mappings=None, output_jax_format=False,
output_torch_format=False,
extra_jax_mappings=None, extra_torch_mappings=None,
multioutput=None, nout=None, **kwargs):
"""Get the equations from a hall of fame file. If no arguments
entered, the ones used previously from a call to PySR will be used."""
Expand Down Expand Up @@ -790,20 +795,22 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
for i in range(len(output)):
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
sympy_format.append(eqn)

# Numpy:
lambda_format.append(CallableEquation(sympy_symbols, eqn))

# JAX:
if output_jax_format:
from .export_jax import sympy2jax
func, params = sympy2jax(eqn, sympy_symbols)
jax_format.append({'callable': func, 'parameters': params})
<<<<<<< HEAD

lambda_format.append(CallableEquation(sympy_symbols, eqn))
=======
# Torch:
if output_torch_format:
from .export_torch import sympy2torch
module = sympy2torch(eqn, sympy_symbols)
torch_format.append(module)
lambda_format.append(lambdify(sympy_symbols, eqn))
>>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)

curMSE = output.loc[i, 'MSE']
curComplexity = output.loc[i, 'Complexity']

Expand Down

0 comments on commit 84e4a47

Please sign in to comment.