Skip to content

Commit

Permalink
Merge afdb8dd into efffd9b
Browse files Browse the repository at this point in the history
  • Loading branch information
tomjelen committed Mar 22, 2024
2 parents efffd9b + afdb8dd commit c5f09df
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
16 changes: 14 additions & 2 deletions pysr/export_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,29 @@
}


def create_sympy_symbols_map(
feature_names_in: List[str],
) -> Dict[str, sympy.Symbol]:
return {variable: sympy.Symbol(variable) for variable in feature_names_in}


def create_sympy_symbols(
feature_names_in: List[str],
) -> List[sympy.Symbol]:
return [sympy.Symbol(variable) for variable in feature_names_in]


def pysr2sympy(
equation: str, *, extra_sympy_mappings: Optional[Dict[str, Callable]] = None
equation: str,
*,
feature_names_in: Optional[List[str]] = None,
extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
):
if feature_names_in is None:
feature_names_in = []
local_sympy_mappings = {
**(extra_sympy_mappings if extra_sympy_mappings else {}),
**create_sympy_symbols_map(feature_names_in),
**(extra_sympy_mappings if extra_sympy_mappings is not None else {}),
**sympy_mappings,
}

Expand Down
1 change: 1 addition & 0 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2226,6 +2226,7 @@ def get_hof(self):
for _, eqn_row in output.iterrows():
eqn = pysr2sympy(
eqn_row["equation"],
feature_names_in=self.feature_names_in_,
extra_sympy_mappings=self.extra_sympy_mappings,
)
sympy_format.append(eqn)
Expand Down
13 changes: 13 additions & 0 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,19 @@ def test_bad_variable_names_fail(self):
model.fit(X, y, variable_names=["f{c}"])
self.assertIn("Invalid variable name", str(cm.exception))

def test_python_built_in_python_funcs_as_variable_names(self):
model = PySRRegressor()
X = np.random.randn(100, 2)
y = np.random.randn(100)

# Should not throw an error
try:
model.fit(X, y, variable_names=["exec", "hash"])
except Exception:
self.fail(
"Should not have thrown when a variable name is a builtin_function_or_method"
)

def test_bad_kwargs(self):
bad_kwargs = [
dict(
Expand Down

0 comments on commit c5f09df

Please sign in to comment.