Skip to content

Commit

Permalink
Fix issue with lambda getting redefined; add test
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 31, 2021
1 parent 97e6589 commit 6a4fa2c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
17 changes: 15 additions & 2 deletions pysr/sr.py
Expand Up @@ -61,6 +61,19 @@
'gamma': lambda x : sympy.gamma(x),
}

class CallableEquation(object):
"""Simple wrapper for numpy lambda functions built with sympy"""
def __init__(self, sympy_symbols, eqn):
self._sympy = eqn
self._sympy_symbols = sympy_symbols
self._lambda = lambdify(sympy_symbols, eqn)

def __repr__(self):
return f"PySRFunction(X=>{self._sympy})"

def __call__(self, X):
return self._lambda(*X.T)

def pysr(X, y, weights=None,
binary_operators=None,
unary_operators=None,
Expand Down Expand Up @@ -774,8 +787,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
if output_jax_format:
func, params = sympy2jax(eqn, sympy_symbols)
jax_format.append({'callable': func, 'parameters': params})
tmp_lambda = lambdify(sympy_symbols, eqn)
lambda_format.append(lambda X: tmp_lambda(*X.T))

lambda_format.append(CallableEquation(sympy_symbols, eqn))
curMSE = output.loc[i, 'MSE']
curComplexity = output.loc[i, 'Complexity']

Expand Down
32 changes: 28 additions & 4 deletions test/test.py
@@ -1,8 +1,9 @@
import unittest
import numpy as np
from pysr import pysr, get_hof, best, best_tex, best_callable
from pysr import pysr, get_hof, best, best_tex, best_callable, best_row
from pysr.sr import run_feature_selection, _handle_feature_selection
import sympy
from sympy import lambdify
import pandas as pd

class TestPipeline(unittest.TestCase):
Expand All @@ -27,20 +28,43 @@ def test_multioutput_custom_operator(self):
y = self.X[:, [0, 1]]**2
equations = pysr(self.X, y,
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
extra_sympy_mappings={'square': lambda x: x**2},
**self.default_test_kwargs)
extra_sympy_mappings={'sq': lambda x: x**2},
**self.default_test_kwargs,
procs=0)
print(equations)
self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4)
self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4)

def test_multioutput_weighted_with_callable(self):
y = self.X[:, [0, 1]]**2
w = np.random.rand(*y.shape)
w[w < 0.5] = 0.0
w[w >= 0.5] = 1.0

# Double equation when weights are 0:
y += (1-w) * y
# Thus, pysr needs to use the weights to find the right equation!

equations = pysr(self.X, y, weights=w,
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
extra_sympy_mappings={'sq': lambda x: x**2},
**self.default_test_kwargs,
procs=0)

np.testing.assert_almost_equal(
best_callable()[0](self.X),
self.X[:, 0]**2)
np.testing.assert_almost_equal(
best_callable()[1](self.X),
self.X[:, 1]**2)

def test_empty_operators_single_input(self):
X = np.random.randn(100, 1)
y = X[:, 0] + 3.0
equations = pysr(X, y,
unary_operators=[], binary_operators=["plus"],
**self.default_test_kwargs)

print(equations)
self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)

class TestBest(unittest.TestCase):
Expand Down

0 comments on commit 6a4fa2c

Please sign in to comment.