diff --git a/talos/parameters/handling.py b/talos/parameters/handling.py index fba5b397..8cb270a0 100644 --- a/talos/parameters/handling.py +++ b/talos/parameters/handling.py @@ -19,10 +19,10 @@ def run_param_pick(self): _choice = random.choice(self.param_log) elif self.search_method == 'linear': - _choice = self.param_log.min() + _choice = min(self.param_log) elif self.search_method == 'reverse': - _choice = self.param_log.max() + _choice = max(self.param_log) self.param_log.remove(_choice) diff --git a/test/core_tests/test_scan.py b/test/core_tests/test_scan.py index 9dc5fdef..379ef5b6 100644 --- a/test/core_tests/test_scan.py +++ b/test/core_tests/test_scan.py @@ -102,15 +102,28 @@ class TestCancer: def __init__(self): self.x, self.y = datasets.cervical_cancer() + self.model = cervix_model def test_scan_cancer(self): print("Running Cervical Cancer dataset test...") - Scan(self.x, self.y, grid_downsample=0.001, params=p3, + Scan(self.x, self.y, grid_downsample=0.0005, params=p3, dataset_name='testing', experiment_no='a', - model=cervix_model, + model=self.model, reduction_method='spear', reduction_interval=5) Reporting('testing_a.csv') + def test_linear_method(self): + print("Testing linear method on Cancer dataset...") + Scan(self.x, self.y, params=p3, dataset_name='testing', + search_method='linear', grid_downsample=0.0005, + experiment_no='000', model=self.model) + + def test_reverse_method(self): + print("Testing reverse method on Cancer dataset...") + Scan(self.x, self.y, params=p3, dataset_name='testing', + search_method='reverse', grid_downsample=0.0005, + experiment_no='000', model=self.model) + class TestLoadDatasets: diff --git a/test_script.py b/test_script.py index 2e2321cd..4ac2bccb 100644 --- a/test_script.py +++ b/test_script.py @@ -6,6 +6,8 @@ if __name__ == '__main__': # TODO describe what all this does + TestCancer().test_linear_method() + TestCancer().test_reverse_method() TestIris().test_scan_iris_explicit_validation_set() TestIris().test_scan_iris_explicit_validation_set_force_fail() TestIris().test_scan_iris_1()