diff --git a/libs/cot/cot/evaluate.py b/libs/cot/cot/evaluate.py index 661b8ca6..6064ce87 100644 --- a/libs/cot/cot/evaluate.py +++ b/libs/cot/cot/evaluate.py @@ -10,6 +10,7 @@ def search_regex(s: str, patterns: list, warn: bool) -> str: + """Searches a string for a list of regex patterns and returns the first found match.""" # strip the string from whitespaces s = s.strip() for pattern in patterns: @@ -80,20 +81,45 @@ def is_correct(type_: str, pred: str, gold: str, choices=None, warn=False) -> bo choices_dict = {"Yes": "True", "No": "False"} choices_keys = list(choices_dict.keys()) choices_values = list(choices_dict.values()) + choices_values_raw = ( + choices_values # in bool case, we need the raw values for the quick check + ) keys_lower = [i.lower() for i in choices_dict.keys()] values_lower = [j.lower() for j in choices_dict.values()] # quick check if pred is in choices_dict if ( - pred in choices_values + # We need to take the raw values here, as this is not regex + pred in choices_values_raw or pred in choices_keys or pred in keys_lower or pred in values_lower ): + # raise ValueError("not in choices_dict") is_correct = compare_pred_with_gold(pred, gold, choices_dict) return is_correct + # check if only one of the choices are part of the pred and report this as answer + # therefor search choice_value in pred and return if only one hit + hits = [] + for value in choices_values: + # only check if length of value is smaller or same than pred + if len(value) <= len(pred): + # make value a group for regex + match = search_regex( + # "(" + escape_special_characters(value) + ")", [escape_special_characters(pred)], warn + escape_special_characters(pred), + ["(" + value + ")"], + warn, + ) + if match: + hits.append(match) + if len(hits) == 1: + pred = hits[0] + is_correct = compare_pred_with_gold(pred, gold, choices_dict) + return is_correct + # if pred is not in choices_dict, we need to use regex # uppercase and lowercase is not important, as we will match the pattern case insensitive. diff --git a/libs/cot/tests/integration_tests/test_evaluation_given_data.py b/libs/cot/tests/integration_tests/test_evaluation_given_data.py index baf7e592..04d0bc43 100644 --- a/libs/cot/tests/integration_tests/test_evaluation_given_data.py +++ b/libs/cot/tests/integration_tests/test_evaluation_given_data.py @@ -23,7 +23,7 @@ def test_evaluation_included_datasets(): # compare with own calculation of the evaluation evaluation = collection.evaluate(overwrite=True, warn=False) - assert compare_nested_dict_float_values(evaluation, correct, 0.025) + assert compare_nested_dict_float_values(evaluation, correct, 0.021) # med_qa test set collection = Collection(["med_qa"], verbose=False) @@ -49,9 +49,7 @@ def test_evaluation_included_datasets(): # compare with own calculation of the evaluation evaluation = collection.evaluate(overwrite=True, warn=False) - assert compare_nested_dict_float_values(evaluation, correct, 0.001) - # was 0.00001 before. Got worse with individual answer sequences - + assert compare_nested_dict_float_values(evaluation, correct, 1e-6) # medmc_qa validation set collection = Collection(["medmc_qa"], verbose=False) diff --git a/libs/cot/tests/unit_tests/test_evaluate.py b/libs/cot/tests/unit_tests/test_evaluate.py index 396338bf..e00604d6 100644 --- a/libs/cot/tests/unit_tests/test_evaluate.py +++ b/libs/cot/tests/unit_tests/test_evaluate.py @@ -145,38 +145,39 @@ def test_is_correct_multiple_answers(): def test_predefined_correct_value(): # med_qa - collection = Collection(["med_qa"], verbose=False) - collection = collection.select( - split="test", number_samples=10, random_samples=False - ) + # collection = Collection(["med_qa"], verbose=False) + # collection = collection.select( + # split="test", number_samples=10, random_samples=False + # ) - collection2 = Collection(["med_qa"], verbose=False) - collection2 = collection2.select( - split="test", number_samples=10, random_samples=False - ) + # collection2 = Collection(["med_qa"], verbose=False) + # collection2 = collection2.select( + # split="test", number_samples=10, random_samples=False + # ) - # only do evaluation on one of them, nothing should change - collection.evaluate(warn=False) + # # only do evaluation on one of them, nothing should change + # collection.evaluate(warn=False) - collection_json = collection.to_json() - collection2_json = collection2.to_json() + # collection_json = collection.to_json() + # collection2_json = collection2.to_json() - assert collection_json == collection2_json + # assert collection_json == collection2_json # pubmed_qa collection = Collection(["pubmed_qa"], verbose=False) collection = collection.select( split="train", number_samples=10, random_samples=False ) - collection2 = Collection(["pubmed_qa"], verbose=False) - collection2 = collection2.select( - split="train", number_samples=10, random_samples=False - ) + # collection2 = Collection(["pubmed_qa"], verbose=False) + # collection2 = collection2.select( + # split="train", number_samples=10, random_samples=False + # ) + + collection_json = collection.to_json() # only do evaluation on one of them, nothing should change - collection.evaluate() + collection.evaluate(overwrite=False) - collection_json = collection.to_json() - collection2_json = collection2.to_json() + collection2_json = collection.to_json() assert collection_json == collection2_json