Skip to content

Commit

Permalink
Add test for feature selection
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 31, 2021
1 parent a626763 commit 97e6589
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions test/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
import numpy as np
from pysr import pysr, get_hof, best, best_tex, best_callable
from pysr.sr import run_feature_selection, _handle_feature_selection
import sympy
import pandas as pd

Expand Down Expand Up @@ -72,3 +73,28 @@ def test_best_lambda(self):
y = np.cos(X[:, 0])**2
for f in [best_callable(), best_callable(self.equations)]:
np.testing.assert_almost_equal(f(X), y)


class TestFeatureSelection(unittest.TestCase):
def test_feature_selection(self):
np.random.seed(0)
X = np.random.randn(20001, 5)
y = X[:, 2]**2 + X[:, 3]**2
selected = run_feature_selection(X, y, select_k_features=2)
self.assertEqual(sorted(selected), [2, 3])

def test_feature_selection_handler(self):
np.random.seed(0)
X = np.random.randn(20000, 5)
y = X[:, 2]**2 + X[:, 3]**2
var_names = [f'x{i}' for i in range(5)]
selected_X, selected_var_names = _handle_feature_selection(
X, select_k_features=2,
use_custom_variable_names=True,
variable_names=[f'x{i}' for i in range(5)],
y=y)
self.assertEqual(set(selected_var_names), set('x2 x3'.split(' ')))
np.testing.assert_array_equal(
np.sort(selected_X, axis=1),
np.sort(X[:, [2, 3]], axis=1)
)

0 comments on commit 97e6589

Please sign in to comment.