Skip to content

Commit

Permalink
Make tests non-random
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 31, 2021
1 parent 25f8cac commit 51a6b05
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
15 changes: 9 additions & 6 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ def test_multioutput_weighted_with_callable(self):

np.testing.assert_almost_equal(
best_callable()[0](self.X),
self.X[:, 0]**2)
self.X[:, 0]**2,
decimal=4)
np.testing.assert_almost_equal(
best_callable()[1](self.X),
self.X[:, 1]**2)
self.X[:, 1]**2,
decimal=4)

def test_empty_operators_single_input(self):
X = np.random.randn(100, 1)
Expand Down Expand Up @@ -96,19 +98,20 @@ def test_best_lambda(self):
X = np.random.randn(10, 2)
y = np.cos(X[:, 0])**2
for f in [best_callable(), best_callable(self.equations)]:
np.testing.assert_almost_equal(f(X), y)
np.testing.assert_almost_equal(f(X), y, decimal=4)


class TestFeatureSelection(unittest.TestCase):
def test_feature_selection(self):
def setUp(self):
np.random.seed(0)
X = np.random.randn(20001, 5)

def test_feature_selection(self):
X = np.random.randn(20000, 5)
y = X[:, 2]**2 + X[:, 3]**2
selected = run_feature_selection(X, y, select_k_features=2)
self.assertEqual(sorted(selected), [2, 3])

def test_feature_selection_handler(self):
np.random.seed(0)
X = np.random.randn(20000, 5)
y = X[:, 2]**2 + X[:, 3]**2
var_names = [f'x{i}' for i in range(5)]
Expand Down
6 changes: 5 additions & 1 deletion test/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import sympy

class TestJAX(unittest.TestCase):
def setUp(self):
np.random.seed(0)

def test_sympy2jax(self):
x, y, z = sympy.symbols('x y z')
cosx = 1.0 * sympy.cos(x) + y
Expand Down Expand Up @@ -35,5 +38,6 @@ def test_pipeline(self):
jformat = equations.iloc[-1].jax_format
np.testing.assert_almost_equal(
np.array(jformat['callable'](jnp.array(X), jformat['parameters'])),
np.square(np.cos(X[:, 1])) # Select feature 1
np.square(np.cos(X[:, 1])), # Select feature 1
decimal=4
)
8 changes: 6 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
import sympy

class TestTorch(unittest.TestCase):
def setUp(self):
np.random.seed(0)

def test_sympy2torch(self):
x, y, z = sympy.symbols('x y z')
cosx = 1.0 * sympy.cos(x) + y
X = torch.randn((1000, 3))
X = torch.tensor(np.random.randn(1000, 3))
true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
torch_module = sympy2torch(cosx, [x, y, z])
self.assertTrue(
Expand All @@ -34,5 +37,6 @@ def test_pipeline(self):
tformat = equations.iloc[-1].torch_format
np.testing.assert_almost_equal(
tformat(torch.tensor(X)).detach().numpy(),
np.square(np.cos(X[:, 1])) #Selection 1st feature
np.square(np.cos(X[:, 1])), #Selection 1st feature
decimal=4
)

0 comments on commit 51a6b05

Please sign in to comment.