From 3ecb13ee42e8dcf8f59898e56ce81d1dc30ce1e8 Mon Sep 17 00:00:00 2001 From: Dimos Tsouros Date: Tue, 10 Dec 2024 12:41:04 +0100 Subject: [PATCH] proba tests --- tests/test_algorithms.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 5f969e2..88320bf 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -57,8 +57,18 @@ def test_growacq(self, bench, inner_alg): assert len(learned_instance.cl) > 0 assert learned_instance.get_cpmpy_model().solve() + @pytest.mark.parametrize(("bench", "algorithm", "classifier"), _generate_proba_inputs(), ids=str) + def test_proba(self, bench, algorithm, classifier): + env = ca.ProbaActiveCAEnv(classifier=classifier) + (instance, oracle) = bench + ca_system = algorithm + ca_system.env = env + learned_instance = ca_system.learn(instance=instance, oracle=oracle) + assert len(learned_instance.cl) > 0 + assert learned_instance.get_cpmpy_model().solve() + @pytest.mark.parametrize(("bench", "inner_alg", "classifier"), _generate_proba_inputs(), ids=str) - def test_proba(self, bench, inner_alg, classifier): + def test_proba_growacq(self, bench, inner_alg, classifier): env = ca.ProbaActiveCAEnv(classifier=classifier) (instance, oracle) = bench ca_system = ca.GrowAcq(env, inner_algorithm=inner_alg)