Skip to content

Commit

Permalink
Add warm start test
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jul 27, 2023
1 parent abd0cfa commit db8bfce
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
2 changes: 2 additions & 0 deletions pysr/sr.py
Expand Up @@ -1784,6 +1784,8 @@ def _run(self, X, y, mutated_params, weights, seed):

y_variable_names = None
if len(y.shape) > 1:
# We set these manually so that they respect Python's 0 indexing
# (by default Julia will use y1, y2...)
y_variable_names = [f"y{_subscriptify(i)}" for i in range(y.shape[1])]

# Call to Julia backend.
Expand Down
17 changes: 15 additions & 2 deletions pysr/test/test.py
Expand Up @@ -1007,14 +1007,17 @@ def test_unit_checks(self):
)

def test_unit_propagation(self):
"""Check that units are propagated correctly."""
"""Check that units are propagated correctly.
This also tests that variables have the correct names.
"""
X = np.ones((100, 3))
y = np.ones((100, 1))
temp_dir = Path(tempfile.mkdtemp())
equation_file = str(temp_dir / "equation_file.csv")
model = PySRRegressor(
binary_operators=["+", "*"],
early_stop_condition="(l, c) -> l < 1e-8 && c == 3",
early_stop_condition="(l, c) -> l < 1e-6 && c == 3",
progress=False,
model_selection="accuracy",
niterations=DEFAULT_NITERATIONS * 2,
Expand All @@ -1027,6 +1030,7 @@ def test_unit_propagation(self):
procs=0,
random_state=0,
equation_file=equation_file,
warm_start=True,
)
model.fit(
X,
Expand All @@ -1039,6 +1043,8 @@ def test_unit_propagation(self):
self.assertNotIn("x1", best["equation"])
self.assertIn("x2", best["equation"])
self.assertEqual(best["complexity"], 3)
self.assertEqual(model.equations_.iloc[0].complexity, 1)
self.assertGreater(model.equations_.iloc[0].loss, 1e-6)

# With pkl file:
pkl_file = str(temp_dir / "equation_file.pkl")
Expand All @@ -1055,6 +1061,13 @@ def test_unit_propagation(self):
best3 = model3.get_best()
self.assertIn("x0", best3["equation"])

# Try warm start, but with no units provided (should
# be a different dataset, and thus different result):
model.fit(X, y)
model.early_stop_condition = "(l, c) -> l < 1e-6 && c == 1"
self.assertEqual(model.equations_.iloc[0].complexity, 1)
self.assertLess(model.equations_.iloc[0].loss, 1e-6)


# TODO: Determine desired behavior if second .fit() call does not have units

Expand Down

0 comments on commit db8bfce

Please sign in to comment.