Skip to content

Commit

Permalink
Merge 781f479 into 887e02d
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer authored Apr 26, 2022
2 parents 887e02d + 781f479 commit 5baa6b7
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 18 deletions.
90 changes: 72 additions & 18 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,15 @@ def __repr__(self):
return f"PySRFunction(X=>{self._sympy})"

def __call__(self, X):
expected_shape = (X.shape[0],)
if isinstance(X, pd.DataFrame):
# Lambda function takes as argument:
return self._lambda(**{k: X[k].values for k in X.columns})
return self._lambda(**{k: X[k].values for k in X.columns}) * np.ones(
expected_shape
)
elif self._selection is not None:
return self._lambda(*X[:, self._selection].T)
return self._lambda(*X.T)
return self._lambda(*X[:, self._selection].T) * np.ones(expected_shape)
return self._lambda(*X.T) * np.ones(expected_shape)


def _get_julia_project(julia_project):
Expand Down Expand Up @@ -779,10 +782,25 @@ def get_params(self, deep=True):
**{key: self.__getattribute__(key) for key in self.surface_parameters},
}

def get_best(self):
"""Get best equation using `model_selection`."""
def get_best(self, index=None):
"""Get best equation using `model_selection`.
:param index: Optional. If you wish to select a particular equation
from `self.equations`, give the row number here. This overrides
the `model_selection` parameter.
:type index: int
:returns: Dictionary representing the best expression found.
:type: pd.Series
"""
if self.equations is None:
raise ValueError("No equations have been generated yet.")

if index is not None:
if isinstance(self.equations, list):
assert isinstance(index, list)
return [eq.iloc[i] for eq, i in zip(self.equations, index)]
return self.equations.iloc[index]

if self.model_selection == "accuracy":
if isinstance(self.equations, list):
return [eq.iloc[-1] for eq in self.equations]
Expand Down Expand Up @@ -826,44 +844,72 @@ def refresh(self):
# such as extra_sympy_mappings.
self.equations = self.get_hof()

def predict(self, X):
def predict(self, X, index=None):
"""Predict y from input X using the equation chosen by `model_selection`.
You may see what equation is used by printing this object. X should have the same
columns as the training data.
:param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
:type X: np.ndarray/pandas.DataFrame
:return: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
:param index: Optional. If you want to compute the output of
an expression using a particular row of
`self.equations`, you may specify the index here.
:type index: int
:returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
:type: np.ndarray
"""
self.refresh()
best = self.get_best()
best = self.get_best(index=index)
if self.multioutput:
return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
return best["lambda_format"](X)

def sympy(self):
"""Return sympy representation of the equation(s) chosen by `model_selection`."""
def sympy(self, index=None):
"""Return sympy representation of the equation(s) chosen by `model_selection`.
:param index: Optional. If you wish to select a particular equation
from `self.equations`, give the index number here. This overrides
the `model_selection` parameter.
:type index: int
:returns: SymPy representation of the best expression.
"""
self.refresh()
best = self.get_best()
best = self.get_best(index=index)
if self.multioutput:
return [eq["sympy_format"] for eq in best]
return best["sympy_format"]

def latex(self):
"""Return latex representation of the equation(s) chosen by `model_selection`."""
def latex(self, index=None):
"""Return latex representation of the equation(s) chosen by `model_selection`.
:param index: Optional. If you wish to select a particular equation
from `self.equations`, give the index number here. This overrides
the `model_selection` parameter.
:type index: int
:returns: LaTeX expression as a string
:type: str
"""
self.refresh()
sympy_representation = self.sympy()
sympy_representation = self.sympy(index=index)
if self.multioutput:
return [sympy.latex(s) for s in sympy_representation]
return sympy.latex(sympy_representation)

def jax(self):
def jax(self, index=None):
"""Return jax representation of the equation(s) chosen by `model_selection`.
Each equation (multiple given if there are multiple outputs) is a dictionary
containing {"callable": func, "parameters": params}. To call `func`, pass
func(X, params). This function is differentiable using `jax.grad`.
:param index: Optional. If you wish to select a particular equation
from `self.equations`, give the index number here. This overrides
the `model_selection` parameter.
:type index: int
:returns: Dictionary of callable jax function in "callable" key,
and jax array of parameters as "parameters" key.
:type: dict
"""
if self.using_pandas:
warnings.warn(
Expand All @@ -873,18 +919,26 @@ def jax(self):
)
self.set_params(output_jax_format=True)
self.refresh()
best = self.get_best()
best = self.get_best(index=index)
if self.multioutput:
return [eq["jax_format"] for eq in best]
return best["jax_format"]

def pytorch(self):
def pytorch(self, index=None):
"""Return pytorch representation of the equation(s) chosen by `model_selection`.
Each equation (multiple given if there are multiple outputs) is a PyTorch module
containing the parameters as trainable attributes. You can use the module like
any other PyTorch module: `module(X)`, where `X` is a tensor with the same
column ordering as trained with.
:param index: Optional. If you wish to select a particular equation
from `self.equations`, give the row number here. This overrides
the `model_selection` parameter.
:type index: int
:returns: PyTorch module representing the expression.
:type: torch.nn.Module
"""
if self.using_pandas:
warnings.warn(
Expand All @@ -894,7 +948,7 @@ def pytorch(self):
)
self.set_params(output_torch_format=True)
self.refresh()
best = self.get_best()
best = self.get_best(index=index)
if self.multioutput:
return [eq["torch_format"] for eq in best]
return best["torch_format"]
Expand Down
19 changes: 19 additions & 0 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ def test_multioutput_custom_operator_quiet(self):
self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)

test_y1 = model.predict(self.X)
test_y2 = model.predict(self.X, index=[-1, -1])

mse1 = np.average((test_y1 - y) ** 2)
mse2 = np.average((test_y2 - y) ** 2)

self.assertLessEqual(mse1, 1e-4)
self.assertLessEqual(mse2, 1e-4)

bad_y = model.predict(self.X, index=[0, 0])
bad_mse = np.average((bad_y - y) ** 2)
self.assertGreater(bad_mse, 1e-4)

def test_multioutput_weighted_with_callable_temp_equation(self):
y = self.X[:, [0, 1]] ** 2
w = np.random.rand(*y.shape)
Expand Down Expand Up @@ -204,6 +217,12 @@ def setUp(self):
def test_best(self):
self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)

def test_index_selection(self):
self.assertEqual(self.model.sympy(-1), sympy.cos(sympy.Symbol("x0")) ** 2)
self.assertEqual(self.model.sympy(2), sympy.cos(sympy.Symbol("x0")) ** 2)
self.assertEqual(self.model.sympy(1), sympy.cos(sympy.Symbol("x0")))
self.assertEqual(self.model.sympy(0), 1.0)

def test_best_tex(self):
self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")

Expand Down

0 comments on commit 5baa6b7

Please sign in to comment.