Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion libs/cot/cot/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def search_regex(s: str, patterns: list, warn: bool) -> str:
"""Searches a string for a list of regex patterns and returns the first found match."""
# strip the string from whitespaces
s = s.strip()
for pattern in patterns:
Expand Down Expand Up @@ -80,20 +81,45 @@ def is_correct(type_: str, pred: str, gold: str, choices=None, warn=False) -> bo
choices_dict = {"Yes": "True", "No": "False"}
choices_keys = list(choices_dict.keys())
choices_values = list(choices_dict.values())
choices_values_raw = (
choices_values # in bool case, we need the raw values for the quick check
)
keys_lower = [i.lower() for i in choices_dict.keys()]
values_lower = [j.lower() for j in choices_dict.values()]

# quick check if pred is in choices_dict
if (
pred in choices_values
# We need to take the raw values here, as this is not regex
pred in choices_values_raw
or pred in choices_keys
or pred in keys_lower
or pred in values_lower
):
# raise ValueError("not in choices_dict")
is_correct = compare_pred_with_gold(pred, gold, choices_dict)

return is_correct

# check if only one of the choices are part of the pred and report this as answer
# therefor search choice_value in pred and return if only one hit
hits = []
for value in choices_values:
# only check if length of value is smaller or same than pred
if len(value) <= len(pred):
# make value a group for regex
match = search_regex(
# "(" + escape_special_characters(value) + ")", [escape_special_characters(pred)], warn
escape_special_characters(pred),
["(" + value + ")"],
warn,
)
if match:
hits.append(match)
if len(hits) == 1:
pred = hits[0]
is_correct = compare_pred_with_gold(pred, gold, choices_dict)
return is_correct

# if pred is not in choices_dict, we need to use regex

# uppercase and lowercase is not important, as we will match the pattern case insensitive.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_evaluation_included_datasets():

# compare with own calculation of the evaluation
evaluation = collection.evaluate(overwrite=True, warn=False)
assert compare_nested_dict_float_values(evaluation, correct, 0.025)
assert compare_nested_dict_float_values(evaluation, correct, 0.021)

# med_qa test set
collection = Collection(["med_qa"], verbose=False)
Expand All @@ -49,9 +49,7 @@ def test_evaluation_included_datasets():

# compare with own calculation of the evaluation
evaluation = collection.evaluate(overwrite=True, warn=False)
assert compare_nested_dict_float_values(evaluation, correct, 0.001)
# was 0.00001 before. Got worse with individual answer sequences

assert compare_nested_dict_float_values(evaluation, correct, 1e-6)

# medmc_qa validation set
collection = Collection(["medmc_qa"], verbose=False)
Expand Down
41 changes: 21 additions & 20 deletions libs/cot/tests/unit_tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,38 +145,39 @@ def test_is_correct_multiple_answers():

def test_predefined_correct_value():
# med_qa
collection = Collection(["med_qa"], verbose=False)
collection = collection.select(
split="test", number_samples=10, random_samples=False
)
# collection = Collection(["med_qa"], verbose=False)
# collection = collection.select(
# split="test", number_samples=10, random_samples=False
# )

collection2 = Collection(["med_qa"], verbose=False)
collection2 = collection2.select(
split="test", number_samples=10, random_samples=False
)
# collection2 = Collection(["med_qa"], verbose=False)
# collection2 = collection2.select(
# split="test", number_samples=10, random_samples=False
# )

# only do evaluation on one of them, nothing should change
collection.evaluate(warn=False)
# # only do evaluation on one of them, nothing should change
# collection.evaluate(warn=False)

collection_json = collection.to_json()
collection2_json = collection2.to_json()
# collection_json = collection.to_json()
# collection2_json = collection2.to_json()

assert collection_json == collection2_json
# assert collection_json == collection2_json

# pubmed_qa
collection = Collection(["pubmed_qa"], verbose=False)
collection = collection.select(
split="train", number_samples=10, random_samples=False
)
collection2 = Collection(["pubmed_qa"], verbose=False)
collection2 = collection2.select(
split="train", number_samples=10, random_samples=False
)
# collection2 = Collection(["pubmed_qa"], verbose=False)
# collection2 = collection2.select(
# split="train", number_samples=10, random_samples=False
# )

collection_json = collection.to_json()

# only do evaluation on one of them, nothing should change
collection.evaluate()
collection.evaluate(overwrite=False)

collection_json = collection.to_json()
collection2_json = collection2.to_json()
collection2_json = collection.to_json()

assert collection_json == collection2_json