From 6a4fa2c29c0672e3315e87ae8bb5169fb2f6e4b5 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 31 May 2021 02:19:13 -0400 Subject: [PATCH] Fix issue with lambda getting redefined; add test --- pysr/sr.py | 17 +++++++++++++++-- test/test.py | 32 ++++++++++++++++++++++++++++---- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/pysr/sr.py b/pysr/sr.py index e694cae28..5bfd3d078 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -61,6 +61,19 @@ 'gamma': lambda x : sympy.gamma(x), } +class CallableEquation(object): + """Simple wrapper for numpy lambda functions built with sympy""" + def __init__(self, sympy_symbols, eqn): + self._sympy = eqn + self._sympy_symbols = sympy_symbols + self._lambda = lambdify(sympy_symbols, eqn) + + def __repr__(self): + return f"PySRFunction(X=>{self._sympy})" + + def __call__(self, X): + return self._lambda(*X.T) + def pysr(X, y, weights=None, binary_operators=None, unary_operators=None, @@ -774,8 +787,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, if output_jax_format: func, params = sympy2jax(eqn, sympy_symbols) jax_format.append({'callable': func, 'parameters': params}) - tmp_lambda = lambdify(sympy_symbols, eqn) - lambda_format.append(lambda X: tmp_lambda(*X.T)) + + lambda_format.append(CallableEquation(sympy_symbols, eqn)) curMSE = output.loc[i, 'MSE'] curComplexity = output.loc[i, 'Complexity'] diff --git a/test/test.py b/test/test.py index f54d1cfa0..809048dab 100644 --- a/test/test.py +++ b/test/test.py @@ -1,8 +1,9 @@ import unittest import numpy as np -from pysr import pysr, get_hof, best, best_tex, best_callable +from pysr import pysr, get_hof, best, best_tex, best_callable, best_row from pysr.sr import run_feature_selection, _handle_feature_selection import sympy +from sympy import lambdify import pandas as pd class TestPipeline(unittest.TestCase): @@ -27,12 +28,36 @@ def test_multioutput_custom_operator(self): y = self.X[:, [0, 1]]**2 equations = pysr(self.X, y, unary_operators=["sq(x) = x^2"], binary_operators=["plus"], - extra_sympy_mappings={'square': lambda x: x**2}, - **self.default_test_kwargs) + extra_sympy_mappings={'sq': lambda x: x**2}, + **self.default_test_kwargs, + procs=0) print(equations) self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4) self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4) + def test_multioutput_weighted_with_callable(self): + y = self.X[:, [0, 1]]**2 + w = np.random.rand(*y.shape) + w[w < 0.5] = 0.0 + w[w >= 0.5] = 1.0 + + # Double equation when weights are 0: + y += (1-w) * y + # Thus, pysr needs to use the weights to find the right equation! + + equations = pysr(self.X, y, weights=w, + unary_operators=["sq(x) = x^2"], binary_operators=["plus"], + extra_sympy_mappings={'sq': lambda x: x**2}, + **self.default_test_kwargs, + procs=0) + + np.testing.assert_almost_equal( + best_callable()[0](self.X), + self.X[:, 0]**2) + np.testing.assert_almost_equal( + best_callable()[1](self.X), + self.X[:, 1]**2) + def test_empty_operators_single_input(self): X = np.random.randn(100, 1) y = X[:, 0] + 3.0 @@ -40,7 +65,6 @@ def test_empty_operators_single_input(self): unary_operators=[], binary_operators=["plus"], **self.default_test_kwargs) - print(equations) self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4) class TestBest(unittest.TestCase):