Skip to content

Commit

Permalink
fix(lm_eval): fix loglikelihood substring problem check
Browse files Browse the repository at this point in the history
Signed-off-by: Radek Ježek <pc.jezek@gmail.com>
  • Loading branch information
jezekra1 committed Apr 18, 2024
1 parent 0788604 commit feda653
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
33 changes: 18 additions & 15 deletions src/genai/extensions/lm_eval/model.py
Expand Up @@ -132,21 +132,24 @@ def _check_last_token_is_stop_token(self, response_tokens: list[str], context_to
"""
context_length = len(context_tokens)
if response_tokens[: context_length - 1] == context_tokens[:-1]:
if response_tokens[-1].startswith(context_tokens[-1]):
raise RuntimeError(
f"The context sent to loglikelihood evaluation ends with a token that is substring of the "
f"continuation token:\n"
f"context_tokens={context_tokens}\n"
f"response_tokens={response_tokens[:context_length]}\n"
"This is not allowed as it would skew the results. Please check your data."
)
return response_tokens[:context_length][-1] != context_tokens[-1]
raise RuntimeError(
f"There is an unexpected difference between tokenizer and model tokens:\n"
f"context_tokens={context_tokens}\n"
f"response_tokens={response_tokens[:context_length]}"
)
if response_tokens[: context_length - 1] != context_tokens[: context_length - 1]:
raise RuntimeError(
f"There is an unexpected difference between tokenizer and model tokens:\n"
f"context_tokens={context_tokens}\n"
f"response_tokens={response_tokens[:context_length]}"
)

last_context_token = context_tokens[context_length - 1]
last_context_token_resp = response_tokens[context_length - 1]
if last_context_token != last_context_token_resp and last_context_token_resp.startswith(last_context_token):
raise RuntimeError(
f"The context sent to loglikelihood evaluation ends with a token ({last_context_token}) "
f"that is substring of the continuation token ({last_context_token_resp}).\n"
f"context_tokens={context_tokens}\n"
f"response_tokens={response_tokens[:context_length]}\n"
"This is not allowed as it would skew the results. Please check your data."
)
return last_context_token != last_context_token_resp

def _check_model_logprobs_support(self):
input_tokens = (
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/extensions/test_lm_eval.py
Expand Up @@ -28,7 +28,9 @@ def test_create_from_arg_string(self):
def test_loglikelihood_raises_for_invalid_tokenization(self):
"""Test loglikelihood of part of token is invalid"""
lm = IBMGenAILMEval(model_id="tiiuae/falcon-40b")
with pytest.raises(RuntimeError, match=r".*ends with a token that is substring of the continuation token"):
with pytest.raises(
RuntimeError, match=r".*ends with a token .* that is substring of the continuation token .*"
):
requests = [
Instance(request_type="loglikelihood", doc=args, arguments=args, idx=i)
for i, args in enumerate([("test str", "ing")])
Expand Down

0 comments on commit feda653

Please sign in to comment.