diff --git a/tests/test_survival.py b/tests/test_survival.py index a5a2fba..a9c6a4d 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -233,7 +233,6 @@ def test_refining_conditions_for_nominal_attributes(self): survival_time_attr="survival_time", ) clf.fit(X, y, expert_rules=[("expert_rules-1", "IF CMVstatus @= {1} THEN")]) - r = [str(r) for r in clf.model.rules] self.assertEqual( ["IF [[CMVstatus = {1}]] THEN "], @@ -251,6 +250,33 @@ def test_refining_conditions_for_nominal_attributes(self): ), ) + def test_refining_conditions_for_numerical_attributes(self): + df: pd.DataFrame = read_arff( + os.path.join(dir_path, "resources", "data", "bmt-train-0.arff") + ) + X, y = df.drop("survival_status", axis=1), df["survival_status"] + + # Run experiment using python API + clf = survival.ExpertSurvivalRules( + complementary_conditions=True, + extend_using_preferred=False, + extend_using_automatic=False, + induce_using_preferred=False, + induce_using_automatic=False, + preferred_conditions_per_rule=0, + preferred_attributes_per_rule=0, + survival_time_attr="survival_time", + ) + clf.fit(X, y, expert_rules=[("expert_rules-1", "IF CD34kgx10d6 @= Any THEN")]) + self.assertEqual( + ["IF [[CD34kgx10d6 = (-inf, 11.86)]] THEN "], + [str(r) for r in clf.model.rules], + ( + "Ruleset should contain only a single rule configured by expert with " + "a refined condition" + ), + ) + if __name__ == "__main__": unittest.main()