Skip to content

Commit

Permalink
Merge pull request #281 from MilesCranmer/complex-numbers
Browse files Browse the repository at this point in the history
Complex-valued expressions
  • Loading branch information
MilesCranmer committed Mar 21, 2023
2 parents ab9ae60 + 38f33fd commit 81ba2f3
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 6 deletions.
36 changes: 35 additions & 1 deletion docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,41 @@ You can get the sympy version of the best equation with:
model.sympy()
```

## 8. Additional features
## 8. Complex numbers

PySR can also search for complex-valued expressions. Simply pass
data with a complex datatype (e.g., `np.complex128`),
and PySR will automatically search for complex-valued expressions:

```python
import numpy as np

X = np.random.randn(100, 1) + 1j * np.random.randn(100, 1)
y = (1 + 2j) * np.cos(X[:, 0] * (0.5 - 0.2j))

model = PySRRegressor(
binary_operators=["+", "-", "*"], unary_operators=["cos"], niterations=100,
)

model.fit(X, y)
```

You can see that all of the learned constants are now complex numbers.
We can get the sympy version of the best equation with:

```python
model.sympy()
```

We can also make predictions normally, by passing complex data:

```python
model.predict(X, -1)
```

to make predictions with the most accurate expression.

## 9. Additional features

For the many other features available in PySR, please
read the [Options section](options.md).
1 change: 1 addition & 0 deletions pysr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import sklearn_monkeypatch
from .version import __version__
from .sr import (
pysr,
Expand Down
13 changes: 13 additions & 0 deletions pysr/sklearn_monkeypatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Here, we monkey patch scikit-learn until this
# issue is fixed: https://github.com/scikit-learn/scikit-learn/issues/25922
from sklearn.utils import validation


def _ensure_no_complex_data(*args, **kwargs):
...


try:
validation._ensure_no_complex_data = _ensure_no_complex_data
except AttributeError:
...
29 changes: 28 additions & 1 deletion pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
What precision to use for the data. By default this is `32`
(float32), but you can select `64` or `16` as well, giving
you 64 or 16 bits of floating point precision, respectively.
If you pass complex data, the corresponding complex precision
will be used (i.e., `64` for complex128, `32` for complex64).
Default is `32`.
random_state : int, Numpy RandomState instance or None
Pass an int for reproducible results across multiple function calls.
Expand Down Expand Up @@ -1619,7 +1621,13 @@ def _run(self, X, y, mutated_params, weights, seed):
)

# Convert data to desired precision
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
test_X = np.array(X)
is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
is_real = not is_complex
if is_real:
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
else:
np_dtype = {32: np.complex64, 64: np.complex128}[self.precision]

# This converts the data into a Julia array:
Main.X = np.array(X, dtype=np_dtype).T
Expand Down Expand Up @@ -2007,6 +2015,7 @@ def pytorch(self, index=None):

def _read_equation_file(self):
"""Read the hall of fame file created by `SymbolicRegression.jl`."""

try:
if self.nout_ > 1:
all_outputs = []
Expand All @@ -2024,6 +2033,7 @@ def _read_equation_file(self):
},
inplace=True,
)
df["equation"] = df["equation"].apply(_preprocess_julia_floats)

all_outputs.append(df)
else:
Expand All @@ -2039,6 +2049,10 @@ def _read_equation_file(self):
},
inplace=True,
)
all_outputs[-1]["equation"] = all_outputs[-1]["equation"].apply(
_preprocess_julia_floats
)

except FileNotFoundError:
raise RuntimeError(
"Couldn't find equation file! The equation search likely exited "
Expand Down Expand Up @@ -2329,3 +2343,16 @@ def _csv_filename_to_pkl_filename(csv_filename) -> str:
pkl_basename = base + ".pkl"

return os.path.join(dirname, pkl_basename)


_regexp_im = re.compile(r"\b(\d+\.\d+)im\b")
_regexp_im_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)im\b")
_regexp_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)\b")

_apply_regexp_im = lambda x: _regexp_im.sub(r"\1j", x)
_apply_regexp_im_sci = lambda x: _regexp_im_sci.sub(r"\1e\2j", x)
_apply_regexp_sci = lambda x: _regexp_sci.sub(r"\1e\2", x)


def _preprocess_julia_floats(s: str) -> str:
return _apply_regexp_sci(_apply_regexp_im_sci(_apply_regexp_im(s)))
17 changes: 15 additions & 2 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,20 @@ def test_multioutput_weighted_with_callable_temp_equation(self):
print("Model equations: ", model.sympy()[1])
print("True equation: x1^2")

def test_complex_equations_anonymous_stop(self):
X = self.rstate.randn(100, 3) + 1j * self.rstate.randn(100, 3)
y = (2 + 1j) * np.cos(X[:, 0] * (0.5 - 0.3j))
model = PySRRegressor(
binary_operators=["+", "-", "*"],
unary_operators=["cos"],
**self.default_test_kwargs,
early_stop_condition="(loss, complexity) -> loss <= 1e-4 && complexity <= 6",
)
model.fit(X, y)
test_y = model.predict(X)
self.assertTrue(np.issubdtype(test_y.dtype, np.complexfloating))
self.assertLessEqual(np.average(np.abs(test_y - y) ** 2), 1e-4)

def test_empty_operators_single_input_warm_start(self):
X = self.rstate.randn(100, 1)
y = X[:, 0] + 3.0
Expand Down Expand Up @@ -230,7 +244,6 @@ def test_warm_start_set_at_init(self):
regressor.fit(self.X, y)

def test_noisy(self):

y = self.X[:, [0, 1]] ** 2 + self.rstate.randn(self.X.shape[0], 1) * 0.05
model = PySRRegressor(
# Test that passing a single operator works:
Expand Down Expand Up @@ -664,7 +677,7 @@ def test_scikit_learn_compatibility(self):

check_generator = check_estimator(model, generate_only=True)
exception_messages = []
for (_, check) in check_generator:
for _, check in check_generator:
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down
4 changes: 2 additions & 2 deletions pysr/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "0.11.17"
__symbolic_regression_jl_version__ = "0.15.3"
__version__ = "0.12.0"
__symbolic_regression_jl_version__ = "0.16.1"

0 comments on commit 81ba2f3

Please sign in to comment.