Skip to content

Commit

Permalink
Add test for pickling + units
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jul 23, 2023
1 parent 2a90e3f commit 9c17dcf
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,17 +1010,23 @@ def test_unit_propagation(self):
"""Check that units are propagated correctly."""
X = np.ones((100, 3))
y = np.ones((100, 1))
temp_dir = Path(tempfile.mkdtemp())
equation_file = str(temp_dir / "equation_file.csv")
model = PySRRegressor(
binary_operators=["+", "*"],
early_stop_condition="(l, c) -> l < 1e-8 && c == 3",
**self.default_test_kwargs,
progress=False,
model_selection="accuracy",
niterations=DEFAULT_NITERATIONS * 2,
populations=DEFAULT_POPULATIONS * 2,
complexity_of_constants=10,
weight_mutate_constant=0.0,
should_optimize_constants=False,
multithreading=False,
deterministic=True,
procs=0,
random_state=0,
equation_file=equation_file,
)
model.fit(
X,
Expand All @@ -1034,6 +1040,23 @@ def test_unit_propagation(self):
self.assertIn("x2", best["equation"])
self.assertEqual(best["complexity"], 3)

# With pkl file:
pkl_file = str(temp_dir / "equation_file.pkl")
model2 = PySRRegressor.from_file(pkl_file)
best2 = model2.get_best()
self.assertIn("x0", best2["equation"])

# From csv file alone (we need to delete pkl file:)
# First, we delete the pkl file:
os.remove(pkl_file)
model3 = PySRRegressor.from_file(
equation_file, binary_operators=["+", "*"], n_features_in=X.shape[1]
)
best3 = model3.get_best()
self.assertIn("x0", best3["equation"])

# TODO: Determine desired behavior if second .fit() call does not have units


def runtests():
"""Run all tests in test.py."""
Expand Down

0 comments on commit 9c17dcf

Please sign in to comment.