diff --git a/src/genai/extensions/lm_eval/model.py b/src/genai/extensions/lm_eval/model.py index d71114dc..5195cf74 100644 --- a/src/genai/extensions/lm_eval/model.py +++ b/src/genai/extensions/lm_eval/model.py @@ -132,21 +132,24 @@ def _check_last_token_is_stop_token(self, response_tokens: list[str], context_to """ context_length = len(context_tokens) - if response_tokens[: context_length - 1] == context_tokens[:-1]: - if response_tokens[-1].startswith(context_tokens[-1]): - raise RuntimeError( - f"The context sent to loglikelihood evaluation ends with a token that is substring of the " - f"continuation token:\n" - f"context_tokens={context_tokens}\n" - f"response_tokens={response_tokens[:context_length]}\n" - "This is not allowed as it would skew the results. Please check your data." - ) - return response_tokens[:context_length][-1] != context_tokens[-1] - raise RuntimeError( - f"There is an unexpected difference between tokenizer and model tokens:\n" - f"context_tokens={context_tokens}\n" - f"response_tokens={response_tokens[:context_length]}" - ) + if response_tokens[: context_length - 1] != context_tokens[: context_length - 1]: + raise RuntimeError( + f"There is an unexpected difference between tokenizer and model tokens:\n" + f"context_tokens={context_tokens}\n" + f"response_tokens={response_tokens[:context_length]}" + ) + + last_context_token = context_tokens[context_length - 1] + last_context_token_resp = response_tokens[context_length - 1] + if last_context_token != last_context_token_resp and last_context_token_resp.startswith(last_context_token): + raise RuntimeError( + f"The context sent to loglikelihood evaluation ends with a token ({last_context_token}) " + f"that is substring of the continuation token ({last_context_token_resp}).\n" + f"context_tokens={context_tokens}\n" + f"response_tokens={response_tokens[:context_length]}\n" + "This is not allowed as it would skew the results. Please check your data." + ) + return last_context_token != last_context_token_resp def _check_model_logprobs_support(self): input_tokens = ( diff --git a/tests/integration/extensions/test_lm_eval.py b/tests/integration/extensions/test_lm_eval.py index d8d42e0e..14a4410b 100644 --- a/tests/integration/extensions/test_lm_eval.py +++ b/tests/integration/extensions/test_lm_eval.py @@ -28,7 +28,9 @@ def test_create_from_arg_string(self): def test_loglikelihood_raises_for_invalid_tokenization(self): """Test loglikelihood of part of token is invalid""" lm = IBMGenAILMEval(model_id="tiiuae/falcon-40b") - with pytest.raises(RuntimeError, match=r".*ends with a token that is substring of the continuation token"): + with pytest.raises( + RuntimeError, match=r".*ends with a token .* that is substring of the continuation token .*" + ): requests = [ Instance(request_type="loglikelihood", doc=args, arguments=args, idx=i) for i, args in enumerate([("test str", "ing")])