Skip to content

Commit

Permalink
Merge 44d1e08 into 7156334
Browse files Browse the repository at this point in the history
  • Loading branch information
tomjelen committed Mar 1, 2024
2 parents 7156334 + 44d1e08 commit 88a96ee
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
6 changes: 5 additions & 1 deletion pysr/export_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,13 @@ def create_sympy_symbols(


def pysr2sympy(
equation: str, *, extra_sympy_mappings: Optional[Dict[str, Callable]] = None
equation: str,
feature_names_in: List[str],
*,
extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
):
local_sympy_mappings = {
**{f: sympy.Symbol(f) for f in feature_names_in},
**(extra_sympy_mappings if extra_sympy_mappings else {}),
**sympy_mappings,
}
Expand Down
1 change: 1 addition & 0 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2226,6 +2226,7 @@ def get_hof(self):
for _, eqn_row in output.iterrows():
eqn = pysr2sympy(
eqn_row["equation"],
self.feature_names_in_,
extra_sympy_mappings=self.extra_sympy_mappings,
)
sympy_format.append(eqn)
Expand Down
13 changes: 13 additions & 0 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,19 @@ def test_bad_variable_names_fail(self):
model.fit(X, y, variable_names=["f{c}"])
self.assertIn("Invalid variable name", str(cm.exception))

def test_python_built_in_python_funcs_as_variable_names(self):
model = PySRRegressor()
X = np.random.randn(100, 2)
y = np.random.randn(100)

# Should not throw an error
try:
model.fit(X, y, variable_names=["exec", "hash"])
except Exception:
self.fail(
"Should not have thrown when a variable name is a builtin_function_or_method"
)

def test_bad_kwargs(self):
bad_kwargs = [
dict(
Expand Down

0 comments on commit 88a96ee

Please sign in to comment.